From a56711796fb1f67904436f1f3b38772d79dde666 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 20 Nov 2024 21:39:01 +0700 Subject: [PATCH 001/100] cmd/cli: set DNS using Windows API --- cmd/cli/os_windows.go | 55 +++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index b9412b6..1a22b0f 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -4,12 +4,13 @@ import ( "errors" "fmt" "net" + "net/netip" "os" "slices" - "strconv" "strings" "sync" + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" @@ -30,14 +31,6 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e return setDNS(iface, nameservers) } -func setDnsPowershellCmd(iface *net.Interface, nameservers []string) string { - nss := make([]string, 0, len(nameservers)) - for _, ns := range nameservers { - nss = append(nss, strconv.Quote(ns)) - } - return fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ServerAddresses (%s)", iface.Index, strings.Join(nss, ",")) -} - // setDNS sets the dns server for the provided network interface func setDNS(iface *net.Interface, nameservers []string) error { if len(nameservers) == 0 { @@ -65,9 +58,36 @@ func setDNS(iface *net.Interface, nameservers []string) error { } } }) - out, err := powershell(setDnsPowershellCmd(iface, nameservers)) + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return fmt.Errorf("%w: %s", err, string(out)) + return fmt.Errorf("setDNS: %w", err) + } + var ( + serversV4 []netip.Addr + serversV6 []netip.Addr + ) + for _, ns := range nameservers { + if addr, err := netip.ParseAddr(ns); err == nil { + if addr.Is4() { + serversV4 = append(serversV4, addr) + } else { + serversV6 = append(serversV6, addr) + } + } + } + + if len(serversV4) == 0 && len(serversV6) == 0 { + return errors.New("invalid DNS nameservers") + } + if len(serversV4) > 0 { + if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil { + return fmt.Errorf("could not set DNS ipv4: %w", err) + } + } + if len(serversV6) > 0 { + if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil { + return fmt.Errorf("could not set DNS ipv6: %w", err) + } } return nil } @@ -96,11 +116,16 @@ func resetDNS(iface *net.Interface) error { } }) - // Restoring DHCP settings. - cmd := fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ResetServerAddresses", iface.Index) - out, err := powershell(cmd) + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return fmt.Errorf("%w: %s", err, string(out)) + return fmt.Errorf("resetDNS: %w", err) + } + // Restoring DHCP settings. + if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil { + return fmt.Errorf("could not reset DNS ipv4: %w", err) + } + if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil { + return fmt.Errorf("could not reset DNS ipv6: %w", err) } // If there's static DNS saved, restoring it. From 71e327653a65f239f124e75d9f3b94374a6350b4 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 21 Nov 2024 16:52:59 +0700 Subject: [PATCH 002/100] cmd/cli: check local DNS using Windows API --- cmd/cli/cli.go | 19 ++----------------- cmd/cli/os_windows.go | 12 ++++++++++-- cmd/cli/service_others.go | 3 +++ cmd/cli/service_windows.go | 22 ++++++++++++++++++++++ cmd/cli/service_windows_test.go | 25 +++++++++++++++++++++++++ 5 files changed, 62 insertions(+), 19 deletions(-) create mode 100644 cmd/cli/service_windows_test.go diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 502014e..b0ae022 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -709,7 +709,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, return nil }) // Windows forwarders file. - if windowsHasLocalDnsServerRunning() { + if hasLocalDnsServerRunning() { files = append(files, absHomeDir(windowsForwardersFilename)) } // Binary itself. @@ -2107,7 +2107,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata cdMode := cdUID != "" nextdnsMode := nextdns != "" // For Windows server with local Dns server running, we can only try on random local IP. - hasLocalDnsServer := windowsHasLocalDnsServerRunning() + hasLocalDnsServer := hasLocalDnsServerRunning() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { @@ -2614,21 +2614,6 @@ func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.M return c.ExchangeContext(ctx, msg, addr) } -// powershell runs the given powershell command. -func powershell(cmd string) ([]byte, error) { - out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() - return bytes.TrimSpace(out), err -} - -// windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. -func windowsHasLocalDnsServerRunning() bool { - if runtime.GOOS == "windows" { - _, err := powershell("Get-Process -Name DNS") - return err == nil - } - return false -} - // absHomeDir returns the absolute path to given filename using home directory as root dir. func absHomeDir(filename string) string { if homedir != "" { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 1a22b0f..aa44418 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -1,11 +1,13 @@ package cli import ( + "bytes" "errors" "fmt" "net" "net/netip" "os" + "os/exec" "slices" "strings" "sync" @@ -39,7 +41,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { setDNSOnce.Do(func() { // If there's a Dns server running, that means we are on AD with Dns feature enabled. // Configuring the Dns server to forward queries to ctrld instead. - if windowsHasLocalDnsServerRunning() { + if hasLocalDnsServerRunning() { file := absHomeDir(windowsForwardersFilename) oldForwardersContent, _ := os.ReadFile(file) hasLocalIPv6Listener := needLocalIPv6Listener() @@ -101,7 +103,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { func resetDNS(iface *net.Interface) error { resetDNSOnce.Do(func() { // See corresponding comment in setDNS. - if windowsHasLocalDnsServerRunning() { + if hasLocalDnsServerRunning() { file := absHomeDir(windowsForwardersFilename) content, err := os.ReadFile(file) if err != nil { @@ -241,3 +243,9 @@ func removeDnsServerForwarders(nameservers []string) error { } return nil } + +// powershell runs the given powershell command. +func powershell(cmd string) ([]byte, error) { + out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput() + return bytes.TrimSpace(out), err +} diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index f4d73e5..2303e30 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -13,3 +13,6 @@ func hasElevatedPrivilege() (bool, error) { func openLogFile(path string, flags int) (*os.File, error) { return os.OpenFile(path, flags, os.FileMode(0o600)) } + +// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. +func hasLocalDnsServerRunning() bool { return false } diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index d4e2449..af4f317 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -2,7 +2,9 @@ package cli import ( "os" + "strings" "syscall" + "unsafe" "golang.org/x/sys/windows" ) @@ -79,3 +81,23 @@ func openLogFile(path string, mode int) (*os.File, error) { return os.NewFile(uintptr(handle), path), nil } + +const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{})) + +// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. +func hasLocalDnsServerRunning() bool { + h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0) + if e != nil { + return false + } + p := windows.ProcessEntry32{Size: processEntrySize} + for { + e := windows.Process32Next(h, &p) + if e != nil { + return false + } + if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" { + return true + } + } +} diff --git a/cmd/cli/service_windows_test.go b/cmd/cli/service_windows_test.go new file mode 100644 index 0000000..67c2725 --- /dev/null +++ b/cmd/cli/service_windows_test.go @@ -0,0 +1,25 @@ +package cli + +import ( + "testing" + "time" +) + +func Test_hasLocalDnsServerRunning(t *testing.T) { + start := time.Now() + hasDns := hasLocalDnsServerRunning() + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + hasDnsPowershell := hasLocalDnsServerRunningPowershell() + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + if hasDns != hasDnsPowershell { + t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns) + } +} + +func hasLocalDnsServerRunningPowershell() bool { + _, err := powershell("Get-Process -Name DNS") + return err == nil +} From 9b6a308958d4d7420c066e3f69a0afecbfdac10c Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 21 Nov 2024 20:24:46 +0700 Subject: [PATCH 003/100] cmd/cli: get AD domain using Windows API --- cmd/cli/ad_windows.go | 18 +++++++++++++----- cmd/cli/ad_windows_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 cmd/cli/ad_windows_test.go diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index d7374d0..316414d 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -1,8 +1,11 @@ package cli import ( - "fmt" "strings" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" "github.com/Control-D-Inc/ctrld" ) @@ -40,10 +43,15 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool { // getActiveDirectoryDomain returns AD domain name of this computer. func getActiveDirectoryDomain() (string, error) { - cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }" - output, err := powershell(cmd) + var domain *uint16 + var status uint32 + err := syscall.NetGetJoinInformation(nil, &domain, &status) if err != nil { - return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output)) + return "", err } - return string(output), nil + defer syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) + if status == syscall.NetSetupDomainName { + return windows.UTF16PtrToString(domain), nil + } + return "", nil } diff --git a/cmd/cli/ad_windows_test.go b/cmd/cli/ad_windows_test.go new file mode 100644 index 0000000..392abbd --- /dev/null +++ b/cmd/cli/ad_windows_test.go @@ -0,0 +1,36 @@ +package cli + +import ( + "fmt" + "testing" + "time" +) + +func Test_getActiveDirectoryDomain(t *testing.T) { + start := time.Now() + domain, err := getActiveDirectoryDomain() + if err != nil { + t.Fatal(err) + } + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + domainPowershell, err := getActiveDirectoryDomainPowershell() + if err != nil { + t.Fatal(err) + } + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + if domain != domainPowershell { + t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain) + } +} + +func getActiveDirectoryDomainPowershell() (string, error) { + cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }" + output, err := powershell(cmd) + if err != nil { + return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output)) + } + return string(output), nil +} From 5e9b4244e7c3e11509599fbfd9bff7c08dd93739 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 25 Nov 2024 18:38:40 +0700 Subject: [PATCH 004/100] cmd/cli: get physical interfaces using Windows WMI --- cmd/cli/net_windows.go | 60 ++++++++++++++++++++++++++++++------- cmd/cli/net_windows_test.go | 42 ++++++++++++++++++++++++++ go.mod | 3 ++ go.sum | 6 ++++ 4 files changed, 101 insertions(+), 10 deletions(-) create mode 100644 cmd/cli/net_windows_test.go diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index dc13b08..7174a1f 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -1,10 +1,13 @@ package cli import ( - "bufio" - "bytes" "net" - "strings" + + "github.com/microsoft/wmi/pkg/base/host" + "github.com/microsoft/wmi/pkg/base/instance" + "github.com/microsoft/wmi/pkg/base/query" + "github.com/microsoft/wmi/pkg/constant" + "github.com/microsoft/wmi/pkg/hardware/network/netadapter" ) func patchNetIfaceName(iface *net.Interface) error { @@ -20,15 +23,52 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo // validInterfacesMap returns a set of all physical interfaces. func validInterfacesMap() map[string]struct{} { - out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name") - if err != nil { - return nil - } m := make(map[string]struct{}) - scanner := bufio.NewScanner(bytes.NewReader(out)) - for scanner.Scan() { - ifaceName := strings.TrimSpace(scanner.Text()) + for _, ifaceName := range validInterfaces() { m[ifaceName] = struct{}{} } return m } + +// validInterfaces returns a list of all physical interfaces. +func validInterfaces() []string { + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("MSFT_NetAdapter") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if err != nil { + mainLog.Load().Err(err).Msg("failed to get wmi network adapter") + return nil + } + var adapters []string + for _, i := range instances { + adapter, err := netadapter.NewNetworkAdapter(i) + if err != nil { + mainLog.Load().Err(err).Msg("failed to get network adapter") + continue + } + // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) + // + // "Indicates if a connector is present on the network adapter. This value is set to TRUE + // if this is a physical adapter or FALSE if this is not a physical adapter." + physical, err := adapter.GetPropertyConnectorPresent() + if err != nil { + mainLog.Load().Err(err).Msg("failed to get network adapter connector present property") + continue + } + if !physical { + continue + } + ifaceIdx, err := adapter.GetInterfaceIndex() + if err != nil { + mainLog.Load().Err(err).Msg("failed to get interface index") + continue + } + iff, err := net.InterfaceByIndex(int(ifaceIdx)) + if err != nil { + mainLog.Load().Err(err).Msg("failed to get interface") + continue + } + adapters = append(adapters, iff.Name) + } + return adapters +} diff --git a/cmd/cli/net_windows_test.go b/cmd/cli/net_windows_test.go new file mode 100644 index 0000000..a15f119 --- /dev/null +++ b/cmd/cli/net_windows_test.go @@ -0,0 +1,42 @@ +package cli + +import ( + "bufio" + "bytes" + "slices" + "strings" + "testing" + "time" +) + +func Test_validInterfaces(t *testing.T) { + verbose = 3 + initConsoleLogging() + start := time.Now() + ifaces := validInterfaces() + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + ifacesPowershell := validInterfacesPowershell() + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + slices.Sort(ifaces) + slices.Sort(ifacesPowershell) + if !slices.Equal(ifaces, ifacesPowershell) { + t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces) + } +} + +func validInterfacesPowershell() []string { + out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name") + if err != nil { + return nil + } + var res []string + scanner := bufio.NewScanner(bytes.NewReader(out)) + for scanner.Scan() { + ifaceName := strings.TrimSpace(scanner.Text()) + res = append(res, ifaceName) + } + return res +} diff --git a/go.mod b/go.mod index 84b58c4..58b67c5 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 github.com/kardianos/service v1.2.1 github.com/mdlayher/ndp v1.0.1 + github.com/microsoft/wmi v0.24.5 github.com/miekg/dns v1.1.58 github.com/minio/selfupdate v0.6.0 github.com/olekukonko/tablewriter v0.0.5 @@ -49,6 +50,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 // indirect + github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect @@ -72,6 +74,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect diff --git a/go.sum b/go.sum index ebb9042..cb1d9ee 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,8 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= @@ -227,6 +229,8 @@ github.com/mdlayher/packet v1.1.2 h1:3Up1NG6LZrsgDVn6X4L9Ge/iyRyxFEFD9o6Pr3Q1nQY github.com/mdlayher/packet v1.1.2/go.mod h1:GEu1+n9sG5VtiRE4SydOmX5GTwyyYlteZiFU+x0kew4= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI= +github.com/microsoft/wmi v0.24.5 h1:NT+WqhjKbEcg3ldmDsRMarWgHGkpeW+gMopSCfON0kM= +github.com/microsoft/wmi v0.24.5/go.mod h1:1zbdSF0A+5OwTUII5p3hN7/K6KF2m3o27pSG6Y51VU8= github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY= github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU= @@ -245,6 +249,7 @@ github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFu github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -478,6 +483,7 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= From 6837176ec7c49b76b66ab2e4dc7055310c672457 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 25 Nov 2024 20:06:42 +0700 Subject: [PATCH 005/100] cmd/cli: get static DNS using syscall --- cmd/cli/os_windows.go | 25 +++++++++----- cmd/cli/os_windows_test.go | 68 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 9 deletions(-) create mode 100644 cmd/cli/os_windows_test.go diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index aa44418..5ff9360 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -13,14 +13,15 @@ import ( "sync" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) const ( - v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` - v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` + v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` ) var ( @@ -177,25 +178,31 @@ func currentDNS(iface *net.Interface) []string { func currentStaticDNS(iface *net.Interface) ([]string, error) { luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) if err != nil { - return nil, err + return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err) } guid, err := luid.GUID() if err != nil { - return nil, err + return nil, fmt.Errorf("luid.GUID: %w", err) } var ns []string for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} { - interfaceKeyPath := path + guid.String() found := false + interfaceKeyPath := path + guid.String() + k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE) + if err != nil { + return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err) + } for _, key := range []string{"NameServer", "ProfileNameServer"} { if found { continue } - cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key) - out, err := powershell(cmd) - if err == nil && len(out) > 0 { + value, _, err := k.GetStringValue(key) + if err != nil && !errors.Is(err, registry.ErrNotExist) { + return nil, fmt.Errorf("%s: %w", key, err) + } + if len(value) > 0 { found = true - for _, e := range strings.Split(string(out), ",") { + for _, e := range strings.Split(value, ",") { ns = append(ns, strings.TrimRight(e, "\x00")) } } diff --git a/cmd/cli/os_windows_test.go b/cmd/cli/os_windows_test.go new file mode 100644 index 0000000..40be5ed --- /dev/null +++ b/cmd/cli/os_windows_test.go @@ -0,0 +1,68 @@ +package cli + +import ( + "fmt" + "net" + "slices" + "strings" + "testing" + "time" + + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func Test_currentStaticDNS(t *testing.T) { + iface, err := net.InterfaceByName(defaultIfaceName()) + if err != nil { + t.Fatal(err) + } + start := time.Now() + staticDns, err := currentStaticDNS(iface) + if err != nil { + t.Fatal(err) + } + t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds()) + + start = time.Now() + staticDnsPowershell, err := currentStaticDnsPowershell(iface) + if err != nil { + t.Fatal(err) + } + t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds()) + + slices.Sort(staticDns) + slices.Sort(staticDnsPowershell) + if !slices.Equal(staticDns, staticDnsPowershell) { + t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns) + } +} + +func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) { + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return nil, err + } + guid, err := luid.GUID() + if err != nil { + return nil, err + } + var ns []string + for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} { + interfaceKeyPath := path + guid.String() + found := false + for _, key := range []string{"NameServer", "ProfileNameServer"} { + if found { + continue + } + cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key) + out, err := powershell(cmd) + if err == nil && len(out) > 0 { + found = true + for _, e := range strings.Split(string(out), ",") { + ns = append(ns, strings.TrimRight(e, "\x00")) + } + } + } + } + return ns, nil +} From 8360bdc50ada11c9df1456d11bbf330af87c160d Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 27 Nov 2024 16:00:38 +0700 Subject: [PATCH 006/100] cmd/cli: add split route AD top level domain on Windows The sub-domains are matched using wildcard domain rule, but this rule won't match top level domain, causing requests are forwarded to ControlD upstreams. To fix this, add the split route for top level domain explicitly. --- cmd/cli/ad_windows.go | 17 ++++++++++++----- cmd/cli/ad_windows_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index 316414d..475ba09 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -24,19 +24,26 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool { // Network rules are lowercase during toml config marshaling, // lowercase the domain here too for consistency. domain = strings.ToLower(domain) + domainRuleAdded := addSplitDnsRule(cfg, domain) + wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, ".")) + return domainRuleAdded || wildcardDomainRuleRuleAdded +} + +// addSplitDnsRule adds split-rule for given domain if there's no existed rule. +// The return value indicates whether the split-rule was added or not. +func addSplitDnsRule(cfg *ctrld.Config, domain string) bool { for n, lc := range cfg.Listener { if lc.Policy == nil { lc.Policy = &ctrld.ListenerPolicyConfig{} } - domainRule := "*." + strings.TrimPrefix(domain, ".") for _, rule := range lc.Policy.Rules { - if _, ok := rule[domainRule]; ok { - mainLog.Load().Debug().Msgf("domain rule already exist for listener.%s", n) + if _, ok := rule[domain]; ok { + mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n) return false } } - mainLog.Load().Debug().Msgf("adding active directory domain for listener.%s", n) - lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}}) + mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n) + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) } return true } diff --git a/cmd/cli/ad_windows_test.go b/cmd/cli/ad_windows_test.go index 392abbd..6fe7f41 100644 --- a/cmd/cli/ad_windows_test.go +++ b/cmd/cli/ad_windows_test.go @@ -4,6 +4,10 @@ import ( "fmt" "testing" "time" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/testhelper" + "github.com/stretchr/testify/assert" ) func Test_getActiveDirectoryDomain(t *testing.T) { @@ -34,3 +38,35 @@ func getActiveDirectoryDomainPowershell() (string, error) { } return string(output), nil } + +func Test_addSplitDnsRule(t *testing.T) { + newCfg := func(domains ...string) *ctrld.Config { + cfg := testhelper.SampleConfig(t) + lc := cfg.Listener["0"] + for _, domain := range domains { + lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}}) + } + return cfg + } + tests := []struct { + name string + cfg *ctrld.Config + domain string + added bool + }{ + {"added", newCfg(), "example.com", true}, + {"TLD existed", newCfg("example.com"), "*.example.com", true}, + {"wildcard existed", newCfg("*.example.com"), "example.com", true}, + {"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false}, + {"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + added := addSplitDnsRule(tc.cfg, tc.domain) + assert.Equal(t, tc.added, added) + }) + } +} From 70ab8032a0006a8bab03bc3fa5e6b9968693faac Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 4 Dec 2024 17:36:04 +0700 Subject: [PATCH 007/100] cmd/cli: silent WMI query The log is being printed by the wmi library, which may cause confusion. --- cmd/cli/net_windows.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index 7174a1f..2077b85 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -1,7 +1,10 @@ package cli import ( + "io" + "log" "net" + "os" "github.com/microsoft/wmi/pkg/base/host" "github.com/microsoft/wmi/pkg/base/instance" @@ -32,6 +35,8 @@ func validInterfacesMap() map[string]struct{} { // validInterfaces returns a list of all physical interfaces. func validInterfaces() []string { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) From 17941882a94568fc921b5ef1e603e393fad62573 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 5 Dec 2024 15:00:06 +0700 Subject: [PATCH 008/100] cmd/cli: split-route SRV record to OS resolver Since SRV record is mostly useful in AD environment. Even in non-AD one, the OS resolver could still resolve the query for external services. Users who want special treatment can still specify domain rules to forward requests to ControlD upstreams explicitly. --- cmd/cli/dns_proxy.go | 12 ++++++++++++ cmd/cli/dns_proxy_test.go | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 6611975..f195f62 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -445,6 +445,10 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } } else { switch { + case isSrvLookup(req.msg): + upstreams = []string{upstreamOS} + upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams) case isPrivatePtrLookup(req.msg): isLanOrPtrQuery = true if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil { @@ -1059,6 +1063,14 @@ func isLanHostnameQuery(m *dns.Msg) bool { strings.HasSuffix(name, ".lan") } +// isSrvLookup reports whether DNS message is a SRV query. +func isSrvLookup(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + return m.Question[0].Qtype == dns.TypeSRV +} + // isWanClient reports whether the input is a WAN address. func isWanClient(na net.Addr) bool { var ip netip.Addr diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 877fb71..6e7a431 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -414,6 +414,26 @@ func Test_isPrivatePtrLookup(t *testing.T) { } } +func Test_isSrvLookup(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + isSrvLookup bool + }{ + {"SRV", newDnsMsgWithHostname("foo", dns.TypeSRV), true}, + {"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isSrvLookup(tc.msg); tc.isSrvLookup != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got) + } + }) + } +} + func Test_isWanClient(t *testing.T) { tests := []struct { name string From 09426dcd3626eb66fb4709952330d766a324bf9e Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 5 Dec 2024 17:06:06 +0700 Subject: [PATCH 009/100] cmd/cli: new flow for LAN hostname query If there is no explicit rules for LAN hostname queries, using OS resolver instead of forwarding requests to remote upstreams. --- cmd/cli/dns_proxy.go | 10 ++++++---- cmd/cli/dns_proxy_test.go | 3 +++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index f195f62..a69f5b5 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -456,7 +456,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { res.clientInfo = true return res } - upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs) ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams) case isLanHostnameQuery(req.msg): isLanOrPtrQuery = true @@ -465,7 +465,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { res.clientInfo = true return res } - upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + upstreams = []string{upstreamOS} + upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams) default: ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams) @@ -605,7 +606,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return res } -func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { +func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { if len(p.localUpstreams) > 0 { tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams)) tmp = append(tmp, p.localUpstreams...) @@ -1060,7 +1061,8 @@ func isLanHostnameQuery(m *dns.Msg) bool { name := strings.TrimSuffix(q.Name, ".") return !strings.Contains(name, ".") || strings.HasSuffix(name, ".domain") || - strings.HasSuffix(name, ".lan") + strings.HasSuffix(name, ".lan") || + strings.HasSuffix(name, ".local") } // isSrvLookup reports whether DNS message is a SRV query. diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 6e7a431..9deb9ed 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -365,6 +365,9 @@ func Test_isLanHostnameQuery(t *testing.T) { {"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false}, {"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false}, {"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false}, + {".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true}, + {".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true}, + {".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true}, } for _, tc := range tests { tc := tc From ed39269c8061a989550856038343f96726b625af Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 5 Dec 2024 22:37:06 +0700 Subject: [PATCH 010/100] Implementing new initializing OS resolver logic Since the nameservers that we got during startup are the good ones that work, saving it for later usage if we could not find available ones. --- resolver.go | 90 +++++++++++++++++++----------------------------- resolver_test.go | 59 ++++++++++++++++--------------- 2 files changed, 67 insertions(+), 82 deletions(-) diff --git a/resolver.go b/resolver.go index f54edfb..fd97e48 100644 --- a/resolver.go +++ b/resolver.go @@ -85,60 +85,41 @@ func availableNameservers() []string { func InitializeOsResolver() []string { return initializeOsResolver(availableNameservers()) } + +// initializeOsResolver performs logic for choosing OS resolver nameserver. +// The logic: +// +// - First available LAN servers are saved and store. +// - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { var ( - nss []string + lanNss []string publicNss []string ) - var ( - lastLanServer netip.Addr - curLanServer netip.Addr - curLanServerAvailable bool - ) - if p := or.currentLanServer.Load(); p != nil { - curLanServer = *p - or.currentLanServer.Store(nil) - } - if p := or.lastLanServer.Load(); p != nil { - lastLanServer = *p - or.lastLanServer.Store(nil) - } + for _, ns := range servers { addr, err := netip.ParseAddr(ns) if err != nil { continue } server := net.JoinHostPort(ns, "53") - // Always use new public nameserver. - if !isLanAddr(addr) { - publicNss = append(publicNss, server) - nss = append(nss, server) - continue - } - // For LAN server, storing only current and last LAN server if any. - if addr.Compare(curLanServer) == 0 { - curLanServerAvailable = true + if isLanAddr(addr) { + lanNss = append(lanNss, server) } else { - if addr.Compare(lastLanServer) == 0 { - or.lastLanServer.Store(&addr) - } else { - if or.currentLanServer.CompareAndSwap(nil, &addr) { - nss = append(nss, server) - } - } + publicNss = append(publicNss, server) } } - // Store current LAN server as last one only if it's still available. - if curLanServerAvailable && curLanServer.IsValid() { - or.lastLanServer.Store(&curLanServer) - nss = append(nss, net.JoinHostPort(curLanServer.String(), "53")) + if len(lanNss) > 0 { + // Saved first initialized LAN servers. + or.initializedLanServers.CompareAndSwap(nil, &lanNss) } - if len(publicNss) == 0 { - publicNss = append(publicNss, controldPublicDnsWithPort) - nss = append(nss, controldPublicDnsWithPort) + if len(lanNss) == 0 { + or.lanServers.Store(or.initializedLanServers.Load()) + } else { + or.lanServers.Store(&lanNss) } - or.publicServer.Store(&publicNss) - return nss + or.publicServers.Store(&publicNss) + return slices.Concat(lanNss, publicNss) } // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. @@ -185,9 +166,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { } type osResolver struct { - currentLanServer atomic.Pointer[netip.Addr] - lastLanServer atomic.Pointer[netip.Addr] - publicServer atomic.Pointer[[]string] + initializedLanServers atomic.Pointer[[]string] + lanServers atomic.Pointer[[]string] + publicServers atomic.Pointer[[]string] } type osResolverResult struct { @@ -201,13 +182,10 @@ type osResolverResult struct { // Query is sent to all nameservers concurrently, and the first // success response will be returned. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - publicServers := *o.publicServer.Load() - nss := make([]string, 0, 2) - if p := o.currentLanServer.Load(); p != nil { - nss = append(nss, net.JoinHostPort(p.String(), "53")) - } - if p := o.lastLanServer.Load(); p != nil { - nss = append(nss, net.JoinHostPort(p.String(), "53")) + publicServers := *o.publicServers.Load() + var nss []string + if p := o.lanServers.Load(); p != nil { + nss = append(nss, (*p)...) } numServers := len(nss) + len(publicServers) if numServers == 0 { @@ -467,17 +445,19 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { r := &osResolver{} - nss := slices.Sorted(slices.Values(nameservers)) - for i, ns := range nss { + var publicNss []string + var lanNss []string + for _, ns := range slices.Sorted(slices.Values(nameservers)) { ip, _, _ := net.SplitHostPort(ns) addr, _ := netip.ParseAddr(ip) if isLanAddr(addr) { - r.currentLanServer.Store(&addr) - nss = slices.Delete(nss, i, i+1) - break + lanNss = append(lanNss, ns) + } else { + publicNss = append(publicNss, ns) } } - r.publicServer.Store(&nss) + r.lanServers.Store(&lanNss) + r.publicServers.Store(&publicNss) return r } diff --git a/resolver_test.go b/resolver_test.go index 7b1a49d..0db05f6 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -20,7 +20,7 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() resolver := &osResolver{} - resolver.publicServer.Store(&[]string{"127.0.0.127:5353"}) + resolver.publicServers.Store(&[]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -74,7 +74,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { } }() resolver := &osResolver{} - resolver.publicServer.Store(&ns) + resolver.publicServers.Store(&ns) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg) @@ -156,38 +156,43 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H func Test_initializeOsResolver(t *testing.T) { lanServer1 := "192.168.1.1" lanServer2 := "10.0.10.69" + lanServer3 := "192.168.40.1" wanServer := "1.1.1.1" + lanServers := []string{net.JoinHostPort(lanServer1, "53"), net.JoinHostPort(lanServer2, "53")} publicServers := []string{net.JoinHostPort(wanServer, "53")} - // First initialization. + or = newResolverWithNameserver(defaultNameservers()) + + // First initialization, initialized servers are saved. + initializeOsResolver([]string{lanServer1, lanServer2, wanServer}) + p := or.initializedLanServers.Load() + assert.NotNil(t, p) + t.Logf("%v - %v", *p, lanServers) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) + + // No new LAN servers, but lanServer2 gone, initialized servers not changed. initializeOsResolver([]string{lanServer1, wanServer}) - p := or.currentLanServer.Load() + p = or.initializedLanServers.Load() assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer1, "53")})) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - // No new LAN server, current LAN server -> last LAN server. - initializeOsResolver([]string{lanServer1, wanServer}) - p = or.currentLanServer.Load() - assert.Nil(t, p) - p = or.lastLanServer.Load() + // New LAN servers, they are used, initialized servers not changed. + initializeOsResolver([]string{lanServer3, wanServer}) + p = or.initializedLanServers.Load() assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer3, "53")})) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - // New LAN server detected. - initializeOsResolver([]string{lanServer2, lanServer1, wanServer}) - p = or.currentLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer2, p.String()) - p = or.lastLanServer.Load() - assert.NotNil(t, p) - assert.Equal(t, lanServer1, p.String()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) - - // No LAN server available. + // No LAN server available, initialized servers will be used. initializeOsResolver([]string{wanServer}) - assert.Nil(t, or.currentLanServer.Load()) - assert.Nil(t, or.lastLanServer.Load()) - assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers)) + p = or.initializedLanServers.Load() + assert.NotNil(t, p) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) } From a9f76322bdf3c5ccc3e9548b12231a9e93fa6658 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 6 Dec 2024 13:53:26 +0700 Subject: [PATCH 011/100] Bump quic-go to v0.48.2 For fixing GO-2024-3302 (CVE-2024-53259) --- config_quic.go | 4 ++-- go.mod | 16 ++++++++-------- go.sum | 32 ++++++++++++++++---------------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/config_quic.go b/config_quic.go index a6dd8b7..a46780a 100644 --- a/config_quic.go +++ b/config_quic.go @@ -34,7 +34,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() { } func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { - rt := &http3.RoundTripper{} + rt := &http3.Transport{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { _, port, _ := net.SplitHostPort(addr) @@ -64,7 +64,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) return conn, err } - runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) { + runtime.SetFinalizer(rt, func(rt *http3.Transport) { rt.CloseIdleConnections() }) return rt diff --git a/go.mod b/go.mod index 58b67c5..1f797e8 100644 --- a/go.mod +++ b/go.mod @@ -28,16 +28,16 @@ require ( github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_model v0.5.0 github.com/prometheus/prom2json v1.3.3 - github.com/quic-go/quic-go v0.42.0 + github.com/quic-go/quic-go v0.48.2 github.com/rs/zerolog v1.28.0 github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/net v0.27.0 - golang.org/x/sync v0.7.0 - golang.org/x/sys v0.22.0 + golang.org/x/net v0.28.0 + golang.org/x/sync v0.8.0 + golang.org/x/sys v0.23.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.74.0 ) @@ -78,7 +78,7 @@ require ( github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/common v0.48.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect - github.com/quic-go/qpack v0.4.0 // indirect + github.com/quic-go/qpack v0.5.1 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/spf13/afero v1.9.5 // indirect @@ -90,10 +90,10 @@ require ( go.uber.org/mock v0.4.0 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect - golang.org/x/crypto v0.25.0 // indirect - golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/mod v0.19.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.17.0 // indirect golang.org/x/tools v0.23.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index cb1d9ee..5e073b9 100644 --- a/go.sum +++ b/go.sum @@ -266,10 +266,10 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/prometheus/prom2json v1.3.3 h1:IYfSMiZ7sSOfliBoo89PcufjWO4eAR0gznGcETyaUgo= github.com/prometheus/prom2json v1.3.3/go.mod h1:Pv4yIPktEkK7btWsrUTWDDDrnpUrAELaOCj+oFwlgmc= -github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= -github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM= -github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= +github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= +github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= +github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= +github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -343,8 +343,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -355,8 +355,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -414,8 +414,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -435,8 +435,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -486,8 +486,8 @@ golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -498,8 +498,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From f5ba8be182d106291b40589ca9fe81b27a80cae6 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 11 Dec 2024 14:10:55 +0700 Subject: [PATCH 012/100] Use ControlD Public DNS when non-available This logic was missed when new initializing OS resolver logic was implemented. While at it, also adding this test case to prevent regression. --- resolver.go | 3 +++ resolver_test.go | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/resolver.go b/resolver.go index fd97e48..f3b7a10 100644 --- a/resolver.go +++ b/resolver.go @@ -118,6 +118,9 @@ func initializeOsResolver(servers []string) []string { } else { or.lanServers.Store(&lanNss) } + if len(publicNss) == 0 { + publicNss = append(publicNss, controldPublicDnsWithPort) + } or.publicServers.Store(&publicNss) return slices.Concat(lanNss, publicNss) } diff --git a/resolver_test.go b/resolver_test.go index 0db05f6..e0b5508 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -167,7 +167,6 @@ func Test_initializeOsResolver(t *testing.T) { initializeOsResolver([]string{lanServer1, lanServer2, wanServer}) p := or.initializedLanServers.Load() assert.NotNil(t, p) - t.Logf("%v - %v", *p, lanServers) assert.True(t, slices.Equal(*p, lanServers)) assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) @@ -195,4 +194,12 @@ func Test_initializeOsResolver(t *testing.T) { assert.True(t, slices.Equal(*p, lanServers)) assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) + + // No Public server, ControlD Public DNS will be used. + initializeOsResolver([]string{}) + p = or.initializedLanServers.Load() + assert.NotNil(t, p) + assert.True(t, slices.Equal(*p, lanServers)) + assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) + assert.True(t, slices.Equal(*or.publicServers.Load(), []string{controldPublicDnsWithPort})) } From a63a30c76bee3360f4f53dbdf4aca3ef15766111 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 6 Dec 2024 20:47:52 +0700 Subject: [PATCH 013/100] all: add sending logs to ControlD API --- cmd/cli/cli.go | 2 +- cmd/cli/prog.go | 2 +- internal/controld/config.go | 128 +++++++++++++++++++++++++----------- 3 files changed, 92 insertions(+), 40 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index b0ae022..89c79de 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1269,7 +1269,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() // Performs self-uninstallation if the ControlD device does not exist. - var uer *controld.UtilityErrorResponse + var uer *controld.ErrorResponse if errors.As(err, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { _ = uninstallInvalidCdUID(p, cdLogger, false) } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 07c7677..4e7df33 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -1044,7 +1044,7 @@ func dnsChanged(iface *net.Interface, nameservers []string) bool { // selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then. func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) { - var uer *controld.UtilityErrorResponse + var uer *controld.ErrorResponse if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode { p.stopDnsWatchers() diff --git a/internal/controld/config.go b/internal/controld/config.go index 1bc2512..348dc54 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -25,8 +25,12 @@ import ( const ( apiDomainCom = "api.controld.com" apiDomainDev = "api.controld.dev" - resolverDataURLCom = "https://api.controld.com/utility" - resolverDataURLDev = "https://api.controld.dev/utility" + apiURLCom = "https://api.controld.com" + apiURLDev = "https://api.controld.dev" + resolverDataURLCom = apiURLCom + "/utility" + resolverDataURLDev = apiURLDev + "/utility" + logURLCom = apiURLCom + "/logs" + logURLDev = apiURLDev + "/logs" InvalidConfigCode = 40402 ) @@ -49,14 +53,14 @@ type utilityResponse struct { } `json:"body"` } -type UtilityErrorResponse struct { +type ErrorResponse struct { ErrorField struct { Message string `json:"message"` Code int `json:"code"` } `json:"error"` } -func (u UtilityErrorResponse) Error() string { +func (u ErrorResponse) Error() string { return u.ErrorField.Message } @@ -71,6 +75,12 @@ type UtilityOrgRequest struct { Hostname string `json:"hostname"` } +// LogsRequest contains request data for sending runtime logs to API. +type LogsRequest struct { + UID string `json:"uid"` + LogFile string `json:"log_file"` +} + // FetchResolverConfig fetch Control D config for given uid. func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) { uid, clientID := ParseRawUID(rawUID) @@ -123,6 +133,81 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade } req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") + transport := apiTransport(cdDev) + client := http.Client{ + Timeout: 10 * time.Second, + Transport: transport, + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err) + } + defer resp.Body.Close() + d := json.NewDecoder(resp.Body) + if resp.StatusCode != http.StatusOK { + errResp := &ErrorResponse{} + if err := d.Decode(errResp); err != nil { + return nil, err + } + return nil, errResp + } + + ur := &utilityResponse{} + if err := d.Decode(ur); err != nil { + return nil, err + } + return &ur.Body.Resolver, nil +} + +// SendLogs sends runtime log to ControlD API. +func SendLogs(req *LogsRequest, cdDev bool) error { + body, _ := json.Marshal(req) + return postLogAPI(cdDev, bytes.NewReader(body)) +} + +func postLogAPI(cdDev bool, body io.Reader) error { + apiUrl := logURLCom + if cdDev { + apiUrl = logURLDev + } + req, err := http.NewRequest("POST", apiUrl, body) + if err != nil { + return fmt.Errorf("http.NewRequest: %w", err) + } + transport := apiTransport(cdDev) + client := http.Client{ + Timeout: 10 * time.Second, + Transport: transport, + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("postLogAPI client.Do: %w", err) + } + defer resp.Body.Close() + d := json.NewDecoder(resp.Body) + if resp.StatusCode != http.StatusOK { + errResp := &ErrorResponse{} + if err := d.Decode(errResp); err != nil { + return err + } + return errResp + } + _, _ = io.Copy(io.Discard, resp.Body) + return nil +} + +// ParseRawUID parse the input raw UID, returning real UID and ClientID. +// The raw UID can have 2 forms: +// +// - +// - / +func ParseRawUID(rawUID string) (string, string) { + uid, clientID, _ := strings.Cut(rawUID, "/") + return uid, clientID +} + +// apiTransport returns an HTTP transport for connecting to ControlD API endpoint. +func apiTransport(cdDev bool) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { apiDomain := apiDomainCom @@ -143,41 +228,8 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade d := &ctrldnet.ParallelDialer{} return d.DialContext(ctx, network, addrs) } - if router.Name() == ddwrt.Name || runtime.GOOS == "android" { transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()} } - client := http.Client{ - Timeout: 10 * time.Second, - Transport: transport, - } - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("client.Do: %w", err) - } - defer resp.Body.Close() - d := json.NewDecoder(resp.Body) - if resp.StatusCode != http.StatusOK { - errResp := &UtilityErrorResponse{} - if err := d.Decode(errResp); err != nil { - return nil, err - } - return nil, errResp - } - - ur := &utilityResponse{} - if err := d.Decode(ur); err != nil { - return nil, err - } - return &ur.Body.Resolver, nil -} - -// ParseRawUID parse the input raw UID, returning real UID and ClientID. -// The raw UID can have 2 forms: -// -// - -// - / -func ParseRawUID(rawUID string) (string, string) { - uid, clientID, _ := strings.Cut(rawUID, "/") - return uid, clientID + return transport } From cd5619a05bfa9548e5eedf8ce53fff70bbab8760 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 9 Dec 2024 23:17:00 +0700 Subject: [PATCH 014/100] cmd/cli: add internal logging So in case of no logging enabled, useful data could be sent to ControlD server for further troubleshooting. --- cmd/cli/cli.go | 5 ++ cmd/cli/commands.go | 88 +++++++++++++++++++++++++++ cmd/cli/control_server.go | 46 ++++++++++++++ cmd/cli/log_writer.go | 120 +++++++++++++++++++++++++++++++++++++ cmd/cli/log_writer_test.go | 49 +++++++++++++++ cmd/cli/prog.go | 32 +++++----- 6 files changed, 326 insertions(+), 14 deletions(-) create mode 100644 cmd/cli/commands.go create mode 100644 cmd/cli/log_writer.go create mode 100644 cmd/cli/log_writer_test.go diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 89c79de..7a54367 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1100,6 +1100,8 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, }, } rootCmd.AddCommand(upgradeCmd) + + initLogCmd() } // isMobile reports whether the current OS is a mobile platform. @@ -1231,6 +1233,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // so it's able to log information in processCDFlags. initLogging() + // Initializing internal logging after global logging. + p.initInternalLogging() + mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) mainLog.Load().Info().Msgf("os: %s", osVersion()) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go new file mode 100644 index 0000000..a98fbb4 --- /dev/null +++ b/cmd/cli/commands.go @@ -0,0 +1,88 @@ +package cli + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "path/filepath" + + "github.com/spf13/cobra" +) + +func initLogCmd() { + logSendCmd := &cobra.Command{ + Use: "send", + Short: "Send runtime debug logs to ControlD", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + dir, err := socketDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(sendLogsPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to send logs") + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + mainLog.Load().Notice().Msg("runtime logs sent successfully") + case http.StatusServiceUnavailable: + mainLog.Load().Warn().Msg("runtime logs could only be sent once per minute") + default: + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to read response body") + } + mainLog.Load().Error().Msg("failed to send logs") + mainLog.Load().Error().Msg(string(buf)) + } + }, + } + logViewCmd := &cobra.Command{ + Use: "view", + Short: "View current runtime debug logs", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + dir, err := socketDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(viewLogsPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to get logs") + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusMovedPermanently: + mainLog.Load().Warn().Msg("runtime debugs log is not enabled") + mainLog.Load().Warn().Msg(`ctrld may be run without "--cd" flag or logging is already enabled`) + return + case http.StatusBadRequest: + mainLog.Load().Warn().Msg("runtime debugs log is not available") + return + case http.StatusOK: + } + var logs logViewResponse + if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to decode view logs result") + } + fmt.Println(logs.Data) + }, + } + logCmd := &cobra.Command{ + Use: "log", + Short: "Manage runtime debug logs", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + logSendCmd.Use, + }, + } + logCmd.AddCommand(logSendCmd) + logCmd.AddCommand(logViewCmd) + rootCmd.AddCommand(logCmd) +} diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index c31fd13..b6deed5 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -2,6 +2,7 @@ package cli import ( "context" + "encoding/base64" "encoding/json" "net" "net/http" @@ -25,6 +26,8 @@ const ( deactivationPath = "/deactivation" cdPath = "/cd" ifacePath = "/iface" + viewLogsPath = "/logs/view" + sendLogsPath = "/logs/send" ) type controlServer struct { @@ -211,6 +214,49 @@ func (p *prog) registerControlServerHandler() { } w.WriteHeader(http.StatusBadRequest) })) + p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + data, err := p.logContent() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + if len(data) == 0 { + w.WriteHeader(http.StatusMovedPermanently) + return + } + if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + })) + p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + if time.Since(p.internalLogSent) < logSentInterval { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + data, err := p.logContent() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + if len(data) == 0 { + w.WriteHeader(http.StatusMovedPermanently) + return + } + logFile := base64.StdEncoding.EncodeToString(data) + req := &controld.LogsRequest{ + UID: cdUID, + LogFile: logFile, + } + mainLog.Load().Debug().Msg("sending log file to ControlD server") + if err := controld.SendLogs(req, cdDev); err != nil { + mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } else { + mainLog.Load().Debug().Msg("sending log file successfully") + } + p.internalLogSent = time.Now() + })) } func jsonResponse(next http.Handler) http.Handler { diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go new file mode 100644 index 0000000..f84b231 --- /dev/null +++ b/cmd/cli/log_writer.go @@ -0,0 +1,120 @@ +package cli + +import ( + "bytes" + "errors" + "os" + "sync" + "time" + + "github.com/rs/zerolog" + + "github.com/Control-D-Inc/ctrld" +) + +const ( + logWriterSize = 1024 * 1024 * 5 // 5 MB + logWriterInitialSize = 32 // 32 B + logSentInterval = time.Minute + logTruncatedMarker = "...\n" +) + +type logViewResponse struct { + Data string `json:"data"` +} + +// logWriter is an internal buffer to keep track of runtime log when no logging is enabled. +type logWriter struct { + mu sync.Mutex + buf bytes.Buffer + size int +} + +// newLogWriter creates an internal log writer with a fixed buffer size. +func newLogWriter() *logWriter { + lw := &logWriter{size: logWriterSize} + return lw +} + +func (lw *logWriter) Write(p []byte) (int, error) { + lw.mu.Lock() + defer lw.mu.Unlock() + + // If writing p causes overflows, discard old data. + if lw.buf.Len()+len(p) > lw.size { + buf := lw.buf.Bytes() + buf = buf[:logWriterInitialSize] + if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 { + buf = buf[:idx] + } + lw.buf.Reset() + lw.buf.Write(buf) + lw.buf.WriteString(logTruncatedMarker) // indicate that the log was truncated. + } + // If p is bigger than buffer size, truncate p by half until its size is smaller. + for len(p)+lw.buf.Len() > lw.size { + p = p[len(p)/2:] + } + return lw.buf.Write(p) +} + +// initInternalLogging performs internal logging if there's no log enabled. +func (p *prog) initInternalLogging() { + if !p.needInternalLogging() { + return + } + p.initInternalLogWriterOnce.Do(func() { + mainLog.Load().Notice().Msg("internal logging enabled") + lw := newLogWriter() + p.internalLogWriter = lw + p.internalLogSent = time.Now().Add(-logSentInterval) + }) + p.mu.Lock() + lw := p.internalLogWriter + p.mu.Unlock() + multi := zerolog.MultiLevelWriter(lw) + l := mainLog.Load().Output(multi).With().Logger() + mainLog.Store(&l) + ctrld.ProxyLogger.Store(&l) + if verbose == 0 { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } +} + +// needInternalLogging reports whether prog needs to run internal logging. +func (p *prog) needInternalLogging() bool { + // Do not run in non-cd mode. + if cdUID == "" { + return false + } + // Do not run if there's already log file. + if p.cfg.Service.LogPath != "" { + return false + } + return true +} + +func (p *prog) logContent() ([]byte, error) { + var data []byte + if p.needInternalLogging() { + p.mu.Lock() + lw := p.internalLogWriter + p.mu.Unlock() + if lw == nil { + return nil, errors.New("nil internal log writer") + } + lw.mu.Lock() + data = lw.buf.Bytes() + lw.mu.Unlock() + } else { + if p.cfg.Service.LogPath == "" { + return nil, nil + } + buf, err := os.ReadFile(p.cfg.Service.LogPath) + if err != nil { + return nil, err + } + data = buf + } + return data, nil +} diff --git a/cmd/cli/log_writer_test.go b/cmd/cli/log_writer_test.go new file mode 100644 index 0000000..6882ea0 --- /dev/null +++ b/cmd/cli/log_writer_test.go @@ -0,0 +1,49 @@ +package cli + +import ( + "strings" + "sync" + "testing" +) + +func Test_logWriter_Write(t *testing.T) { + size := 64 + lw := &logWriter{size: size} + lw.buf.Grow(lw.size) + data := strings.Repeat("A", size) + lw.Write([]byte(data)) + if lw.buf.String() != data { + t.Fatalf("unexpected buf content: %v", lw.buf.String()) + } + newData := "B" + halfData := strings.Repeat("A", len(data)/2) + logTruncatedMarker + lw.Write([]byte(newData)) + if lw.buf.String() != halfData+newData { + t.Fatalf("unexpected new buf content: %v", lw.buf.String()) + } + + bigData := strings.Repeat("B", 256) + expected := halfData + strings.Repeat("B", 16) + lw.Write([]byte(bigData)) + if lw.buf.String() != expected { + t.Fatalf("unexpected big buf content: %v", lw.buf.String()) + } +} + +func Test_logWriter_ConcurrentWrite(t *testing.T) { + size := 64 + lw := &logWriter{size: size} + n := 10 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + lw.Write([]byte(strings.Repeat("A", i))) + }() + } + wg.Wait() + if lw.buf.Len() > lw.size { + t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String()) + } +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 4e7df33..c6146ea 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -85,20 +85,23 @@ type prog struct { dnsWatcherClosedOnce sync.Once dnsWatcherStopCh chan struct{} - cfg *ctrld.Config - localUpstreams []string - ptrNameservers []string - appCallback *AppCallback - cache dnscache.Cacher - cacheFlushDomainsMap map[string]struct{} - sema semaphore - ciTable *clientinfo.Table - um *upstreamMonitor - router router.Router - ptrLoopGuard *loopGuard - lanLoopGuard *loopGuard - metricsQueryStats atomic.Bool - queryFromSelfMap sync.Map + cfg *ctrld.Config + localUpstreams []string + ptrNameservers []string + appCallback *AppCallback + cache dnscache.Cacher + cacheFlushDomainsMap map[string]struct{} + sema semaphore + ciTable *clientinfo.Table + um *upstreamMonitor + router router.Router + ptrLoopGuard *loopGuard + lanLoopGuard *loopGuard + metricsQueryStats atomic.Bool + queryFromSelfMap sync.Map + initInternalLogWriterOnce sync.Once + internalLogWriter *logWriter + internalLogSent time.Time selfUninstallMu sync.Mutex refusedQueryCount int @@ -517,6 +520,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } go p.apiConfigReload() p.postRun() + p.initInternalLogging() } wg.Wait() } From f71dd789158af7088233ac24c9297da376000d25 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 10 Dec 2024 17:34:31 +0700 Subject: [PATCH 015/100] cmd/cli: move cobra commands to separated file So each command initialization/logic can be read/update more easily. --- cmd/cli/cli.go | 987 +---------------------------------------- cmd/cli/commands.go | 1033 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 1043 insertions(+), 977 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 7a54367..21d2873 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -30,18 +30,14 @@ import ( "github.com/go-playground/validator/v10" "github.com/kardianos/service" "github.com/miekg/dns" - "github.com/minio/selfupdate" - "github.com/olekukonko/tablewriter" "github.com/pelletier/go-toml/v2" "github.com/rs/zerolog" "github.com/spf13/cobra" - "github.com/spf13/pflag" "github.com/spf13/viper" "tailscale.com/logtail/backoff" "tailscale.com/net/netmon" "github.com/Control-D-Inc/ctrld" - "github.com/Control-D-Inc/ctrld/internal/clientinfo" "github.com/Control-D-Inc/ctrld/internal/controld" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" "github.com/Control-D-Inc/ctrld/internal/router" @@ -129,978 +125,17 @@ func initCLI() { rootCmd.SetHelpCommand(&cobra.Command{Hidden: true}) rootCmd.CompletionOptions.HiddenDefaultCmd = true - runCmd := &cobra.Command{ - Use: "run", - Short: "Run the DNS proxy server", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - RunCobraCommand(cmd) - }, - } - runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon") - runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") - runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") - runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") - runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") - runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") - runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") - runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") - runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") - runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") - runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") - runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") - runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") - _ = runCmd.Flags().MarkHidden("dev") - runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") - _ = runCmd.Flags().MarkHidden("homedir") - runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) - _ = runCmd.Flags().MarkHidden("iface") - runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) - - runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true} - rootCmd.AddCommand(runCmd) - - startCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "start", - Short: "Install and start the ctrld service", - Long: `Install and start the ctrld service - -NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - checkStrFlagEmpty(cmd, cdUidFlagName) - checkStrFlagEmpty(cmd, cdOrgFlagName) - validateCdAndNextDNSFlags() - sc := &service.Config{} - *sc = *svcConfig - osArgs := os.Args[2:] - if os.Args[1] == "service" { - osArgs = os.Args[3:] - } - setDependencies(sc) - sc.Arguments = append([]string{"run"}, osArgs...) - - p := &prog{ - router: router.New(&cfg, cdUID != ""), - cfg: &cfg, - } - s, err := newService(p, sc) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - - status, err := s.Status() - isCtrldRunning := status == service.StatusRunning - isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) - - // Get current running iface, if any. - var currentIface string - - // If pin code was set, do not allow running start command. - if isCtrldRunning { - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - currentIface = runningIface(s) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - reportSetDnsOk := func(sockDir string) { - if cc := newSocketControlClient(ctx, s, sockDir); cc != nil { - if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { - if iface == "auto" { - iface = defaultIfaceName() - } - logger := mainLog.Load().With().Str("iface", iface).Logger() - logger.Debug().Msg("setting DNS successfully") - } - } - } - - // No config path, generating config in HOME directory. - noConfigStart := isNoConfigStart(cmd) - writeDefaultConfig := !noConfigStart && configBase64 == "" - - logServerStarted := make(chan struct{}) - // A buffer channel to gather log output from runCmd and report - // to user in case self-check process failed. - runCmdLogCh := make(chan string, 256) - ud, err := userHomeDir() - sockDir := ud - if err != nil { - mainLog.Load().Warn().Msg("log server did not start") - close(logServerStarted) - } else { - setWorkingDirectory(sc, ud) - if configPath == "" && writeDefaultConfig { - defaultConfigFile = filepath.Join(ud, defaultConfigFile) - } - sc.Arguments = append(sc.Arguments, "--homedir="+ud) - if d, err := socketDir(); err == nil { - sockDir = d - } - sockPath := filepath.Join(sockDir, ctrldLogUnixSock) - _ = os.Remove(sockPath) - go func() { - defer func() { - close(runCmdLogCh) - _ = os.Remove(sockPath) - }() - close(logServerStarted) - if conn := runLogServer(sockPath); conn != nil { - // Enough buffer for log message, we don't produce - // such long log message, but just in case. - buf := make([]byte, 1024) - for { - n, err := conn.Read(buf) - if err != nil { - return - } - msg := string(buf[:n]) - if _, _, found := strings.Cut(msg, msgExit); found { - cancel() - } - runCmdLogCh <- msg - } - } - }() - } - <-logServerStarted - - if !startOnly { - startOnly = len(osArgs) == 0 - } - // If user run "ctrld start" and ctrld is already installed, starting existing service. - if startOnly && isCtrldInstalled { - tryReadingConfigWithNotice(false, true) - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) - } - - initLogging() - tasks := []task{ - {s.Stop, false}, - resetDnsTask(p, s, isCtrldInstalled, currentIface), - {func() error { - // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { - if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { - return err - } - return nil - }) - return nil - }, false}, - {s.Start, true}, - {noticeWritingControlDConfig, false}, - } - mainLog.Load().Notice().Msg("Starting existing ctrld service") - if doTasks(tasks) { - mainLog.Load().Notice().Msg("Service started") - sockDir, err := socketDir() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Failed to get socket directory") - os.Exit(1) - } - reportSetDnsOk(sockDir) - } else { - mainLog.Load().Error().Err(err).Msg("Failed to start existing ctrld service") - os.Exit(1) - } - return - } - - if cdUID != "" { - doValidateCdRemoteConfig(cdUID) - } else if uid := cdUIDFromProvToken(); uid != "" { - cdUID = uid - mainLog.Load().Debug().Msg("using uid from provision token") - removeOrgFlagsFromArgs(sc) - // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. - sc.Arguments = append(sc.Arguments, "--cd="+cdUID) - } - if cdUID != "" { - validateCdUpstreamProtocol() - } - - if err := p.router.ConfigureService(sc); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to configure service on router") - } - - if configPath != "" { - v.SetConfigFile(configPath) - } - - tryReadingConfigWithNotice(writeDefaultConfig, true) - - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) - } - - initLogging() - - if nextdns != "" { - removeNextDNSFromArgs(sc) - } - - // Explicitly passing config, so on system where home directory could not be obtained, - // or sub-process env is different with the parent, we still behave correctly and use - // the expected config file. - if configPath == "" { - sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile) - } - - if router.Name() != "" && iface != "" { - mainLog.Load().Debug().Msg("cleaning up router before installing") - _ = p.router.Cleanup() - } - - tasks := []task{ - {s.Stop, false}, - {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, - {func() error { return ensureUninstall(s) }, false}, - resetDnsTask(p, s, isCtrldInstalled, currentIface), - {func() error { - // Save current DNS so we can restore later. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { - if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { - return err - } - return nil - }) - return nil - }, false}, - {s.Install, false}, - {s.Start, true}, - // Note that startCmd do not actually write ControlD config, but the config file was - // generated after s.Start, so we notice users here for consistent with nextdns mode. - {noticeWritingControlDConfig, false}, - } - mainLog.Load().Notice().Msg("Starting service") - if doTasks(tasks) { - if err := p.router.Install(sc); err != nil { - mainLog.Load().Warn().Err(err).Msg("post installation failed, please check system/service log for details error") - return - } - - ok, status, err := selfCheckStatus(ctx, s, sockDir) - switch { - case ok && status == service.StatusRunning: - mainLog.Load().Notice().Msg("Service started") - default: - marker := bytes.Repeat([]byte("="), 32) - // If ctrld service is not running, emitting log obtained from ctrld process. - if status != service.StatusRunning || ctx.Err() != nil { - mainLog.Load().Error().Msg("ctrld service may not have started due to an error or misconfiguration, service log:") - _, _ = mainLog.Load().Write(marker) - haveLog := false - for msg := range runCmdLogCh { - _, _ = mainLog.Load().Write([]byte(strings.ReplaceAll(msg, msgExit, ""))) - haveLog = true - } - // If we're unable to get log from "ctrld run", notice users about it. - if !haveLog { - mainLog.Load().Write([]byte(`"`)) - } - } - // Report any error if occurred. - if err != nil { - _, _ = mainLog.Load().Write(marker) - msg := fmt.Sprintf("An error occurred while performing test query: %s", err) - mainLog.Load().Write([]byte(msg)) - } - // If ctrld service is running but selfCheckStatus failed, it could be related - // to user's system firewall configuration, notice users about it. - if status == service.StatusRunning && err == nil { - _, _ = mainLog.Load().Write(marker) - mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`)) - mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) - } - - _, _ = mainLog.Load().Write(marker) - uninstall(p, s) - os.Exit(1) - } - reportSetDnsOk(sockDir) - } - }, - } - // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". - startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") - startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") - startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") - startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") - startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") - startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") - startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") - startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") - startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") - startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") - startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") - startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") - _ = startCmd.Flags().MarkHidden("dev") - startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) - startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") - startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) - startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`) - startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service") - _ = startCmd.Flags().MarkHidden("start_only") - - routerCmd := &cobra.Command{ - Use: "setup", - Run: func(cmd *cobra.Command, _ []string) { - exe, err := os.Executable() - if err != nil { - mainLog.Load().Fatal().Msgf("could not find executable path: %v", err) - os.Exit(1) - } - flags := make([]string, 0) - cmd.Flags().Visit(func(flag *pflag.Flag) { - flags = append(flags, fmt.Sprintf("--%s=%s", flag.Name, flag.Value)) - }) - cmdArgs := []string{"start"} - cmdArgs = append(cmdArgs, flags...) - command := exec.Command(exe, cmdArgs...) - command.Stdout = os.Stdout - command.Stderr = os.Stderr - command.Stdin = os.Stdin - if err := command.Run(); err != nil { - mainLog.Load().Fatal().Msg(err.Error()) - } - }, - } - routerCmd.Flags().AddFlagSet(startCmd.Flags()) - routerCmd.Hidden = true - rootCmd.AddCommand(routerCmd) - - stopCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "stop", - Short: "Stop the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - initLogging() - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - if doTasks([]task{{s.Stop, true}}) { - p.router.Cleanup() - p.resetDNS() - if router.WaitProcessExited() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - for { - select { - case <-ctx.Done(): - mainLog.Load().Error().Msg("timeout while waiting for service to stop") - return - default: - } - time.Sleep(time.Second) - if status, _ := s.Status(); status == service.StatusStopped { - break - } - } - } - mainLog.Load().Notice().Msg("Service stopped") - } - }, - } - stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`) - stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`) - _ = stopCmd.Flags().MarkHidden("pin") - - restartCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "restart", - Short: "Restart the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - cdUID = curCdUID() - cdMode := cdUID != "" - - p := &prog{router: router.New(&cfg, cdMode)} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("service not installed") - return - } - initLogging() - - if cdMode { - doValidateCdRemoteConfig(cdUID) - } - - iface = runningIface(s) - tasks := []task{ - {s.Stop, false}, - {s.Start, true}, - } - if doTasks(tasks) { - dir, err := socketDir() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") - return - } - cc := newSocketControlClient(context.TODO(), s, dir) - if cc == nil { - mainLog.Load().Notice().Msg("Service was not restarted") - os.Exit(1) - } - _, _ = cc.post(ifacePath, nil) - mainLog.Load().Notice().Msg("Service restarted") - } - }, - } - - reloadCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "reload", - Short: "Reload the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - dir, err := socketDir() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") - } - cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - resp, err := cc.post(reloadPath, nil) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld") - } - defer resp.Body.Close() - switch resp.StatusCode { - case http.StatusOK: - mainLog.Load().Notice().Msg("Service reloaded") - case http.StatusCreated: - s, err := newService(&prog{}, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.") - mainLog.Load().Warn().Msg("Restarting service") - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - mainLog.Load().Warn().Msg("Service not installed") - return - } - restartCmd.Run(cmd, args) - default: - buf, err := io.ReadAll(resp.Body) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("could not read response from control server") - } - mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf)) - } - }, - } - statusCmd := &cobra.Command{ - Use: "status", - Short: "Show status of the ctrld service", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - s, err := newService(&prog{}, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - status, err := s.Status() - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - os.Exit(1) - } - switch status { - case service.StatusUnknown: - mainLog.Load().Notice().Msg("Unknown status") - os.Exit(2) - case service.StatusRunning: - mainLog.Load().Notice().Msg("Service is running") - os.Exit(0) - case service.StatusStopped: - mainLog.Load().Notice().Msg("Service is stopped") - os.Exit(1) - } - }, - } - if runtime.GOOS == "darwin" { - // On darwin, running status command without privileges may return wrong information. - statusCmd.PreRun = func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - } - } - - uninstallCmd := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "uninstall", - Short: "Stop and uninstall the ctrld service", - Long: `Stop and uninstall the ctrld service. - -NOTE: Uninstalling will set DNS to values provided by DHCP.`, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - if iface == "" { - iface = "auto" - } - if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { - os.Exit(deactivationPinInvalidExitCode) - } - uninstall(p, s) - if cleanup { - var files []string - // Config file. - files = append(files, v.ConfigFileUsed()) - // Log file and backup log file. - // For safety, only process if log file path is absolute. - if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) { - files = append(files, logFile) - oldLogFile := logFile + oldLogSuffix - if _, err := os.Stat(oldLogFile); err == nil { - files = append(files, oldLogFile) - } - } - // Socket files. - if dir, _ := socketDir(); dir != "" { - files = append(files, filepath.Join(dir, ctrldControlUnixSock)) - files = append(files, filepath.Join(dir, ctrldLogUnixSock)) - } - // Static DNS settings files. - withEachPhysicalInterfaces("", "", func(i *net.Interface) error { - file := savedStaticDnsSettingsFilePath(i) - if _, err := os.Stat(file); err == nil { - files = append(files, file) - } - return nil - }) - // Windows forwarders file. - if hasLocalDnsServerRunning() { - files = append(files, absHomeDir(windowsForwardersFilename)) - } - // Binary itself. - bin, _ := os.Executable() - if bin != "" && supportedSelfDelete { - files = append(files, bin) - } - // Backup file after upgrading. - oldBin := bin + oldBinSuffix - if _, err := os.Stat(oldBin); err == nil { - files = append(files, oldBin) - } - for _, file := range files { - if file == "" { - continue - } - if err := os.Remove(file); err != nil { - if os.IsNotExist(err) { - continue - } - mainLog.Load().Warn().Err(err).Msg("failed to remove file") - } else { - mainLog.Load().Debug().Msgf("file removed: %s", file) - } - } - if err := selfDeleteExe(); err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to remove file") - } else { - if !supportedSelfDelete { - mainLog.Load().Debug().Msgf("file removed: %s", bin) - } - } - } - }, - } - uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`) - uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for uninstalling ctrld`) - _ = uninstallCmd.Flags().MarkHidden("pin") - uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`) - - listIfacesCmd := &cobra.Command{ - Use: "list", - Short: "List network interfaces of the host", - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - err := netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { - fmt.Printf("Index : %d\n", i.Index) - fmt.Printf("Name : %s\n", i.Name) - addrs, _ := i.Addrs() - for i, ipaddr := range addrs { - if i == 0 { - fmt.Printf("Addrs : %v\n", ipaddr) - continue - } - fmt.Printf(" %v\n", ipaddr) - } - for i, dns := range currentDNS(i.Interface) { - if i == 0 { - fmt.Printf("DNS : %s\n", dns) - continue - } - fmt.Printf(" : %s\n", dns) - } - println() - }) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - } - }, - } - interfacesCmd := &cobra.Command{ - Use: "interfaces", - Short: "Manage network interfaces", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - listIfacesCmd.Use, - }, - } - interfacesCmd.AddCommand(listIfacesCmd) - - serviceCmd := &cobra.Command{ - Use: "service", - Short: "Manage ctrld service", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - startCmd.Use, - stopCmd.Use, - restartCmd.Use, - reloadCmd.Use, - statusCmd.Use, - uninstallCmd.Use, - interfacesCmd.Use, - }, - } - serviceCmd.AddCommand(startCmd) - serviceCmd.AddCommand(stopCmd) - serviceCmd.AddCommand(restartCmd) - serviceCmd.AddCommand(reloadCmd) - serviceCmd.AddCommand(statusCmd) - serviceCmd.AddCommand(uninstallCmd) - serviceCmd.AddCommand(interfacesCmd) - rootCmd.AddCommand(serviceCmd) - startCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "start", - Short: "Quick start service and configure DNS on interface", - Long: `Quick start service and configure DNS on interface - -NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, - Run: func(cmd *cobra.Command, args []string) { - if len(os.Args) == 2 { - startOnly = true - } - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - startCmd.Run(cmd, args) - }, - } - startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`) - startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) - rootCmd.AddCommand(startCmdAlias) - stopCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "stop", - Short: "Quick stop service and remove DNS from interface", - Run: func(cmd *cobra.Command, args []string) { - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - stopCmd.Run(cmd, args) - }, - } - stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) - stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) - rootCmd.AddCommand(stopCmdAlias) - - restartCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "restart", - Short: "Restart the ctrld service", - Run: func(cmd *cobra.Command, args []string) { - restartCmd.Run(cmd, args) - }, - } - rootCmd.AddCommand(restartCmdAlias) - - reloadCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "reload", - Short: "Reload the ctrld service", - Run: func(cmd *cobra.Command, args []string) { - reloadCmd.Run(cmd, args) - }, - } - rootCmd.AddCommand(reloadCmdAlias) - - statusCmdAlias := &cobra.Command{ - Use: "status", - Short: "Show status of the ctrld service", - Args: cobra.NoArgs, - Run: statusCmd.Run, - } - rootCmd.AddCommand(statusCmdAlias) - - uninstallCmdAlias := &cobra.Command{ - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Use: "uninstall", - Short: "Stop and uninstall the ctrld service", - Long: `Stop and uninstall the ctrld service. - -NOTE: Uninstalling will set DNS to values provided by DHCP.`, - Run: func(cmd *cobra.Command, args []string) { - if !cmd.Flags().Changed("iface") { - os.Args = append(os.Args, "--iface="+ifaceStartStop) - } - iface = ifaceStartStop - uninstallCmd.Run(cmd, args) - }, - } - uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) - uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags()) - rootCmd.AddCommand(uninstallCmdAlias) - - listClientsCmd := &cobra.Command{ - Use: "list", - Short: "List clients that ctrld discovered", - Args: cobra.NoArgs, - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Run: func(cmd *cobra.Command, args []string) { - dir, err := socketDir() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") - } - cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - resp, err := cc.post(listClientsPath, nil) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to get clients list") - } - defer resp.Body.Close() - - var clients []*clientinfo.Client - if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to decode clients list result") - } - map2Slice := func(m map[string]struct{}) []string { - s := make([]string, 0, len(m)) - for k := range m { - if k == "" { // skip empty source from output. - continue - } - s = append(s, k) - } - sort.Strings(s) - return s - } - // If metrics is enabled, server set this for all clients, so we can check only the first one. - // Ideally, we may have a field in response to indicate that query count should be shown, but - // it would break earlier version of ctrld, which only look list of clients in response. - withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount - data := make([][]string, len(clients)) - for i, c := range clients { - row := []string{ - c.IP.String(), - c.Hostname, - c.Mac, - strings.Join(map2Slice(c.Source), ","), - } - if withQueryCount { - row = append(row, strconv.FormatInt(c.QueryCount, 10)) - } - data[i] = row - } - table := tablewriter.NewWriter(os.Stdout) - headers := []string{"IP", "Hostname", "Mac", "Discovered"} - if withQueryCount { - headers = append(headers, "Queries") - } - table.SetHeader(headers) - table.SetAutoFormatHeaders(false) - table.AppendBulk(data) - table.Render() - }, - } - clientsCmd := &cobra.Command{ - Use: "clients", - Short: "Manage clients", - Args: cobra.OnlyValidArgs, - ValidArgs: []string{ - listClientsCmd.Use, - }, - } - clientsCmd.AddCommand(listClientsCmd) - rootCmd.AddCommand(clientsCmd) - - const ( - upgradeChannelDev = "dev" - upgradeChannelProd = "prod" - upgradeChannelDefault = "default" - ) - upgradeChannel := map[string]string{ - upgradeChannelDefault: "https://dl.controld.dev", - upgradeChannelDev: "https://dl.controld.dev", - upgradeChannelProd: "https://dl.controld.com", - } - if isStableVersion(curVersion()) { - upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd] - } - upgradeCmd := &cobra.Command{ - Use: "upgrade", - Short: "Upgrading ctrld to latest version", - ValidArgs: []string{upgradeChannelDev, upgradeChannelProd}, - Args: cobra.MaximumNArgs(1), - PreRun: func(cmd *cobra.Command, args []string) { - checkHasElevatedPrivilege() - }, - Run: func(cmd *cobra.Command, args []string) { - bin, err := os.Executable() - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path") - } - sc := &service.Config{} - *sc = *svcConfig - sc.Executable = bin - readConfig(false) - v.Unmarshal(&cfg) - p := &prog{router: router.New(&cfg, runInCdMode())} - s, err := newService(p, sc) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - return - } - - svcInstalled := true - if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { - svcInstalled = false - } - oldBin := bin + oldBinSuffix - baseUrl := upgradeChannel[upgradeChannelDefault] - if len(args) > 0 { - channel := args[0] - switch channel { - case upgradeChannelProd, upgradeChannelDev: // ok - default: - mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev) - } - baseUrl = upgradeChannel[channel] - } - dlUrl := upgradeUrl(baseUrl) - mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) - resp, err := http.Get(dlUrl) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to download binary") - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - mainLog.Load().Fatal().Msgf("could not download binary: %s", http.StatusText(resp.StatusCode)) - } - mainLog.Load().Debug().Msg("Updating current binary") - if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil { - if rerr := selfupdate.RollbackError(err); rerr != nil { - mainLog.Load().Error().Err(rerr).Msg("could not rollback old binary") - } - mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") - } - - doRestart := func() bool { - if !svcInstalled { - return true - } - tasks := []task{ - {s.Stop, false}, - {s.Start, false}, - } - if doTasks(tasks) { - if dir, err := socketDir(); err == nil { - if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { - _, _ = cc.post(ifacePath, nil) - return true - } - } - } - return false - } - if svcInstalled { - mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") - } - if doRestart() { - _ = os.Remove(oldBin) - _ = os.Chmod(bin, 0755) - ver := "unknown version" - out, err := exec.Command(bin, "--version").CombinedOutput() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version") - } - if after, found := strings.CutPrefix(string(out), "ctrld version "); found { - ver = after - } - mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver) - return - } - - mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin) - if err := os.Remove(bin); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to remove new binary") - } - if err := os.Rename(oldBin, bin); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to restore old binary") - } - if doRestart() { - mainLog.Load().Notice().Msg("Restored previous binary successfully") - return - } - }, - } - rootCmd.AddCommand(upgradeCmd) - + initRunCmd() + startCmd := initStartCmd() + stopCmd := initStopCmd() + restartCmd := initRestartCmd() + reloadCmd := initReloadCmd(restartCmd) + statusCmd := initStatusCmd() + uninstallCmd := initUninstallCmd() + interfacesCmd := initInterfacesCmd() + initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd) + initClientsCmd() + initUpgradeCmd() initLogCmd() } diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index a98fbb4..f5fbd5b 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -1,16 +1,37 @@ package cli import ( + "bytes" + "context" "encoding/json" + "errors" "fmt" "io" + "net" "net/http" + "net/netip" + "os" + "os/exec" "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "time" + "github.com/kardianos/service" + "github.com/minio/selfupdate" + "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" + "github.com/spf13/pflag" + "tailscale.com/net/netmon" + + "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/clientinfo" + "github.com/Control-D-Inc/ctrld/internal/router" ) -func initLogCmd() { +func initLogCmd() *cobra.Command { logSendCmd := &cobra.Command{ Use: "send", Short: "Send runtime debug logs to ControlD", @@ -85,4 +106,1014 @@ func initLogCmd() { logCmd.AddCommand(logSendCmd) logCmd.AddCommand(logViewCmd) rootCmd.AddCommand(logCmd) + + return logCmd +} + +func initRunCmd() *cobra.Command { + runCmd := &cobra.Command{ + Use: "run", + Short: "Run the DNS proxy server", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + RunCobraCommand(cmd) + }, + } + runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon") + runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") + runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") + runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") + runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") + runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") + runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") + runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") + runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") + runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") + runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") + runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") + runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = runCmd.Flags().MarkHidden("dev") + runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") + _ = runCmd.Flags().MarkHidden("homedir") + runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + _ = runCmd.Flags().MarkHidden("iface") + runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) + + runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true} + rootCmd.AddCommand(runCmd) + + return runCmd +} + +func initStartCmd() *cobra.Command { + startCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "start", + Short: "Install and start the ctrld service", + Long: `Install and start the ctrld service + +NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + checkStrFlagEmpty(cmd, cdUidFlagName) + checkStrFlagEmpty(cmd, cdOrgFlagName) + validateCdAndNextDNSFlags() + sc := &service.Config{} + *sc = *svcConfig + osArgs := os.Args[2:] + if os.Args[1] == "service" { + osArgs = os.Args[3:] + } + setDependencies(sc) + sc.Arguments = append([]string{"run"}, osArgs...) + + p := &prog{ + router: router.New(&cfg, cdUID != ""), + cfg: &cfg, + } + s, err := newService(p, sc) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + + status, err := s.Status() + isCtrldRunning := status == service.StatusRunning + isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) + + // Get current running iface, if any. + var currentIface string + + // If pin code was set, do not allow running start command. + if isCtrldRunning { + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + os.Exit(deactivationPinInvalidExitCode) + } + currentIface = runningIface(s) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + reportSetDnsOk := func(sockDir string) { + if cc := newSocketControlClient(ctx, s, sockDir); cc != nil { + if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK { + if iface == "auto" { + iface = defaultIfaceName() + } + logger := mainLog.Load().With().Str("iface", iface).Logger() + logger.Debug().Msg("setting DNS successfully") + } + } + } + + // No config path, generating config in HOME directory. + noConfigStart := isNoConfigStart(cmd) + writeDefaultConfig := !noConfigStart && configBase64 == "" + + logServerStarted := make(chan struct{}) + // A buffer channel to gather log output from runCmd and report + // to user in case self-check process failed. + runCmdLogCh := make(chan string, 256) + ud, err := userHomeDir() + sockDir := ud + if err != nil { + mainLog.Load().Warn().Msg("log server did not start") + close(logServerStarted) + } else { + setWorkingDirectory(sc, ud) + if configPath == "" && writeDefaultConfig { + defaultConfigFile = filepath.Join(ud, defaultConfigFile) + } + sc.Arguments = append(sc.Arguments, "--homedir="+ud) + if d, err := socketDir(); err == nil { + sockDir = d + } + sockPath := filepath.Join(sockDir, ctrldLogUnixSock) + _ = os.Remove(sockPath) + go func() { + defer func() { + close(runCmdLogCh) + _ = os.Remove(sockPath) + }() + close(logServerStarted) + if conn := runLogServer(sockPath); conn != nil { + // Enough buffer for log message, we don't produce + // such long log message, but just in case. + buf := make([]byte, 1024) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + msg := string(buf[:n]) + if _, _, found := strings.Cut(msg, msgExit); found { + cancel() + } + runCmdLogCh <- msg + } + } + }() + } + <-logServerStarted + + if !startOnly { + startOnly = len(osArgs) == 0 + } + // If user run "ctrld start" and ctrld is already installed, starting existing service. + if startOnly && isCtrldInstalled { + tryReadingConfigWithNotice(false, true) + if err := v.Unmarshal(&cfg); err != nil { + mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + } + + initLogging() + tasks := []task{ + {s.Stop, false}, + resetDnsTask(p, s, isCtrldInstalled, currentIface), + {func() error { + // Save current DNS so we can restore later. + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { + return err + } + return nil + }) + return nil + }, false}, + {s.Start, true}, + {noticeWritingControlDConfig, false}, + } + mainLog.Load().Notice().Msg("Starting existing ctrld service") + if doTasks(tasks) { + mainLog.Load().Notice().Msg("Service started") + sockDir, err := socketDir() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Failed to get socket directory") + os.Exit(1) + } + reportSetDnsOk(sockDir) + } else { + mainLog.Load().Error().Err(err).Msg("Failed to start existing ctrld service") + os.Exit(1) + } + return + } + + if cdUID != "" { + doValidateCdRemoteConfig(cdUID) + } else if uid := cdUIDFromProvToken(); uid != "" { + cdUID = uid + mainLog.Load().Debug().Msg("using uid from provision token") + removeOrgFlagsFromArgs(sc) + // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. + sc.Arguments = append(sc.Arguments, "--cd="+cdUID) + } + if cdUID != "" { + validateCdUpstreamProtocol() + } + + if err := p.router.ConfigureService(sc); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to configure service on router") + } + + if configPath != "" { + v.SetConfigFile(configPath) + } + + tryReadingConfigWithNotice(writeDefaultConfig, true) + + if err := v.Unmarshal(&cfg); err != nil { + mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + } + + initLogging() + + if nextdns != "" { + removeNextDNSFromArgs(sc) + } + + // Explicitly passing config, so on system where home directory could not be obtained, + // or sub-process env is different with the parent, we still behave correctly and use + // the expected config file. + if configPath == "" { + sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile) + } + + if router.Name() != "" && iface != "" { + mainLog.Load().Debug().Msg("cleaning up router before installing") + _ = p.router.Cleanup() + } + + tasks := []task{ + {s.Stop, false}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, + {func() error { return ensureUninstall(s) }, false}, + resetDnsTask(p, s, isCtrldInstalled, currentIface), + {func() error { + // Save current DNS so we can restore later. + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil { + return err + } + return nil + }) + return nil + }, false}, + {s.Install, false}, + {s.Start, true}, + // Note that startCmd do not actually write ControlD config, but the config file was + // generated after s.Start, so we notice users here for consistent with nextdns mode. + {noticeWritingControlDConfig, false}, + } + mainLog.Load().Notice().Msg("Starting service") + if doTasks(tasks) { + if err := p.router.Install(sc); err != nil { + mainLog.Load().Warn().Err(err).Msg("post installation failed, please check system/service log for details error") + return + } + + ok, status, err := selfCheckStatus(ctx, s, sockDir) + switch { + case ok && status == service.StatusRunning: + mainLog.Load().Notice().Msg("Service started") + default: + marker := bytes.Repeat([]byte("="), 32) + // If ctrld service is not running, emitting log obtained from ctrld process. + if status != service.StatusRunning || ctx.Err() != nil { + mainLog.Load().Error().Msg("ctrld service may not have started due to an error or misconfiguration, service log:") + _, _ = mainLog.Load().Write(marker) + haveLog := false + for msg := range runCmdLogCh { + _, _ = mainLog.Load().Write([]byte(strings.ReplaceAll(msg, msgExit, ""))) + haveLog = true + } + // If we're unable to get log from "ctrld run", notice users about it. + if !haveLog { + mainLog.Load().Write([]byte(`"`)) + } + } + // Report any error if occurred. + if err != nil { + _, _ = mainLog.Load().Write(marker) + msg := fmt.Sprintf("An error occurred while performing test query: %s", err) + mainLog.Load().Write([]byte(msg)) + } + // If ctrld service is running but selfCheckStatus failed, it could be related + // to user's system firewall configuration, notice users about it. + if status == service.StatusRunning && err == nil { + _, _ = mainLog.Load().Write(marker) + mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`)) + mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) + } + + _, _ = mainLog.Load().Write(marker) + uninstall(p, s) + os.Exit(1) + } + reportSetDnsOk(sockDir) + } + }, + } + // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". + startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") + startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") + startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") + startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint") + startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint") + startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") + startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") + startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") + startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") + startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") + startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") + startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = startCmd.Flags().MarkHidden("dev") + startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") + startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) + startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`) + startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service") + _ = startCmd.Flags().MarkHidden("start_only") + + routerCmd := &cobra.Command{ + Use: "setup", + Run: func(cmd *cobra.Command, _ []string) { + exe, err := os.Executable() + if err != nil { + mainLog.Load().Fatal().Msgf("could not find executable path: %v", err) + os.Exit(1) + } + flags := make([]string, 0) + cmd.Flags().Visit(func(flag *pflag.Flag) { + flags = append(flags, fmt.Sprintf("--%s=%s", flag.Name, flag.Value)) + }) + cmdArgs := []string{"start"} + cmdArgs = append(cmdArgs, flags...) + command := exec.Command(exe, cmdArgs...) + command.Stdout = os.Stdout + command.Stderr = os.Stderr + command.Stdin = os.Stdin + if err := command.Run(); err != nil { + mainLog.Load().Fatal().Msg(err.Error()) + } + }, + } + routerCmd.Flags().AddFlagSet(startCmd.Flags()) + routerCmd.Hidden = true + rootCmd.AddCommand(routerCmd) + + startCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "start", + Short: "Quick start service and configure DNS on interface", + Long: `Quick start service and configure DNS on interface + +NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`, + Run: func(cmd *cobra.Command, args []string) { + if len(os.Args) == 2 { + startOnly = true + } + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + startCmd.Run(cmd, args) + }, + } + startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) + rootCmd.AddCommand(startCmdAlias) + + return startCmd +} + +func initStopCmd() *cobra.Command { + stopCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "stop", + Short: "Stop the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + initLogging() + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + os.Exit(deactivationPinInvalidExitCode) + } + if doTasks([]task{{s.Stop, true}}) { + p.router.Cleanup() + p.resetDNS() + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + return + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } + mainLog.Load().Notice().Msg("Service stopped") + } + }, + } + stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`) + stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`) + _ = stopCmd.Flags().MarkHidden("pin") + + stopCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "stop", + Short: "Quick stop service and remove DNS from interface", + Run: func(cmd *cobra.Command, args []string) { + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + stopCmd.Run(cmd, args) + }, + } + stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) + stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) + rootCmd.AddCommand(stopCmdAlias) + + return stopCmd +} + +func initRestartCmd() *cobra.Command { + restartCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "restart", + Short: "Restart the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + readConfig(false) + v.Unmarshal(&cfg) + cdUID = curCdUID() + cdMode := cdUID != "" + + p := &prog{router: router.New(&cfg, cdMode)} + s, err := newService(p, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + initLogging() + + if cdMode { + doValidateCdRemoteConfig(cdUID) + } + + iface = runningIface(s) + tasks := []task{ + {s.Stop, false}, + {s.Start, true}, + } + if doTasks(tasks) { + dir, err := socketDir() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") + return + } + cc := newSocketControlClient(context.TODO(), s, dir) + if cc == nil { + mainLog.Load().Notice().Msg("Service was not restarted") + os.Exit(1) + } + _, _ = cc.post(ifacePath, nil) + mainLog.Load().Notice().Msg("Service restarted") + } + }, + } + + restartCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "restart", + Short: "Restart the ctrld service", + Run: func(cmd *cobra.Command, args []string) { + restartCmd.Run(cmd, args) + }, + } + rootCmd.AddCommand(restartCmdAlias) + + return restartCmd +} + +func initReloadCmd(restartCmd *cobra.Command) *cobra.Command { + reloadCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + dir, err := socketDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(reloadPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld") + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + mainLog.Load().Notice().Msg("Service reloaded") + case http.StatusCreated: + s, err := newService(&prog{}, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.") + mainLog.Load().Warn().Msg("Restarting service") + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("Service not installed") + return + } + restartCmd.Run(cmd, args) + default: + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("could not read response from control server") + } + mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf)) + } + }, + } + + reloadCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Run: func(cmd *cobra.Command, args []string) { + reloadCmd.Run(cmd, args) + }, + } + rootCmd.AddCommand(reloadCmdAlias) + + return reloadCmd +} + +func initStatusCmd() *cobra.Command { + statusCmd := &cobra.Command{ + Use: "status", + Short: "Show status of the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + s, err := newService(&prog{}, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + status, err := s.Status() + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + os.Exit(1) + } + switch status { + case service.StatusUnknown: + mainLog.Load().Notice().Msg("Unknown status") + os.Exit(2) + case service.StatusRunning: + mainLog.Load().Notice().Msg("Service is running") + os.Exit(0) + case service.StatusStopped: + mainLog.Load().Notice().Msg("Service is stopped") + os.Exit(1) + } + }, + } + if runtime.GOOS == "darwin" { + // On darwin, running status command without privileges may return wrong information. + statusCmd.PreRun = func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + } + } + + statusCmdAlias := &cobra.Command{ + Use: "status", + Short: "Show status of the ctrld service", + Args: cobra.NoArgs, + Run: statusCmd.Run, + } + rootCmd.AddCommand(statusCmdAlias) + + return statusCmd +} + +func initUninstallCmd() *cobra.Command { + uninstallCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "uninstall", + Short: "Stop and uninstall the ctrld service", + Long: `Stop and uninstall the ctrld service. + +NOTE: Uninstalling will set DNS to values provided by DHCP.`, + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + if iface == "" { + iface = "auto" + } + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { + os.Exit(deactivationPinInvalidExitCode) + } + uninstall(p, s) + if cleanup { + var files []string + // Config file. + files = append(files, v.ConfigFileUsed()) + // Log file and backup log file. + // For safety, only process if log file path is absolute. + if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) { + files = append(files, logFile) + oldLogFile := logFile + oldLogSuffix + if _, err := os.Stat(oldLogFile); err == nil { + files = append(files, oldLogFile) + } + } + // Socket files. + if dir, _ := socketDir(); dir != "" { + files = append(files, filepath.Join(dir, ctrldControlUnixSock)) + files = append(files, filepath.Join(dir, ctrldLogUnixSock)) + } + // Static DNS settings files. + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + file := savedStaticDnsSettingsFilePath(i) + if _, err := os.Stat(file); err == nil { + files = append(files, file) + } + return nil + }) + // Windows forwarders file. + if hasLocalDnsServerRunning() { + files = append(files, absHomeDir(windowsForwardersFilename)) + } + // Binary itself. + bin, _ := os.Executable() + if bin != "" && supportedSelfDelete { + files = append(files, bin) + } + // Backup file after upgrading. + oldBin := bin + oldBinSuffix + if _, err := os.Stat(oldBin); err == nil { + files = append(files, oldBin) + } + for _, file := range files { + if file == "" { + continue + } + if err := os.Remove(file); err != nil { + if os.IsNotExist(err) { + continue + } + mainLog.Load().Warn().Err(err).Msg("failed to remove file") + } else { + mainLog.Load().Debug().Msgf("file removed: %s", file) + } + } + if err := selfDeleteExe(); err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to remove file") + } else { + if !supportedSelfDelete { + mainLog.Load().Debug().Msgf("file removed: %s", bin) + } + } + } + }, + } + uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`) + uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for uninstalling ctrld`) + _ = uninstallCmd.Flags().MarkHidden("pin") + uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`) + + uninstallCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Use: "uninstall", + Short: "Stop and uninstall the ctrld service", + Long: `Stop and uninstall the ctrld service. + +NOTE: Uninstalling will set DNS to values provided by DHCP.`, + Run: func(cmd *cobra.Command, args []string) { + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } + iface = ifaceStartStop + uninstallCmd.Run(cmd, args) + }, + } + uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) + uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags()) + rootCmd.AddCommand(uninstallCmdAlias) + + return uninstallCmd +} + +func initInterfacesCmd() *cobra.Command { + listIfacesCmd := &cobra.Command{ + Use: "list", + Short: "List network interfaces of the host", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + err := netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { + fmt.Printf("Index : %d\n", i.Index) + fmt.Printf("Name : %s\n", i.Name) + addrs, _ := i.Addrs() + for i, ipaddr := range addrs { + if i == 0 { + fmt.Printf("Addrs : %v\n", ipaddr) + continue + } + fmt.Printf(" %v\n", ipaddr) + } + for i, dns := range currentDNS(i.Interface) { + if i == 0 { + fmt.Printf("DNS : %s\n", dns) + continue + } + fmt.Printf(" : %s\n", dns) + } + println() + }) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + } + }, + } + interfacesCmd := &cobra.Command{ + Use: "interfaces", + Short: "Manage network interfaces", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + listIfacesCmd.Use, + }, + } + interfacesCmd.AddCommand(listIfacesCmd) + + return interfacesCmd +} + +func initClientsCmd() *cobra.Command { + listClientsCmd := &cobra.Command{ + Use: "list", + Short: "List clients that ctrld discovered", + Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Run: func(cmd *cobra.Command, args []string) { + dir, err := socketDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(listClientsPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to get clients list") + } + defer resp.Body.Close() + + var clients []*clientinfo.Client + if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to decode clients list result") + } + map2Slice := func(m map[string]struct{}) []string { + s := make([]string, 0, len(m)) + for k := range m { + if k == "" { // skip empty source from output. + continue + } + s = append(s, k) + } + sort.Strings(s) + return s + } + // If metrics is enabled, server set this for all clients, so we can check only the first one. + // Ideally, we may have a field in response to indicate that query count should be shown, but + // it would break earlier version of ctrld, which only look list of clients in response. + withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount + data := make([][]string, len(clients)) + for i, c := range clients { + row := []string{ + c.IP.String(), + c.Hostname, + c.Mac, + strings.Join(map2Slice(c.Source), ","), + } + if withQueryCount { + row = append(row, strconv.FormatInt(c.QueryCount, 10)) + } + data[i] = row + } + table := tablewriter.NewWriter(os.Stdout) + headers := []string{"IP", "Hostname", "Mac", "Discovered"} + if withQueryCount { + headers = append(headers, "Queries") + } + table.SetHeader(headers) + table.SetAutoFormatHeaders(false) + table.AppendBulk(data) + table.Render() + }, + } + clientsCmd := &cobra.Command{ + Use: "clients", + Short: "Manage clients", + Args: cobra.OnlyValidArgs, + ValidArgs: []string{ + listClientsCmd.Use, + }, + } + clientsCmd.AddCommand(listClientsCmd) + rootCmd.AddCommand(clientsCmd) + + return clientsCmd +} + +func initUpgradeCmd() *cobra.Command { + const ( + upgradeChannelDev = "dev" + upgradeChannelProd = "prod" + upgradeChannelDefault = "default" + ) + upgradeChannel := map[string]string{ + upgradeChannelDefault: "https://dl.controld.dev", + upgradeChannelDev: "https://dl.controld.dev", + upgradeChannelProd: "https://dl.controld.com", + } + if isStableVersion(curVersion()) { + upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd] + } + upgradeCmd := &cobra.Command{ + Use: "upgrade", + Short: "Upgrading ctrld to latest version", + ValidArgs: []string{upgradeChannelDev, upgradeChannelProd}, + Args: cobra.MaximumNArgs(1), + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, + Run: func(cmd *cobra.Command, args []string) { + bin, err := os.Executable() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path") + } + sc := &service.Config{} + *sc = *svcConfig + sc.Executable = bin + readConfig(false) + v.Unmarshal(&cfg) + p := &prog{router: router.New(&cfg, runInCdMode())} + s, err := newService(p, sc) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + + svcInstalled := true + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + svcInstalled = false + } + oldBin := bin + oldBinSuffix + baseUrl := upgradeChannel[upgradeChannelDefault] + if len(args) > 0 { + channel := args[0] + switch channel { + case upgradeChannelProd, upgradeChannelDev: // ok + default: + mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev) + } + baseUrl = upgradeChannel[channel] + } + dlUrl := upgradeUrl(baseUrl) + mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl) + resp, err := http.Get(dlUrl) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to download binary") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + mainLog.Load().Fatal().Msgf("could not download binary: %s", http.StatusText(resp.StatusCode)) + } + mainLog.Load().Debug().Msg("Updating current binary") + if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil { + if rerr := selfupdate.RollbackError(err); rerr != nil { + mainLog.Load().Error().Err(rerr).Msg("could not rollback old binary") + } + mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") + } + + doRestart := func() bool { + if !svcInstalled { + return true + } + tasks := []task{ + {s.Stop, false}, + {s.Start, false}, + } + if doTasks(tasks) { + if dir, err := socketDir(); err == nil { + if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { + _, _ = cc.post(ifacePath, nil) + return true + } + } + } + return false + } + if svcInstalled { + mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") + } + if doRestart() { + _ = os.Remove(oldBin) + _ = os.Chmod(bin, 0755) + ver := "unknown version" + out, err := exec.Command(bin, "--version").CombinedOutput() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version") + } + if after, found := strings.CutPrefix(string(out), "ctrld version "); found { + ver = after + } + mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver) + return + } + + mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin) + if err := os.Remove(bin); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to remove new binary") + } + if err := os.Rename(oldBin, bin); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to restore old binary") + } + if doRestart() { + mainLog.Load().Notice().Msg("Restored previous binary successfully") + return + } + }, + } + rootCmd.AddCommand(upgradeCmd) + + return upgradeCmd +} + +func initServicesCmd(commands ...*cobra.Command) *cobra.Command { + serviceCmd := &cobra.Command{ + Use: "service", + Short: "Manage ctrld service", + Args: cobra.OnlyValidArgs, + } + serviceCmd.ValidArgs = make([]string, len(commands)) + for i, cmd := range commands { + serviceCmd.ValidArgs[i] = cmd.Use + serviceCmd.AddCommand(cmd) + } + rootCmd.AddCommand(serviceCmd) + + return serviceCmd } From 02ee113b9500c408fcc031cff9a870c2a749a499 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 12 Dec 2024 18:47:22 +0700 Subject: [PATCH 016/100] Add missing kea dhcp4 format when validating config Thanks Discord user cosmoxl for reporting this. --- config.go | 2 +- config_test.go | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 3f9b2f8..4302c5d 100644 --- a/config.go +++ b/config.go @@ -205,7 +205,7 @@ type ServiceConfig struct { CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"` MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` - DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` + DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp kea-dhcp4"` DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"` DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` diff --git a/config_test.go b/config_test.go index a20b33c..cd392d5 100644 --- a/config_test.go +++ b/config_test.go @@ -111,6 +111,7 @@ func TestConfigValidation(t *testing.T) { {"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false}, {"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false}, {"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true}, + {"kea dhcp4 format", configWithDhcp4KeaFormat(t), false}, } for _, tc := range tests { @@ -307,6 +308,12 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config { return cfg } +func configWithDhcp4KeaFormat(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Service.DHCPLeaseFileFormat = "kea-dhcp4" + return cfg +} + func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config { cfg := defaultConfig(t) cfg.Upstream["0"].Endpoint = "/1.1.1.1" From 8a96b8bec485dfe2e4a7a5afd49894040c6cc69f Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 12 Dec 2024 15:47:18 +0700 Subject: [PATCH 017/100] cmd/cli: adopt FilteredLevelWriter when doing internal logging Without verbose log, we use internal log writer with log level set to debug. However, this will affect other writers, like console log, since they are default to notice level. By adopting FilteredLevelWriter, we can make internal log writer run in debug level, but all others will run in default level instead. --- cmd/cli/cli.go | 4 ++-- cmd/cli/log_writer.go | 22 +++++++++++++++++----- cmd/cli/main.go | 15 ++++++++------- cmd/cli/prog.go | 4 ++-- go.mod | 2 +- go.sum | 12 +++++------- 6 files changed, 35 insertions(+), 24 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 21d2873..74919d9 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -266,10 +266,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // Log config do not have thing to validate, so it's safe to init log here, // so it's able to log information in processCDFlags. - initLogging() + logWriters := initLogging() // Initializing internal logging after global logging. - p.initInternalLogging() + p.initInternalLogging(logWriters) mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) mainLog.Load().Info().Msgf("os: %s", osVersion()) diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index f84b231..32cf196 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -3,6 +3,7 @@ package cli import ( "bytes" "errors" + "io" "os" "sync" "time" @@ -59,7 +60,7 @@ func (lw *logWriter) Write(p []byte) (int, error) { } // initInternalLogging performs internal logging if there's no log enabled. -func (p *prog) initInternalLogging() { +func (p *prog) initInternalLogging(writers []io.Writer) { if !p.needInternalLogging() { return } @@ -72,13 +73,24 @@ func (p *prog) initInternalLogging() { p.mu.Lock() lw := p.internalLogWriter p.mu.Unlock() - multi := zerolog.MultiLevelWriter(lw) + // If ctrld was run without explicit verbose level, + // run the internal logging at debug level, so we could + // have enough information for troubleshooting. + if verbose == 0 { + for i := range writers { + w := &zerolog.FilteredLevelWriter{ + Writer: zerolog.LevelWriterAdapter{Writer: writers[i]}, + Level: zerolog.NoticeLevel, + } + writers[i] = w + } + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } + writers = append(writers, lw) + multi := zerolog.MultiLevelWriter(writers...) l := mainLog.Load().Output(multi).With().Logger() mainLog.Store(&l) ctrld.ProxyLogger.Store(&l) - if verbose == 0 { - zerolog.SetGlobalLevel(zerolog.DebugLevel) - } } // needInternalLogging reports whether prog needs to run internal logging. diff --git a/cmd/cli/main.go b/cmd/cli/main.go index bafcde1..53662aa 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -101,9 +101,9 @@ func initConsoleLogging() { } // initLogging initializes global logging setup. -func initLogging() { +func initLogging() []io.Writer { zerolog.TimeFieldFormat = time.RFC3339 + ".000" - initLoggingWithBackup(true) + return initLoggingWithBackup(true) } // initLoggingWithBackup initializes log setup base on current config. @@ -112,8 +112,8 @@ func initLogging() { // This is only used in runCmd for special handling in case of logging config // change in cd mode. Without special reason, the caller should use initLogging // wrapper instead of calling this function directly. -func initLoggingWithBackup(doBackup bool) { - writers := []io.Writer{io.Discard} +func initLoggingWithBackup(doBackup bool) []io.Writer { + var writers []io.Writer if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" { // Create parent directory if necessary. if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil { @@ -151,21 +151,22 @@ func initLoggingWithBackup(doBackup bool) { switch { case silent: zerolog.SetGlobalLevel(zerolog.NoLevel) - return + return writers case verbose == 1: logLevel = "info" case verbose > 1: logLevel = "debug" } if logLevel == "" { - return + return writers } level, err := zerolog.ParseLevel(logLevel) if err != nil { mainLog.Load().Warn().Err(err).Msg("could not set log level") - return + return writers } zerolog.SetGlobalLevel(level) + return writers } func initCache() { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index c6146ea..046c0c9 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -514,13 +514,13 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if !reload { // Stop writing log to unix socket. consoleWriter.Out = os.Stdout - initLoggingWithBackup(false) + logWriters := initLoggingWithBackup(false) if p.logConn != nil { _ = p.logConn.Close() } go p.apiConfigReload() p.postRun() - p.initInternalLogging() + p.initInternalLogging(logWriters) } wg.Wait() } diff --git a/go.mod b/go.mod index 1f797e8..630ce44 100644 --- a/go.mod +++ b/go.mod @@ -102,4 +102,4 @@ require ( replace github.com/mr-karan/doggo => github.com/Windscribe/doggo v0.0.0-20220919152748-2c118fc391f8 -replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be +replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c diff --git a/go.sum b/go.sum index 5e073b9..ead3c24 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= -github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik= -github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE= +github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo= @@ -209,11 +209,10 @@ github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= @@ -279,7 +278,7 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM= github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ= @@ -477,15 +476,14 @@ golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= From 37d41bd2150c63974e63e7f1e5028c8395808743 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 12 Dec 2024 18:36:39 +0700 Subject: [PATCH 018/100] Skip public DNS for LAN query So we don't blindly send requests to public DNS even though they can not handle these queries. --- cmd/cli/dns_proxy.go | 3 +++ resolver.go | 17 ++++++++++++++++- resolver_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index a69f5b5..031b362 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -448,6 +448,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { case isSrvLookup(req.msg): upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + ctx = ctrld.LanQueryCtx(ctx) ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams) case isPrivatePtrLookup(req.msg): isLanOrPtrQuery = true @@ -457,6 +458,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return res } upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs) + ctx = ctrld.LanQueryCtx(ctx) ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams) case isLanHostnameQuery(req.msg): isLanOrPtrQuery = true @@ -467,6 +469,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } upstreams = []string{upstreamOS} upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + ctx = ctrld.LanQueryCtx(ctx) ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams) default: ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams) diff --git a/resolver.go b/resolver.go index f3b7a10..e3d319b 100644 --- a/resolver.go +++ b/resolver.go @@ -47,6 +47,14 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") // or is the Resolver used for ResolverTypeOS. var or = newResolverWithNameserver(defaultNameservers()) +// LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network. +type LanQueryCtxKey struct{} + +// LanQueryCtx returns a context.Context with LanQueryCtxKey set. +func LanQueryCtx(ctx context.Context) context.Context { + return context.WithValue(ctx, LanQueryCtxKey{}, true) +} + // defaultNameservers is like nameservers with each element formed "ip:53". func defaultNameservers() []string { ns := nameservers() @@ -191,6 +199,11 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error nss = append(nss, (*p)...) } numServers := len(nss) + len(publicServers) + // If this is a LAN query, skip public DNS. + lan, ok := ctx.Value(LanQueryCtxKey{}).(bool) + if ok && lan { + numServers -= len(publicServers) + } if numServers == 0 { return nil, errors.New("no nameservers available") } @@ -216,7 +229,9 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } } do(nss, true) - do(publicServers, false) + if !lan { + do(publicServers, false) + } logAnswer := func(server string) { if before, _, found := strings.Cut(server, ":"); found { diff --git a/resolver_test.go b/resolver_test.go index e0b5508..5fb8434 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -34,6 +34,44 @@ func Test_osResolver_Resolve(t *testing.T) { } } +func Test_osResolver_ResolveLanHostname(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + reqId := "req-id" + ctx = context.WithValue(ctx, ReqIdCtxKey{}, reqId) + ctx = LanQueryCtx(ctx) + + go func(ctx context.Context) { + defer cancel() + id, ok := ctx.Value(ReqIdCtxKey{}).(string) + if !ok || id != reqId { + t.Error("missing request id") + return + } + lan, ok := ctx.Value(LanQueryCtxKey{}).(bool) + if !ok || !lan { + t.Error("not a LAN query") + return + } + resolver := &osResolver{} + resolver.publicServers.Store(&[]string{"76.76.2.0:53"}) + m := new(dns.Msg) + m.SetQuestion("controld.com.", dns.TypeA) + m.RecursionDesired = true + _, err := resolver.Resolve(ctx, m) + if err == nil { + t.Error("os resolver succeeded unexpectedly") + return + } + }(ctx) + + select { + case <-time.After(10 * time.Second): + t.Error("os resolver hangs") + case <-ctx.Done(): + } +} + func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { ns := make([]string, 0, 2) servers := make([]*dns.Server, 0, 2) From 221917e80bebc3f565cf09401c7d93f83898fc4d Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 12 Dec 2024 18:51:47 +0700 Subject: [PATCH 019/100] Bump golang.org/x/crypto to v0.31.0 To fix CVE-2024-45337 (even though ctrld do not use SSH) --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 630ce44..e4503d6 100644 --- a/go.mod +++ b/go.mod @@ -36,8 +36,8 @@ require ( github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 golang.org/x/net v0.28.0 - golang.org/x/sync v0.8.0 - golang.org/x/sys v0.23.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.28.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.74.0 ) @@ -90,10 +90,10 @@ require ( go.uber.org/mock v0.4.0 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect - golang.org/x/crypto v0.26.0 // indirect + golang.org/x/crypto v0.31.0 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/mod v0.19.0 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/tools v0.23.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index ead3c24..7fb7e6b 100644 --- a/go.sum +++ b/go.sum @@ -342,8 +342,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -434,8 +434,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -484,8 +484,8 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= -golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -496,8 +496,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From 89f7874fc68f39f9d1b6abde333baedbdfb6c931 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 13 Dec 2024 14:25:30 +0700 Subject: [PATCH 020/100] cmd/cli: normalize log path when sending log So the correct log file that "ctrld run" process is writing logs to will be sent to server correctly. --- cmd/cli/log_writer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 32cf196..2e73391 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -122,7 +122,7 @@ func (p *prog) logContent() ([]byte, error) { if p.cfg.Service.LogPath == "" { return nil, nil } - buf, err := os.ReadFile(p.cfg.Service.LogPath) + buf, err := os.ReadFile(normalizeLogFilePath(p.cfg.Service.LogPath)) if err != nil { return nil, err } From cb49d0d9477e4f81da0c6863efb1e9e68808bf2c Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 13 Dec 2024 14:41:46 +0700 Subject: [PATCH 021/100] cmd/cli: perform leaking queries in non-cd mode --- cmd/cli/dns_proxy.go | 2 +- cmd/cli/dns_proxy_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 031b362..8b198a7 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -595,7 +595,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return res } ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) - if cdUID != "" && p.leakOnUpstreamFailure() { + if p.leakOnUpstreamFailure() { p.leakingQueryMu.Lock() if !p.leakingQueryWasRun { p.leakingQueryWasRun = true diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 9deb9ed..eae3dfa 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -75,6 +75,7 @@ func Test_canonicalName(t *testing.T) { func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) + cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false) p := &prog{cfg: cfg} p.um = newUpstreamMonitor(p.cfg) p.lanLoopGuard = newLoopGuard() From 4f623146464995db86b9aa58c6fcec3fd78f64e0 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 13 Dec 2024 18:36:45 +0700 Subject: [PATCH 022/100] cmd/cli: do API reloading if exlcude list changed --- cmd/cli/cli.go | 16 ++++++++++------ cmd/cli/prog.go | 30 +++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 74919d9..f0e927d 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -301,7 +301,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if cdUID != "" { validateCdUpstreamProtocol() - if err := processCDFlags(&cfg); err != nil { + if rc, err := processCDFlags(&cfg); err != nil { if isMobile() { appCallback.Exit(err.Error()) return @@ -315,6 +315,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } notifyExitToLogServer() cdLogger.Fatal().Err(err).Msg("failed to fetch resolver config") + } else { + p.mu.Lock() + p.rc = rc + p.mu.Unlock() } } @@ -604,7 +608,7 @@ func deactivationPinNotSet() bool { return cdDeactivationPin.Load() == defaultDeactivationPin } -func processCDFlags(cfg *ctrld.Config) error { +func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { logger := mainLog.Load().With().Str("mode", "cd").Logger() logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) @@ -622,10 +626,10 @@ func processCDFlags(cfg *ctrld.Config) error { } if err != nil { if isMobile() { - return err + return nil, err } logger.Warn().Err(err).Msg("could not fetch resolver config") - return err + return nil, err } if resolverConfig.DeactivationPin != nil { @@ -641,7 +645,7 @@ func processCDFlags(cfg *ctrld.Config) error { logger.Info().Msg("using defined custom config of Control-D resolver") if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil { setListenerDefaultValue(cfg) - return nil + return resolverConfig, nil } mainLog.Load().Err(err).Msg("disregarding invalid custom config") } @@ -688,7 +692,7 @@ func processCDFlags(cfg *ctrld.Config) error { // Set default value. setListenerDefaultValue(cfg) - return nil + return resolverConfig, nil } // setListenerDefaultValue sets the default value for cfg.Listener if none existed. diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 046c0c9..6deda4e 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -84,6 +84,7 @@ type prog struct { dnsWg sync.WaitGroup dnsWatcherClosedOnce sync.Once dnsWatcherStopCh chan struct{} + rc *controld.ResolverConfig cfg *ctrld.Config localUpstreams []string @@ -165,11 +166,13 @@ func (p *prog) runWait() { if newCfg == nil { newCfg = &ctrld.Config{} + confFile := v.ConfigFileUsed() v := viper.NewWithOptions(viper.KeyDelimiter("::")) ctrld.InitConfig(v, "ctrld") if configPath != "" { - v.SetConfigFile(configPath) + confFile = configPath } + v.SetConfigFile(confFile) if err := v.ReadInConfig(); err != nil { logger.Err(err).Msg("could not read new config") waitOldRunDone() @@ -181,10 +184,14 @@ func (p *prog) runWait() { continue } if cdUID != "" { - if err := processCDFlags(newCfg); err != nil { + if rc, err := processCDFlags(newCfg); err != nil { logger.Err(err).Msg("could not fetch ControlD config") waitOldRunDone() continue + } else { + p.mu.Lock() + p.rc = rc + p.mu.Unlock() } } } @@ -291,7 +298,24 @@ func (p *prog) apiConfigReload() { cdDeactivationPin.Store(defaultDeactivationPin) } - if resolverConfig.Ctrld.CustomConfig == "" { + p.mu.Lock() + rc := p.rc + p.rc = resolverConfig + p.mu.Unlock() + noCustomConfig := resolverConfig.Ctrld.CustomConfig == "" + noExcludeListChanged := true + if rc != nil { + slices.Sort(rc.Exclude) + slices.Sort(resolverConfig.Exclude) + noExcludeListChanged = slices.Equal(rc.Exclude, resolverConfig.Exclude) + } + if noCustomConfig && noExcludeListChanged { + return + } + + if noCustomConfig && !noExcludeListChanged { + logger.Debug().Msg("exclude list changes detected, reloading...") + p.apiReloadCh <- nil return } From 9bbccb40821dfc9b9a6fe8b07ad142b695361989 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 18 Dec 2024 15:46:36 +0700 Subject: [PATCH 023/100] cmd/cli: get default interface once --- cmd/cli/prog.go | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 6deda4e..9ff04bc 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -103,6 +103,8 @@ type prog struct { initInternalLogWriterOnce sync.Once internalLogWriter *logWriter internalLogSent time.Time + runningIface string + requiredMultiNICsConfig bool selfUninstallMu sync.Mutex refusedQueryCount int @@ -243,6 +245,11 @@ func (p *prog) runWait() { } func (p *prog) preRun() { + if iface == "auto" { + iface = defaultIfaceName() + p.requiredMultiNICsConfig = requiredMultiNICsConfig() + } + p.runningIface = iface if runtime.GOOS == "darwin" { p.onStopped = append(p.onStopped, func() { if !service.Interactive() { @@ -607,25 +614,18 @@ func (p *prog) setDNS() { if cfg.Listener == nil { return } - if iface == "" { + if p.runningIface == "" { return } - runningIface := iface + // allIfaces tracks whether we should set DNS for all physical interfaces. - allIfaces := false - if runningIface == "auto" { - runningIface = defaultIfaceName() - // If runningIface is "auto", it means user does not specify "--iface" flag. - // In this case, ctrld has to set DNS for all physical interfaces, so - // thing will still work when user switch from one to the other. - allIfaces = requiredMultiNICsConfig() - } + allIfaces := p.requiredMultiNICsConfig lc := cfg.FirstListener() if lc == nil { return } - logger := mainLog.Load().With().Str("iface", runningIface).Logger() - netIface, err := netInterface(runningIface) + logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + netIface, err := netInterface(p.runningIface) if err != nil { logger.Error().Err(err).Msg("could not get interface") return @@ -754,18 +754,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces } func (p *prog) resetDNS() { - if iface == "" { + if p.runningIface == "" { return } - runningIface := iface - allIfaces := false - if runningIface == "auto" { - runningIface = defaultIfaceName() - // See corresponding comments in (*prog).setDNS function. - allIfaces = requiredMultiNICsConfig() - } - logger := mainLog.Load().With().Str("iface", runningIface).Logger() - netIface, err := netInterface(runningIface) + // See corresponding comments in (*prog).setDNS function. + allIfaces := p.requiredMultiNICsConfig + logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() + netIface, err := netInterface(p.runningIface) if err != nil { logger.Error().Err(err).Msg("could not get interface") return From 4a92ec4d2dc4d15ce70f60a0b7ce5dad31a42a8e Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 19 Dec 2024 22:10:34 +0700 Subject: [PATCH 024/100] cmd/cli: fix race in Test_addSplitDnsRule --- cmd/cli/ad_windows_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/cli/ad_windows_test.go b/cmd/cli/ad_windows_test.go index 6fe7f41..6abd25f 100644 --- a/cmd/cli/ad_windows_test.go +++ b/cmd/cli/ad_windows_test.go @@ -64,7 +64,6 @@ func Test_addSplitDnsRule(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - t.Parallel() added := addSplitDnsRule(tc.cfg, tc.domain) assert.Equal(t, tc.added, added) }) From ff43c74d8d29287945f2edc632f8a5aac18a9c1f Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 19 Dec 2024 21:44:25 +0700 Subject: [PATCH 025/100] Bump golang.org/x/net to v0.33.0 Fix CVE-2024-45338 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index e4503d6..d9e924d 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/net v0.28.0 + golang.org/x/net v0.33.0 golang.org/x/sync v0.10.0 golang.org/x/sys v0.28.0 golang.zx2c4.com/wireguard/windows v0.5.3 diff --git a/go.sum b/go.sum index 7fb7e6b..66c6a73 100644 --- a/go.sum +++ b/go.sum @@ -413,8 +413,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= From 5a566c028a3ddc1c31a975ae14b0432eea470655 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 27 Dec 2024 16:28:56 +0700 Subject: [PATCH 026/100] cmd/cli: better error message when log file is empty While at it, also record the size of logs being sent in debug/error message. --- cmd/cli/commands.go | 31 ++++++++++++++++++++++--------- cmd/cli/control_server.go | 12 +++++++++--- cmd/cli/log_writer.go | 11 +++++++++++ go.mod | 1 + go.sum | 2 ++ 5 files changed, 45 insertions(+), 12 deletions(-) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index f5fbd5b..b174052 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -19,6 +19,7 @@ import ( "strings" "time" + "github.com/docker/go-units" "github.com/kardianos/service" "github.com/minio/selfupdate" "github.com/olekukonko/tablewriter" @@ -48,17 +49,24 @@ func initLogCmd() *cobra.Command { } defer resp.Body.Close() switch resp.StatusCode { - case http.StatusOK: - mainLog.Load().Notice().Msg("runtime logs sent successfully") case http.StatusServiceUnavailable: mainLog.Load().Warn().Msg("runtime logs could only be sent once per minute") - default: - buf, err := io.ReadAll(resp.Body) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to read response body") - } - mainLog.Load().Error().Msg("failed to send logs") - mainLog.Load().Error().Msg(string(buf)) + return + case http.StatusMovedPermanently: + mainLog.Load().Warn().Msg("runtime debugs log is not enabled") + mainLog.Load().Warn().Msg(`ctrld may be run without "--cd" flag or logging is already enabled`) + return + } + var logs logSentResponse + if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to decode sent logs result") + } + size := units.BytesSize(float64(logs.Size)) + if logs.Error == "" { + mainLog.Load().Notice().Msgf("runtime logs sent successfully (%s)", size) + } else { + mainLog.Load().Error().Msgf("failed to send logs (%s)", size) + mainLog.Load().Error().Msg(logs.Error) } }, } @@ -85,6 +93,11 @@ func initLogCmd() *cobra.Command { return case http.StatusBadRequest: mainLog.Load().Warn().Msg("runtime debugs log is not available") + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to read response body") + } + mainLog.Load().Warn().Msgf("ctrld process response:\n\n%s\n", string(buf)) return case http.StatusOK: } diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index b6deed5..7a33407 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -217,7 +217,7 @@ func (p *prog) registerControlServerHandler() { p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { data, err := p.logContent() if err != nil { - w.WriteHeader(http.StatusBadRequest) + http.Error(w, err.Error(), http.StatusBadRequest) return } if len(data) == 0 { @@ -236,7 +236,7 @@ func (p *prog) registerControlServerHandler() { } data, err := p.logContent() if err != nil { - w.WriteHeader(http.StatusBadRequest) + http.Error(w, err.Error(), http.StatusBadRequest) return } if len(data) == 0 { @@ -249,11 +249,17 @@ func (p *prog) registerControlServerHandler() { LogFile: logFile, } mainLog.Load().Debug().Msg("sending log file to ControlD server") + resp := logSentResponse{Size: len(data)} if err := controld.SendLogs(req, cdDev); err != nil { mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + resp.Error = err.Error() + w.WriteHeader(http.StatusInternalServerError) } else { mainLog.Load().Debug().Msg("sending log file successfully") + w.WriteHeader(http.StatusOK) + } + if err := json.NewEncoder(w).Encode(&resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) } p.internalLogSent = time.Now() })) diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 2e73391..7f12032 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -24,6 +24,11 @@ type logViewResponse struct { Data string `json:"data"` } +type logSentResponse struct { + Size int `json:"size"` + Error string `json:"error"` +} + // logWriter is an internal buffer to keep track of runtime log when no logging is enabled. type logWriter struct { mu sync.Mutex @@ -118,6 +123,9 @@ func (p *prog) logContent() ([]byte, error) { lw.mu.Lock() data = lw.buf.Bytes() lw.mu.Unlock() + if len(data) == 0 { + return nil, errors.New("internal log is empty") + } } else { if p.cfg.Service.LogPath == "" { return nil, nil @@ -127,6 +135,9 @@ func (p *prog) logContent() ([]byte, error) { return nil, err } data = buf + if len(data) == 0 { + return nil, errors.New("log file is empty") + } } return data, nil } diff --git a/go.mod b/go.mod index d9e924d..67fe9a2 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/ameshkov/dnsstamps v1.0.3 github.com/coreos/go-systemd/v22 v22.5.0 github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf + github.com/docker/go-units v0.5.0 github.com/frankban/quicktest v1.14.6 github.com/fsnotify/fsnotify v1.7.0 github.com/go-playground/validator/v10 v10.11.1 diff --git a/go.sum b/go.sum index 66c6a73..bcf1ee7 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa h1:h8TfIT1xc8FWbwwpmHn1J5i43Y0uZP97GqasGCzSRJk= github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= From a5c776c84627cc6e7d0cfd633e98b4ad817c0970 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 7 Jan 2025 16:56:09 +0700 Subject: [PATCH 027/100] all: change send log to use x-www-form-urlencoded --- cmd/cli/control_server.go | 24 ++++++++++++------- cmd/cli/log_writer.go | 47 +++++++++++++++++++++++-------------- internal/controld/config.go | 20 ++++++++-------- 3 files changed, 54 insertions(+), 37 deletions(-) diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 7a33407..302b902 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -2,8 +2,9 @@ package cli import ( "context" - "encoding/base64" "encoding/json" + "fmt" + "io" "net" "net/http" "os" @@ -215,17 +216,23 @@ func (p *prog) registerControlServerHandler() { w.WriteHeader(http.StatusBadRequest) })) p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { - data, err := p.logContent() + lr, err := p.logReader() if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - if len(data) == 0 { + if lr.size == 0 { w.WriteHeader(http.StatusMovedPermanently) return } + data, err := io.ReadAll(lr.r) + if err != nil { + http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError) + return + } if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError) return } })) @@ -234,22 +241,21 @@ func (p *prog) registerControlServerHandler() { w.WriteHeader(http.StatusServiceUnavailable) return } - data, err := p.logContent() + r, err := p.logReader() if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - if len(data) == 0 { + if r.size == 0 { w.WriteHeader(http.StatusMovedPermanently) return } - logFile := base64.StdEncoding.EncodeToString(data) req := &controld.LogsRequest{ - UID: cdUID, - LogFile: logFile, + UID: cdUID, + Data: r.r, } mainLog.Load().Debug().Msg("sending log file to ControlD server") - resp := logSentResponse{Size: len(data)} + resp := logSentResponse{Size: r.size} if err := controld.SendLogs(req, cdDev); err != nil { mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err) resp.Error = err.Error() diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 7f12032..a6fb7eb 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -3,6 +3,7 @@ package cli import ( "bytes" "errors" + "fmt" "io" "os" "sync" @@ -25,10 +26,15 @@ type logViewResponse struct { } type logSentResponse struct { - Size int `json:"size"` + Size int64 `json:"size"` Error string `json:"error"` } +type logReader struct { + r io.ReadCloser + size int64 +} + // logWriter is an internal buffer to keep track of runtime log when no logging is enabled. type logWriter struct { mu sync.Mutex @@ -111,8 +117,7 @@ func (p *prog) needInternalLogging() bool { return true } -func (p *prog) logContent() ([]byte, error) { - var data []byte +func (p *prog) logReader() (*logReader, error) { if p.needInternalLogging() { p.mu.Lock() lw := p.internalLogWriter @@ -121,23 +126,29 @@ func (p *prog) logContent() ([]byte, error) { return nil, errors.New("nil internal log writer") } lw.mu.Lock() - data = lw.buf.Bytes() + lr := &logReader{r: io.NopCloser(bytes.NewReader(lw.buf.Bytes()))} + lr.size = int64(lw.buf.Len()) lw.mu.Unlock() - if len(data) == 0 { + if lr.size == 0 { return nil, errors.New("internal log is empty") } - } else { - if p.cfg.Service.LogPath == "" { - return nil, nil - } - buf, err := os.ReadFile(normalizeLogFilePath(p.cfg.Service.LogPath)) - if err != nil { - return nil, err - } - data = buf - if len(data) == 0 { - return nil, errors.New("log file is empty") - } + return lr, nil } - return data, nil + if p.cfg.Service.LogPath == "" { + return nil, nil + } + f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath)) + if err != nil { + return nil, err + } + lr := &logReader{r: f} + if st, err := f.Stat(); err == nil { + lr.size = st.Size() + } else { + return nil, fmt.Errorf("f.Stat: %w", err) + } + if lr.size == 0 { + return nil, errors.New("log file is empty") + } + return lr, nil } diff --git a/internal/controld/config.go b/internal/controld/config.go index 348dc54..b1814d1 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -77,8 +77,8 @@ type UtilityOrgRequest struct { // LogsRequest contains request data for sending runtime logs to API. type LogsRequest struct { - UID string `json:"uid"` - LogFile string `json:"log_file"` + UID string `json:"uid"` + Data io.ReadCloser `json:"-"` } // FetchResolverConfig fetch Control D config for given uid. @@ -160,20 +160,20 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade } // SendLogs sends runtime log to ControlD API. -func SendLogs(req *LogsRequest, cdDev bool) error { - body, _ := json.Marshal(req) - return postLogAPI(cdDev, bytes.NewReader(body)) -} - -func postLogAPI(cdDev bool, body io.Reader) error { +func SendLogs(lr *LogsRequest, cdDev bool) error { + defer lr.Data.Close() apiUrl := logURLCom if cdDev { apiUrl = logURLDev } - req, err := http.NewRequest("POST", apiUrl, body) + req, err := http.NewRequest("POST", apiUrl, lr.Data) if err != nil { return fmt.Errorf("http.NewRequest: %w", err) } + q := req.URL.Query() + q.Set("uid", lr.UID) + req.URL.RawQuery = q.Encode() + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") transport := apiTransport(cdDev) client := http.Client{ Timeout: 10 * time.Second, @@ -181,7 +181,7 @@ func postLogAPI(cdDev bool, body io.Reader) error { } resp, err := client.Do(req) if err != nil { - return fmt.Errorf("postLogAPI client.Do: %w", err) + return fmt.Errorf("SendLogs client.Do: %w", err) } defer resp.Body.Close() d := json.NewDecoder(resp.Body) From db6e977e3a039fce87948f4456e19622edde353f Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 7 Jan 2025 21:15:12 +0700 Subject: [PATCH 028/100] Only used saved LAN servers if available --- resolver.go | 16 ++++++++++++-- resolver_test.go | 57 +++++++++++++++++++++++++++++++++--------------- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/resolver.go b/resolver.go index e3d319b..82a395e 100644 --- a/resolver.go +++ b/resolver.go @@ -78,7 +78,7 @@ func availableNameservers() []string { if _, ok := machineIPsMap[ns]; ok { continue } - if testNameserver(ns) { + if testNameServerFn(ns) { nss = append(nss, ns) } } @@ -122,7 +122,16 @@ func initializeOsResolver(servers []string) []string { or.initializedLanServers.CompareAndSwap(nil, &lanNss) } if len(lanNss) == 0 { - or.lanServers.Store(or.initializedLanServers.Load()) + var nss []string + p := or.initializedLanServers.Load() + if p != nil { + for _, ns := range *p { + if testNameServerFn(ns) { + nss = append(nss, ns) + } + } + } + or.lanServers.Store(&nss) } else { or.lanServers.Store(&lanNss) } @@ -133,6 +142,9 @@ func initializeOsResolver(servers []string) []string { return slices.Concat(lanNss, publicNss) } +// testNameserverFn sends a test query to DNS nameserver to check if the server is available. +var testNameServerFn = testNameserver + // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. func testNameserver(addr string) bool { msg := new(dns.Msg) diff --git a/resolver_test.go b/resolver_test.go index 5fb8434..7eab744 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -75,23 +75,10 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) { func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { ns := make([]string, 0, 2) servers := make([]*dns.Server, 0, 2) - successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(msg, dns.RcodeSuccess) - w.WriteMsg(m) - }) - nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc { - return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(msg, rcode) - w.WriteMsg(m) - }) - } - handlers := []dns.Handler{ nonSuccessHandlerWithRcode(dns.RcodeRefused), nonSuccessHandlerWithRcode(dns.RcodeNameError), - successHandler, + successHandler(), } for i := range handlers { pc, err := net.ListenPacket("udp", ":0") @@ -192,11 +179,15 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H } func Test_initializeOsResolver(t *testing.T) { + testNameServerFn = testNameserverTest lanServer1 := "192.168.1.1" + lanServer1WithPort := net.JoinHostPort("192.168.1.1", "53") lanServer2 := "10.0.10.69" + lanServer2WithPort := net.JoinHostPort("10.0.10.69", "53") lanServer3 := "192.168.40.1" + lanServer3WithPort := net.JoinHostPort("192.168.40.1", "53") wanServer := "1.1.1.1" - lanServers := []string{net.JoinHostPort(lanServer1, "53"), net.JoinHostPort(lanServer2, "53")} + lanServers := []string{lanServer1WithPort, lanServer2WithPort} publicServers := []string{net.JoinHostPort(wanServer, "53")} or = newResolverWithNameserver(defaultNameservers()) @@ -214,7 +205,7 @@ func Test_initializeOsResolver(t *testing.T) { p = or.initializedLanServers.Load() assert.NotNil(t, p) assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer1, "53")})) + assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer1WithPort})) assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) // New LAN servers, they are used, initialized servers not changed. @@ -222,7 +213,7 @@ func Test_initializeOsResolver(t *testing.T) { p = or.initializedLanServers.Load() assert.NotNil(t, p) assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), []string{net.JoinHostPort(lanServer3, "53")})) + assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer3WithPort})) assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) // No LAN server available, initialized servers will be used. @@ -240,4 +231,36 @@ func Test_initializeOsResolver(t *testing.T) { assert.True(t, slices.Equal(*p, lanServers)) assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) assert.True(t, slices.Equal(*or.publicServers.Load(), []string{controldPublicDnsWithPort})) + + // No LAN server available, initialized servers is unavailable, nothing will be used. + nonSuccessTestServerMap[lanServer1WithPort] = true + nonSuccessTestServerMap[lanServer2WithPort] = true + initializeOsResolver([]string{wanServer}) + p = or.initializedLanServers.Load() + assert.NotNil(t, p) + assert.True(t, slices.Equal(*p, lanServers)) + assert.Empty(t, *or.lanServers.Load()) + assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) +} + +func successHandler() dns.HandlerFunc { + return func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, dns.RcodeSuccess) + w.WriteMsg(m) + } +} + +func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc { + return func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, rcode) + w.WriteMsg(m) + } +} + +var nonSuccessTestServerMap = map[string]bool{} + +func testNameserverTest(addr string) bool { + return !nonSuccessTestServerMap[addr] } From 3ea69b180cbdf23a53b232f3c9722177bd502154 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 8 Jan 2025 16:57:32 +0700 Subject: [PATCH 029/100] cmd/cli: use config timeout when checking upstream Otherwise, for slow network connection (like plane wifi), the check may fail even though the internet is available. --- cmd/cli/upstream_monitor.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index b17cb32..4d79c9f 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -93,9 +93,12 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { } msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) - + timeout := 1000 * time.Millisecond + if uc.Timeout > 0 { + timeout = time.Duration(uc.Timeout) * time.Millisecond + } check := func() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() uc.ReBootstrap() _, err := resolver.Resolve(ctx, msg) @@ -112,6 +115,8 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { mainLog.Load().Warn().Msg("stop leaking query") } return + } else { + mainLog.Load().Debug().Msgf("upstream %q is offline: %v", uc.Endpoint, err) } time.Sleep(checkUpstreamBackoffSleep) } From 6046789fa4d139bbb09e90715a90605afc6ea906 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 10 Jan 2025 01:50:03 +0700 Subject: [PATCH 030/100] cmd/cli: re-initializing OS resolver before doing check upstream Otherwise, the check will be done for old stale nameservers, causing it never succeed. --- cmd/cli/upstream_monitor.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 4d79c9f..3400b60 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -86,6 +86,10 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { p.um.mu.Unlock() }() + if uc.Type == ctrld.ResolverTypeOS { + ns := ctrld.InitializeOsResolver() + mainLog.Load().Debug().Msgf("re-initializing OS resolver with nameservers: %v", ns) + } resolver, err := ctrld.NewResolver(uc) if err != nil { mainLog.Load().Warn().Err(err).Msg("could not check upstream") From 3713cbecc3be0e1150c678fec70922dadc01b077 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 10 Jan 2025 09:43:16 +0700 Subject: [PATCH 031/100] cmd/cli: correct log writer initial size --- cmd/cli/log_writer.go | 2 +- cmd/cli/log_writer_test.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index a6fb7eb..03e7139 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -16,7 +16,7 @@ import ( const ( logWriterSize = 1024 * 1024 * 5 // 5 MB - logWriterInitialSize = 32 // 32 B + logWriterInitialSize = 32 * 1024 // 32 KB logSentInterval = time.Minute logTruncatedMarker = "...\n" ) diff --git a/cmd/cli/log_writer_test.go b/cmd/cli/log_writer_test.go index 6882ea0..92c772b 100644 --- a/cmd/cli/log_writer_test.go +++ b/cmd/cli/log_writer_test.go @@ -7,7 +7,7 @@ import ( ) func Test_logWriter_Write(t *testing.T) { - size := 64 + size := 64 * 1024 lw := &logWriter{size: size} lw.buf.Grow(lw.size) data := strings.Repeat("A", size) @@ -22,8 +22,8 @@ func Test_logWriter_Write(t *testing.T) { t.Fatalf("unexpected new buf content: %v", lw.buf.String()) } - bigData := strings.Repeat("B", 256) - expected := halfData + strings.Repeat("B", 16) + bigData := strings.Repeat("B", 256*1024) + expected := halfData + strings.Repeat("B", 16*1024) lw.Write([]byte(bigData)) if lw.buf.String() != expected { t.Fatalf("unexpected big buf content: %v", lw.buf.String()) @@ -31,7 +31,7 @@ func Test_logWriter_Write(t *testing.T) { } func Test_logWriter_ConcurrentWrite(t *testing.T) { - size := 64 + size := 64 * 1024 lw := &logWriter{size: size} n := 10 var wg sync.WaitGroup From 087c1975e56055be4a716e15603a45e93800be2a Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 10 Jan 2025 09:44:06 +0700 Subject: [PATCH 032/100] internal/controld: bump send log timeout to 300s --- internal/controld/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/controld/config.go b/internal/controld/config.go index b1814d1..fbbd9d4 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -176,7 +176,7 @@ func SendLogs(lr *LogsRequest, cdDev bool) error { req.Header.Add("Content-Type", "application/x-www-form-urlencoded") transport := apiTransport(cdDev) client := http.Client{ - Timeout: 10 * time.Second, + Timeout: 300 * time.Second, Transport: transport, } resp, err := client.Do(req) From 6fd3d1788a338a90969fc0e446898721d6866df9 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 10 Jan 2025 18:40:21 +0700 Subject: [PATCH 033/100] cmd/cli: fix memory leaked when querying wmi instance By ensuring the instance is closed when query finished. --- cmd/cli/net_windows.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index 2077b85..f46a93f 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -44,6 +44,7 @@ func validInterfaces() []string { mainLog.Load().Err(err).Msg("failed to get wmi network adapter") return nil } + defer instances.Close() var adapters []string for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) From c53a0ca1c44cfe96ebae44e1b2bbc0b5a4b05f3b Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 10 Jan 2025 18:49:46 +0700 Subject: [PATCH 034/100] cmd/cli: close log reader after reading --- cmd/cli/control_server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 302b902..52406b1 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -221,6 +221,7 @@ func (p *prog) registerControlServerHandler() { http.Error(w, err.Error(), http.StatusBadRequest) return } + defer lr.r.Close() if lr.size == 0 { w.WriteHeader(http.StatusMovedPermanently) return From 5db7d3577b046732cd835f7f5bc9b02f4ed7219c Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 10 Jan 2025 18:50:39 +0700 Subject: [PATCH 035/100] cmd/cli: handle . domain query By returning FormErr response, the same behavior with ControlD. --- cmd/cli/dns_proxy.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 8b198a7..623c7d2 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -106,11 +106,18 @@ func (p *prog) serveDNS(listenerNum string) error { go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) - if domain == selfCheckInternalTestDomain { + switch { + case domain == "": + answer := new(dns.Msg) + answer.SetRcode(m, dns.RcodeFormatError) + _ = w.WriteMsg(answer) + return + case domain == selfCheckInternalTestDomain: answer := resolveInternalDomainTestQuery(ctx, domain, m) _ = w.WriteMsg(answer) return } + if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil { p.cache.Purge() ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain) From a95d50c0afc72a7737153aaee4e4b340c9343884 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 13 Jan 2025 20:03:56 +0700 Subject: [PATCH 036/100] cmd/cli: ensure set/reset DNS is done before checking OS resolver Otherwise, new DNS settings could be reverted by dns watchers, causing the checking will be always false. --- cmd/cli/dns_proxy.go | 2 +- cmd/cli/prog.go | 2 +- cmd/cli/resolvconf.go | 2 +- cmd/cli/upstream_monitor.go | 23 ++++++++++++++++++----- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 623c7d2..631b0e3 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -552,7 +552,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { continue } if p.um.isDown(upstreams[n]) { - ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n]) + ctrld.Log(ctx, mainLog.Load().Debug(), "%s is down", upstreams[n]) continue } answer := resolve(n, upstreamConfig, req.msg) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 9ff04bc..9de6b48 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -728,7 +728,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces mainLog.Load().Debug().Msg("stop dns watchdog") return case <-ticker.C: - if p.leakingQuery.Load() { + if p.leakingQuery.Load() || p.um.isChecking(upstreamOS) { return } if dnsChanged(iface, ns) { diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 6df7be6..6bd8c2a 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -40,7 +40,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: - if p.leakingQuery.Load() { + if p.leakingQuery.Load() || p.um.isChecking(upstreamOS) { return } if !ok { diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 3400b60..86c191d 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -60,6 +60,14 @@ func (um *upstreamMonitor) isDown(upstream string) bool { return um.down[upstream] } +// isChecking reports whether the given upstream is being checked. +func (um *upstreamMonitor) isChecking(upstream string) bool { + um.mu.Lock() + defer um.mu.Unlock() + + return um.checking[upstream] +} + // reset marks an upstream as up and set failed queries counter to zero. func (um *upstreamMonitor) reset(upstream string) { um.mu.Lock() @@ -86,9 +94,10 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { p.um.mu.Unlock() }() - if uc.Type == ctrld.ResolverTypeOS { - ns := ctrld.InitializeOsResolver() - mainLog.Load().Debug().Msgf("re-initializing OS resolver with nameservers: %v", ns) + isOsResolver := uc.Type == ctrld.ResolverTypeOS + if isOsResolver { + p.resetDNS() + defer p.setDNS() } resolver, err := ctrld.NewResolver(uc) if err != nil { @@ -105,12 +114,16 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() uc.ReBootstrap() + if isOsResolver { + ctrld.InitializeOsResolver() + } _, err := resolver.Resolve(ctx, msg) return err } + mainLog.Load().Warn().Msgf("upstream %q is offline", uc.Endpoint) for { if err := check(); err == nil { - mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint) + mainLog.Load().Warn().Msgf("upstream %q is online", uc.Endpoint) p.um.reset(upstream) if p.leakingQuery.CompareAndSwap(true, false) { p.leakingQueryMu.Lock() @@ -120,7 +133,7 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { } return } else { - mainLog.Load().Debug().Msgf("upstream %q is offline: %v", uc.Endpoint, err) + mainLog.Load().Debug().Msgf("checked upstream %q failed: %v", uc.Endpoint, err) } time.Sleep(checkUpstreamBackoffSleep) } From 8bcbb9249e6c3c193b3d1aaa5ab77a9f37e8f275 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 13 Jan 2025 20:26:33 +0700 Subject: [PATCH 037/100] cmd/cli: add an internal warn level log writer So important events like upstream online/offline/failed will be preserved, and submitted to the server as necessary. --- cmd/cli/log_writer.go | 42 ++++++++++++++++++++++++++++++++++++------ cmd/cli/prog.go | 1 + 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 03e7139..c146f4e 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -16,9 +16,11 @@ import ( const ( logWriterSize = 1024 * 1024 * 5 // 5 MB + logWriterSmallSize = 1024 * 1024 * 1 // 1 MB logWriterInitialSize = 32 * 1024 // 32 KB logSentInterval = time.Minute logTruncatedMarker = "...\n" + logSeparator = "\n===\n\n" ) type logViewResponse struct { @@ -42,9 +44,19 @@ type logWriter struct { size int } -// newLogWriter creates an internal log writer with a fixed buffer size. +// newLogWriter creates an internal log writer. func newLogWriter() *logWriter { - lw := &logWriter{size: logWriterSize} + return newLogWriterWithSize(logWriterSize) +} + +// newSmallLogWriter creates an internal log writer with small buffer size. +func newSmallLogWriter() *logWriter { + return newLogWriterWithSize(logWriterSmallSize) +} + +// newLogWriterWithSize creates an internal log writer with a given buffer size. +func newLogWriterWithSize(size int) *logWriter { + lw := &logWriter{size: size} return lw } @@ -77,12 +89,13 @@ func (p *prog) initInternalLogging(writers []io.Writer) { } p.initInternalLogWriterOnce.Do(func() { mainLog.Load().Notice().Msg("internal logging enabled") - lw := newLogWriter() - p.internalLogWriter = lw + p.internalLogWriter = newLogWriter() p.internalLogSent = time.Now().Add(-logSentInterval) + p.internalWarnLogWriter = newSmallLogWriter() }) p.mu.Lock() lw := p.internalLogWriter + wlw := p.internalWarnLogWriter p.mu.Unlock() // If ctrld was run without explicit verbose level, // run the internal logging at debug level, so we could @@ -98,6 +111,10 @@ func (p *prog) initInternalLogging(writers []io.Writer) { zerolog.SetGlobalLevel(zerolog.DebugLevel) } writers = append(writers, lw) + writers = append(writers, &zerolog.FilteredLevelWriter{ + Writer: zerolog.LevelWriterAdapter{Writer: wlw}, + Level: zerolog.WarnLevel, + }) multi := zerolog.MultiLevelWriter(writers...) l := mainLog.Load().Output(multi).With().Logger() mainLog.Store(&l) @@ -121,14 +138,27 @@ func (p *prog) logReader() (*logReader, error) { if p.needInternalLogging() { p.mu.Lock() lw := p.internalLogWriter + wlw := p.internalWarnLogWriter p.mu.Unlock() if lw == nil { return nil, errors.New("nil internal log writer") } + if wlw == nil { + return nil, errors.New("nil internal warn log writer") + } + // Normal log content. lw.mu.Lock() - lr := &logReader{r: io.NopCloser(bytes.NewReader(lw.buf.Bytes()))} - lr.size = int64(lw.buf.Len()) + lwReader := bytes.NewReader(lw.buf.Bytes()) + lwSize := lw.buf.Len() lw.mu.Unlock() + // Warn log content. + wlw.mu.Lock() + wlwReader := bytes.NewReader(wlw.buf.Bytes()) + wlwSize := wlw.buf.Len() + wlw.mu.Unlock() + reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logSeparator)), wlwReader) + lr := &logReader{r: io.NopCloser(reader)} + lr.size = int64(lwSize + wlwSize) if lr.size == 0 { return nil, errors.New("internal log is empty") } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 9de6b48..2ceac7c 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -102,6 +102,7 @@ type prog struct { queryFromSelfMap sync.Map initInternalLogWriterOnce sync.Once internalLogWriter *logWriter + internalWarnLogWriter *logWriter internalLogSent time.Time runningIface string requiredMultiNICsConfig bool From 9c2fe8d21f9f376d1a731f3829f65ba2136d933e Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 14 Jan 2025 14:24:27 +0700 Subject: [PATCH 038/100] cmd/cli: set running iface for stop/uninstall commands --- cmd/cli/cli.go | 1 + cmd/cli/commands.go | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index f0e927d..6b7ac8f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1767,6 +1767,7 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunn if currentRunningIface != "" { iface = currentRunningIface } + p.runningIface = iface if isCtrldInstalled { mainLog.Load().Debug().Msg("restore system DNS settings") if status, _ := s.Status(); status == service.StatusRunning { diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index b174052..ebf3dec 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -522,6 +522,11 @@ func initStopCmd() *cobra.Command { mainLog.Load().Error().Msg(err.Error()) return } + p.runningIface = iface + if ri := runningIface(s); ri != "" { + p.runningIface = ri + } + initLogging() if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) @@ -772,6 +777,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if iface == "" { iface = "auto" } + p.runningIface = iface + if ri := runningIface(s); ri != "" { + p.runningIface = ri + } if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } From f986a575e812dc18ecea3f8f5c9fb112154bcf64 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 14 Jan 2025 22:16:03 +0700 Subject: [PATCH 039/100] cmd/cli: log upstream name if endpoint is empty --- cmd/cli/upstream_monitor.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 86c191d..512a8b6 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -120,10 +120,14 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { _, err := resolver.Resolve(ctx, msg) return err } - mainLog.Load().Warn().Msgf("upstream %q is offline", uc.Endpoint) + endpoint := uc.Endpoint + if endpoint == "" { + endpoint = uc.Name + } + mainLog.Load().Warn().Msgf("upstream %q is offline", endpoint) for { if err := check(); err == nil { - mainLog.Load().Warn().Msgf("upstream %q is online", uc.Endpoint) + mainLog.Load().Warn().Msgf("upstream %q is online", endpoint) p.um.reset(upstream) if p.leakingQuery.CompareAndSwap(true, false) { p.leakingQueryMu.Lock() @@ -133,7 +137,7 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { } return } else { - mainLog.Load().Debug().Msgf("checked upstream %q failed: %v", uc.Endpoint, err) + mainLog.Load().Debug().Msgf("checked upstream %q failed: %v", endpoint, err) } time.Sleep(checkUpstreamBackoffSleep) } From 89600f6091adc7b57b5df631959ef8a6108f60be Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 15 Jan 2025 19:51:55 +0700 Subject: [PATCH 040/100] cmd/cli: new flow for leaking queries to OS resolver The current flow involves marking OS resolver as down, which is not right at all, since ctrld depends on it for leaking queries. This commits implements new flow, which ctrld will restore DNS settings once leaking marked, allowing queries go to OS resolver until the internet connection is established. --- cmd/cli/dns_proxy.go | 19 ++++++++++++++----- cmd/cli/prog.go | 2 +- cmd/cli/resolvconf.go | 2 +- cmd/cli/upstream_monitor.go | 26 ++++---------------------- 4 files changed, 20 insertions(+), 29 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 631b0e3..4f4b980 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -419,12 +419,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) leaked := false - // If ctrld is going to leak query to OS resolver, check remote upstream in background, - // so ctrld could be back to normal operation as long as the network is back online. if len(upstreamConfigs) > 0 && p.leakingQuery.Load() { - for n, uc := range upstreamConfigs { - go p.checkUpstream(upstreams[n], uc) - } upstreamConfigs = nil leaked = true ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) @@ -936,11 +931,25 @@ func (p *prog) performLeakingQuery() { mainLog.Load().Warn().Msg("leaking query to OS resolver") // Signal dns watchers to stop, so changes made below won't be reverted. p.leakingQuery.Store(true) + defer func() { + p.leakingQuery.Store(false) + p.leakingQueryMu.Lock() + p.leakingQueryWasRun = false + p.leakingQueryMu.Unlock() + }() + // Reset DNS, so queries are forwarded to OS resolver normally. p.resetDNS() + // Check remote upstream in background, so ctrld could be back to normal + // operation as long as the network is back online. + for name, uc := range p.cfg.Upstream { + p.checkUpstream(name, uc) + } + // After all upstream back, re-initializing OS resolver. ns := ctrld.InitializeOsResolver() mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) p.dnsWg.Wait() p.setDNS() + mainLog.Load().Warn().Msg("stop leaking query") } // forceFetchingAPI sends signal to force syncing API config if run in cd mode, diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 2ceac7c..29c1120 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -729,7 +729,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces mainLog.Load().Debug().Msg("stop dns watchdog") return case <-ticker.C: - if p.leakingQuery.Load() || p.um.isChecking(upstreamOS) { + if p.leakingQuery.Load() { return } if dnsChanged(iface, ns) { diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 6bd8c2a..6df7be6 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -40,7 +40,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: - if p.leakingQuery.Load() || p.um.isChecking(upstreamOS) { + if p.leakingQuery.Load() { return } if !ok { diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 512a8b6..1f3484b 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -44,6 +44,10 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { // increaseFailureCount increase failed queries count for an upstream by 1. func (um *upstreamMonitor) increaseFailureCount(upstream string) { + // Do not count "upstream.os", since it must not be down for leaking queries. + if upstream == upstreamOS { + return + } um.mu.Lock() defer um.mu.Unlock() @@ -60,14 +64,6 @@ func (um *upstreamMonitor) isDown(upstream string) bool { return um.down[upstream] } -// isChecking reports whether the given upstream is being checked. -func (um *upstreamMonitor) isChecking(upstream string) bool { - um.mu.Lock() - defer um.mu.Unlock() - - return um.checking[upstream] -} - // reset marks an upstream as up and set failed queries counter to zero. func (um *upstreamMonitor) reset(upstream string) { um.mu.Lock() @@ -94,11 +90,6 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { p.um.mu.Unlock() }() - isOsResolver := uc.Type == ctrld.ResolverTypeOS - if isOsResolver { - p.resetDNS() - defer p.setDNS() - } resolver, err := ctrld.NewResolver(uc) if err != nil { mainLog.Load().Warn().Err(err).Msg("could not check upstream") @@ -114,9 +105,6 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() uc.ReBootstrap() - if isOsResolver { - ctrld.InitializeOsResolver() - } _, err := resolver.Resolve(ctx, msg) return err } @@ -129,12 +117,6 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { if err := check(); err == nil { mainLog.Load().Warn().Msgf("upstream %q is online", endpoint) p.um.reset(upstream) - if p.leakingQuery.CompareAndSwap(true, false) { - p.leakingQueryMu.Lock() - p.leakingQueryWasRun = false - p.leakingQueryMu.Unlock() - mainLog.Load().Warn().Msg("stop leaking query") - } return } else { mainLog.Load().Debug().Msgf("checked upstream %q failed: %v", endpoint, err) From 4df470b869ac71c51b3cc53b7d71215e1604a0f0 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 15 Jan 2025 15:39:35 +0700 Subject: [PATCH 041/100] cmd/cli: ensure all ifaces operation is set correctly Since ctrld process does not rely on the global variable iface anymore during runtime, ctrld client's operations must be updated to reflect this change, too. --- cmd/cli/cli.go | 20 ++++++++++++-------- cmd/cli/commands.go | 20 ++++++++++++-------- cmd/cli/control_server.go | 16 +++++++++++++--- 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 6b7ac8f..4934c5a 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1723,20 +1723,22 @@ func upgradeUrl(baseUrl string) string { } // runningIface returns the value of the iface variable used by ctrld process which is running. -func runningIface(s service.Service) string { +func runningIface(s service.Service) *ifaceResponse { if sockDir, err := socketDir(); err == nil { if cc := newSocketControlClient(context.TODO(), s, sockDir); cc != nil { resp, err := cc.post(ifacePath, nil) if err != nil { - return "" + return nil } defer resp.Body.Close() - if buf, _ := io.ReadAll(resp.Body); len(buf) > 0 { - return string(buf) + res := &ifaceResponse{} + if err := json.NewDecoder(resp.Body).Decode(res); err != nil { + return nil } + return res } } - return "" + return nil } // resetDnsNoLog performs resetting DNS with logging disable. @@ -1754,7 +1756,7 @@ func resetDnsNoLog(p *prog) { } // resetDnsTask returns a task which perform reset DNS operation. -func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunningIface string) task { +func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task { return task{func() error { if iface == "" { return nil @@ -1764,8 +1766,10 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, currentRunn // process to reset what setDNS has done properly. oldIface := iface iface = "auto" - if currentRunningIface != "" { - iface = currentRunningIface + p.requiredMultiNICsConfig = requiredMultiNICsConfig() + if ir != nil { + iface = ir.Name + p.requiredMultiNICsConfig = ir.All } p.runningIface = iface if isCtrldInstalled { diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index ebf3dec..0982647 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -197,7 +197,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled) // Get current running iface, if any. - var currentIface string + var currentIface *ifaceResponse // If pin code was set, do not allow running start command. if isCtrldRunning { @@ -522,9 +522,10 @@ func initStopCmd() *cobra.Command { mainLog.Load().Error().Msg(err.Error()) return } - p.runningIface = iface - if ri := runningIface(s); ri != "" { - p.runningIface = ri + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All } initLogging() @@ -610,7 +611,9 @@ func initRestartCmd() *cobra.Command { doValidateCdRemoteConfig(cdUID) } - iface = runningIface(s) + if ir := runningIface(s); ir != nil { + iface = ir.Name + } tasks := []task{ {s.Stop, false}, {s.Start, true}, @@ -777,9 +780,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if iface == "" { iface = "auto" } - p.runningIface = iface - if ri := runningIface(s); ri != "" { - p.runningIface = ri + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All } if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 52406b1..d1daea3 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -31,6 +31,11 @@ const ( sendLogsPath = "/logs/send" ) +type ifaceResponse struct { + Name string `json:"name"` + All bool `json:"all"` +} + type controlServer struct { server *http.Server mux *http.ServeMux @@ -205,15 +210,20 @@ func (p *prog) registerControlServerHandler() { w.WriteHeader(http.StatusBadRequest) })) p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + res := &ifaceResponse{Name: iface} // p.setDNS is only called when running as a service if !service.Interactive() { <-p.csSetDnsDone if p.csSetDnsOk { - w.Write([]byte(iface)) - return + res.Name = p.runningIface + res.All = p.requiredMultiNICsConfig } } - w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(res); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError) + return + } })) p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { lr, err := p.logReader() From e9e63b09836dbffd233394e07f311883418fd3d7 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 15 Jan 2025 23:14:18 +0700 Subject: [PATCH 042/100] cmd/cli: check root privilege for log commands --- cmd/cli/commands.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 0982647..e5f655f 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -37,6 +37,9 @@ func initLogCmd() *cobra.Command { Use: "send", Short: "Send runtime debug logs to ControlD", Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, Run: func(cmd *cobra.Command, args []string) { dir, err := socketDir() if err != nil { @@ -74,6 +77,9 @@ func initLogCmd() *cobra.Command { Use: "view", Short: "View current runtime debug logs", Args: cobra.NoArgs, + PreRun: func(cmd *cobra.Command, args []string) { + checkHasElevatedPrivilege() + }, Run: func(cmd *cobra.Command, args []string) { dir, err := socketDir() if err != nil { From 7833132917e1f877f4c31e5a3e9ab63dcd48feaa Mon Sep 17 00:00:00 2001 From: Alex Paguis Date: Wed, 15 Jan 2025 17:31:10 -0500 Subject: [PATCH 043/100] Don't automatically restore saved DNS settings when switching networks smol tweaks to nameserver test queries fix restoreDNS errors add some debugging information fix wront type in log msg set send logs command timeout to 5 mins when the runningIface is no longer up, attempt to find a new interface prefer default route, ignore non physical interfaces prefer default route, ignore non physical interfaces add max context timeout on performLeakingQuery with more debug logs --- cmd/cli/cli.go | 10 +++ cmd/cli/commands.go | 10 +++ cmd/cli/control_client.go | 4 + cmd/cli/control_server.go | 4 +- cmd/cli/dns_proxy.go | 88 +++++++++++++++++--- cmd/cli/os_darwin.go | 14 ++-- cmd/cli/os_freebsd.go | 6 ++ cmd/cli/os_linux.go | 6 ++ cmd/cli/os_windows.go | 12 ++- cmd/cli/prog.go | 165 +++++++++++++++++++++++++++++++++++++- config.go | 9 +++ resolver.go | 95 +++++++++++++++++++--- 12 files changed, 387 insertions(+), 36 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 4934c5a..39a5977 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1029,6 +1029,16 @@ func uninstall(p *prog, s service.Service) { return } p.resetDNS() + + // if present restore the original DNS settings + if netIface, err := netInterface(p.runningIface); err == nil { + if err := restoreDNS(netIface); err != nil { + mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") + } else { + mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + } + } + if router.Name() != "" { mainLog.Load().Debug().Msg("Router cleanup") } diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index e5f655f..bae0cf1 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -541,6 +541,16 @@ func initStopCmd() *cobra.Command { if doTasks([]task{{s.Stop, true}}) { p.router.Cleanup() p.resetDNS() + + // restore DNS settings + if netIface, err := netInterface(p.runningIface); err == nil { + if err := restoreDNS(netIface); err != nil { + mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface") + } else { + mainLog.Load().Debug().Msg("Restored DNS on interface successfully") + } + } + if router.WaitProcessExited() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() diff --git a/cmd/cli/control_client.go b/cmd/cli/control_client.go index 73002e8..7382d4e 100644 --- a/cmd/cli/control_client.go +++ b/cmd/cli/control_client.go @@ -25,6 +25,10 @@ func newControlClient(addr string) *controlClient { } func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) { + // for log/send, set the timeout to 5 minutes + if path == sendLogsPath { + c.c.Timeout = time.Minute * 5 + } return c.c.Post("http://unix"+path, contentTypeJson, data) } diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index d1daea3..36285e5 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -27,8 +27,8 @@ const ( deactivationPath = "/deactivation" cdPath = "/cd" ifacePath = "/iface" - viewLogsPath = "/logs/view" - sendLogsPath = "/logs/send" + viewLogsPath = "/log/view" + sendLogsPath = "/log/send" ) type ifaceResponse struct { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 4f4b980..b2c0f23 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -542,8 +542,10 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if upstreamConfig == nil { continue } + ctrld.Log(ctx, mainLog.Load().Debug(), "attempting upstream [ %s ] at index: %d, upstream at index: %s", upstreamConfig.String(), n, upstreams[n]) + if p.isLoop(upstreamConfig) { - mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint) + mainLog.Load().Warn().Msgf("dns loop detected, upstream: %s", upstreamConfig.String()) continue } if p.um.isDown(upstreams[n]) { @@ -929,6 +931,11 @@ func (p *prog) selfUninstallCoolOfPeriod() { // performLeakingQuery performs necessary works to leak queries to OS resolver. func (p *prog) performLeakingQuery() { mainLog.Load().Warn().Msg("leaking query to OS resolver") + + // Create a context with timeout for the entire operation + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Signal dns watchers to stop, so changes made below won't be reverted. p.leakingQuery.Store(true) defer func() { @@ -936,20 +943,81 @@ func (p *prog) performLeakingQuery() { p.leakingQueryMu.Lock() p.leakingQueryWasRun = false p.leakingQueryMu.Unlock() + mainLog.Load().Warn().Msg("stop leaking query") }() - // Reset DNS, so queries are forwarded to OS resolver normally. - p.resetDNS() - // Check remote upstream in background, so ctrld could be back to normal - // operation as long as the network is back online. - for name, uc := range p.cfg.Upstream { - p.checkUpstream(name, uc) + + // Create channels to coordinate operations + resetDone := make(chan struct{}) + checkDone := make(chan struct{}) + + // Reset DNS with timeout + go func() { + defer close(resetDone) + mainLog.Load().Debug().Msg("attempting to reset DNS") + p.resetDNS() + mainLog.Load().Debug().Msg("DNS reset completed") + }() + + // Wait for reset with timeout + select { + case <-resetDone: + mainLog.Load().Debug().Msg("DNS reset successful") + case <-ctx.Done(): + mainLog.Load().Error().Msg("DNS reset timed out") + return } - // After all upstream back, re-initializing OS resolver. + + // Check upstream in background with progress tracking + go func() { + defer close(checkDone) + mainLog.Load().Debug().Msg("starting upstream checks") + for name, uc := range p.cfg.Upstream { + select { + case <-ctx.Done(): + return + default: + mainLog.Load().Debug(). + Str("upstream", name). + Msg("checking upstream") + p.checkUpstream(name, uc) + } + } + mainLog.Load().Debug().Msg("upstream checks completed") + }() + + // Wait for upstream checks + select { + case <-checkDone: + mainLog.Load().Debug().Msg("upstream checks successful") + case <-ctx.Done(): + mainLog.Load().Error().Msg("upstream checks timed out") + return + } + + // Initialize OS resolver with timeout + mainLog.Load().Debug().Msg("initializing OS resolver") ns := ctrld.InitializeOsResolver() mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) - p.dnsWg.Wait() + + // Wait for DNS operations to complete + waitCh := make(chan struct{}) + go func() { + p.dnsWg.Wait() + close(waitCh) + }() + + select { + case <-waitCh: + mainLog.Load().Debug().Msg("DNS operations completed") + case <-ctx.Done(): + mainLog.Load().Error().Msg("DNS operations timed out") + return + } + + // Set DNS with timeout + mainLog.Load().Debug().Msg("setting DNS configuration") p.setDNS() - mainLog.Load().Warn().Msg("stop leaking query") + mainLog.Load().Debug().Msg("DNS configuration set successfully") } // forceFetchingAPI sends signal to force syncing API config if run in cd mode, diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index f319056..841be76 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -70,11 +70,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error { // TODO(cuonglm): use system API func resetDNS(iface *net.Interface) error { - if ns := savedStaticNameservers(iface); len(ns) > 0 { - if err := setDNS(iface, ns); err == nil { - return nil - } - } cmd := "networksetup" args := []string{"-setdnsservers", iface.Name, "empty"} if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil { @@ -83,6 +78,15 @@ func resetDNS(iface *net.Interface) error { return nil } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + if ns := savedStaticNameservers(iface); len(ns) > 0 { + err = setDNS(iface, ns) + } + return err +} + func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index bddffca..72da485 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -76,6 +76,12 @@ func resetDNS(iface *net.Interface) error { return nil } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + return err +} + func currentDNS(_ *net.Interface) []string { return resolvconffile.NameServers("") } diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index ade5881..3f815e8 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -195,6 +195,12 @@ func resetDNS(iface *net.Interface) (err error) { }) } +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { + return err +} + func currentDNS(iface *net.Interface) []string { for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} { if ns := fn(iface.Name); len(ns) > 0 { diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index 5ff9360..990cc57 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -130,8 +130,12 @@ func resetDNS(iface *net.Interface) error { if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil { return fmt.Errorf("could not reset DNS ipv6: %w", err) } + return nil +} - // If there's static DNS saved, restoring it. +// restoreDNS restores the DNS settings of the given interface. +// this should only be executed upon turning off the ctrld service. +func restoreDNS(iface *net.Interface) (err error) { if nss := savedStaticNameservers(iface); len(nss) > 0 { v4ns := make([]string, 0, 2) v6ns := make([]string, 0, 2) @@ -148,12 +152,14 @@ func resetDNS(iface *net.Interface) error { continue } mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name) - if err := setDNS(iface, ns); err != nil { + err = setDNS(iface, ns) + + if err != nil { return err } } } - return nil + return err } func currentDNS(iface *net.Interface) []string { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 29c1120..b1fb18b 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -626,9 +626,31 @@ func (p *prog) setDNS() { return } logger := mainLog.Load().With().Str("iface", p.runningIface).Logger() - netIface, err := netInterface(p.runningIface) - if err != nil { - logger.Error().Err(err).Msg("could not get interface") + + const maxDNSRetryAttempts = 3 + const retryDelay = 1 * time.Second + var netIface *net.Interface + var err error + for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ { + netIface, err = netInterface(p.runningIface) + if err == nil { + break + } + if attempt < maxDNSRetryAttempts { + // Try to find a different working interface + newIface := findWorkingInterface(p.runningIface) + if newIface != p.runningIface { + p.runningIface = newIface + logger = mainLog.Load().With().Str("iface", p.runningIface).Logger() + logger.Info().Msg("switched to new interface") + continue + } + + logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...") + time.Sleep(retryDelay) + continue + } + logger.Error().Err(err).Msg("could not get interface after all attempts") return } if err := setupNetworkManager(); err != nil { @@ -766,6 +788,7 @@ func (p *prog) resetDNS() { logger.Error().Err(err).Msg("could not get interface") return } + if err := restoreNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not restore NetworkManager") return @@ -781,6 +804,131 @@ func (p *prog) resetDNS() { } } +// findWorkingInterface looks for a network interface with a valid IP configuration +func findWorkingInterface(currentIface string) string { + // Helper to check if IP is valid (not link-local) + isValidIP := func(ip net.IP) bool { + return ip != nil && + !ip.IsLinkLocalUnicast() && + !ip.IsLinkLocalMulticast() && + !ip.IsLoopback() && + !ip.IsUnspecified() + } + + // Helper to check if interface has valid IP configuration + hasValidIPConfig := func(iface *net.Interface) bool { + if iface == nil || iface.Flags&net.FlagUp == 0 { + return false + } + + addrs, err := iface.Addrs() + if err != nil { + mainLog.Load().Debug(). + Str("interface", iface.Name). + Err(err). + Msg("failed to get interface addresses") + return false + } + + for _, addr := range addrs { + // Check for IP network + if ipNet, ok := addr.(*net.IPNet); ok { + if isValidIP(ipNet.IP) { + return true + } + } + } + return false + } + + // Get default route interface + defaultRoute, err := netmon.DefaultRoute() + if err != nil { + mainLog.Load().Debug(). + Err(err). + Msg("failed to get default route") + } else { + mainLog.Load().Debug(). + Str("default_route_iface", defaultRoute.InterfaceName). + Msg("found default route") + } + + // Get all interfaces + ifaces, err := net.Interfaces() + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to list network interfaces") + return currentIface // Return current interface as fallback + } + + var firstWorkingIface string + var currentIfaceValid bool + + // Single pass through interfaces + for _, iface := range ifaces { + // Must be physical (has MAC address) + if len(iface.HardwareAddr) == 0 { + continue + } + // Skip interfaces that are: + // - Loopback + // - Not up + // - Point-to-point (like VPN tunnels) + if iface.Flags&net.FlagLoopback != 0 || + iface.Flags&net.FlagUp == 0 || + iface.Flags&net.FlagPointToPoint != 0 { + continue + } + + if !hasValidIPConfig(&iface) { + continue + } + + // Found working physical interface + if err == nil && defaultRoute.InterfaceName == iface.Name { + // Found interface with default route - use it immediately + mainLog.Load().Info(). + Str("old_iface", currentIface). + Str("new_iface", iface.Name). + Msg("switching to interface with default route") + return iface.Name + } + + // Keep track of first working interface as fallback + if firstWorkingIface == "" { + firstWorkingIface = iface.Name + } + + // Check if this is our current interface + if iface.Name == currentIface { + currentIfaceValid = true + } + } + + // Return interfaces in order of preference: + // 1. Current interface if it's still valid + if currentIfaceValid { + mainLog.Load().Debug(). + Str("interface", currentIface). + Msg("keeping current interface") + return currentIface + } + + // 2. First working interface found + if firstWorkingIface != "" { + mainLog.Load().Info(). + Str("old_iface", currentIface). + Str("new_iface", firstWorkingIface). + Msg("switching to first working physical interface") + return firstWorkingIface + } + + // 3. Fall back to current interface if nothing else works + mainLog.Load().Warn(). + Str("current_iface", currentIface). + Msg("no working physical interface found, keeping current") + return currentIface +} + // leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams. func (p *prog) leakOnUpstreamFailure() bool { if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil { @@ -1049,7 +1197,16 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string { func savedStaticNameservers(iface *net.Interface) []string { file := savedStaticDnsSettingsFilePath(iface) if data, _ := os.ReadFile(file); len(data) > 0 { - return strings.Split(string(data), ",") + saveValues := strings.Split(string(data), ",") + returnValues := []string{} + // check each one, if its in loopback range, remove it + for _, v := range saveValues { + if net.ParseIP(v).IsLoopback() { + continue + } + returnValues = append(returnValues, v) + } + return returnValues } return nil } diff --git a/config.go b/config.go index 4302c5d..c88404c 100644 --- a/config.go +++ b/config.go @@ -886,3 +886,12 @@ func upstreamUID() string { return hex.EncodeToString(b) } } + +// String returns a string representation of the UpstreamConfig for logging. +func (uc *UpstreamConfig) String() string { + if uc == nil { + return "" + } + return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}", + uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack) +} diff --git a/resolver.go b/resolver.go index 82a395e..3189dfb 100644 --- a/resolver.go +++ b/resolver.go @@ -147,16 +147,82 @@ var testNameServerFn = testNameserver // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. func testNameserver(addr string) bool { - msg := new(dns.Msg) - msg.SetQuestion("controld.com.", dns.TypeNS) - client := new(dns.Client) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53")) - if err != nil { - ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr) + // Skip link-local addresses without scope IDs and deprecated site-local addresses + if ip, err := netip.ParseAddr(addr); err == nil { + if ip.Is6() { + if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") { + ProxyLogger.Load().Debug(). + Str("nameserver", addr). + Msg("skipping link-local IPv6 address without scope ID") + return false + } + // Skip deprecated site-local addresses (fec0::/10) + if strings.HasPrefix(ip.String(), "fec0:") { + ProxyLogger.Load().Debug(). + Str("nameserver", addr). + Msg("skipping deprecated site-local IPv6 address") + return false + } + } } - return err == nil + + ProxyLogger.Load().Debug(). + Str("input_addr", addr). + Msg("testing nameserver") + + // Handle both IPv4 and IPv6 addresses + serverAddr := addr + host, port, err := net.SplitHostPort(addr) + if err != nil { + // No port in address, add default port 53 + serverAddr = net.JoinHostPort(addr, "53") + } else if port == "" { + // Has split markers but empty port + serverAddr = net.JoinHostPort(host, "53") + } + + ProxyLogger.Load().Debug(). + Str("server_addr", serverAddr). + Msg("using server address") + + // Test domains that are likely to exist and respond quickly + testDomains := []struct { + name string + qtype uint16 + }{ + {".", dns.TypeNS}, // Root NS query - should always work + {"controld.com.", dns.TypeA}, // Fallback to a reliable domain + {"google.com.", dns.TypeA}, // Fallback to a reliable domain + } + + client := &dns.Client{ + Timeout: 2 * time.Second, + Net: "udp", + } + + // Try each test query until one succeeds + for _, test := range testDomains { + msg := new(dns.Msg) + msg.SetQuestion(test.name, test.qtype) + msg.RecursionDesired = true + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + resp, _, err := client.ExchangeContext(ctx, msg, serverAddr) + cancel() + + if err == nil && resp != nil { + return true + } + + ProxyLogger.Load().Error(). + Err(err). + Str("nameserver", serverAddr). + Str("test_domain", test.name). + Str("query_type", dns.TypeToString[test.qtype]). + Msg("DNS availability test failed") + } + + return false } // Resolver is the interface that wraps the basic DNS operations. @@ -222,7 +288,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error ctx, cancel := context.WithCancel(ctx) defer cancel() - dnsClient := &dns.Client{Net: "udp"} + dnsClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second} ch := make(chan *osResolverResult, numServers) wg := &sync.WaitGroup{} wg.Add(numServers) @@ -264,11 +330,14 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: switch { case res.server == controldPublicDnsWithPort: - controldSuccessAnswer = res.answer // only use ControlD answer as last one. + Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", res.server) + controldSuccessAnswer = res.answer case !res.lan && publicServerAnswer == nil: - publicServerAnswer = res.answer // use public DNS answer after LAN server.. + Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", res.server) + publicServerAnswer = res.answer publicServer = res.server default: + Log(ctx, ProxyLogger.Load().Debug(), "got LAN answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil @@ -276,6 +345,8 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error case res.answer != nil: nonSuccessAnswer = res.answer nonSuccessServer = res.server + Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d", + res.server, res.answer.Rcode) } errs = append(errs, res.err) } From 841be069b7bf6e5c1e1b10a1fd84b622c912d9c0 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 16 Jan 2025 14:00:34 +0700 Subject: [PATCH 044/100] cmd/cli: only list physical interfaces when listing Since these are the interfaces that ctrld will manipulate anyway. While at it, also skipping non-working devices on MacOS, by checking if the device is present in network service order --- cmd/cli/cli.go | 2 +- cmd/cli/commands.go | 17 ++++++++++------- cmd/cli/net_darwin.go | 8 +++++--- cmd/cli/net_others.go | 2 +- cmd/cli/net_windows.go | 4 ++-- cmd/cli/prog.go | 7 +++++-- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 39a5977..7565517 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -764,7 +764,7 @@ func netInterface(ifaceName string) (*net.Interface, error) { if iface == nil { return nil, errors.New("interface not found") } - if err := patchNetIfaceName(iface); err != nil { + if _, err := patchNetIfaceName(iface); err != nil { return nil, err } return iface, err diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index bae0cf1..9845093 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -9,7 +9,6 @@ import ( "io" "net" "net/http" - "net/netip" "os" "os/exec" "path/filepath" @@ -25,7 +24,6 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "github.com/spf13/pflag" - "tailscale.com/net/netmon" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/clientinfo" @@ -903,7 +901,7 @@ func initInterfacesCmd() *cobra.Command { Short: "List network interfaces of the host", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - err := netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { fmt.Printf("Index : %d\n", i.Index) fmt.Printf("Name : %s\n", i.Name) addrs, _ := i.Addrs() @@ -914,7 +912,14 @@ func initInterfacesCmd() *cobra.Command { } fmt.Printf(" %v\n", ipaddr) } - for i, dns := range currentDNS(i.Interface) { + nss, err := currentStaticDNS(i) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get DNS") + } + if len(nss) == 0 { + nss = currentDNS(i) + } + for i, dns := range nss { if i == 0 { fmt.Printf("DNS : %s\n", dns) continue @@ -922,10 +927,8 @@ func initInterfacesCmd() *cobra.Command { fmt.Printf(" : %s\n", dns) } println() + return nil }) - if err != nil { - mainLog.Load().Error().Msg(err.Error()) - } }, } interfacesCmd := &cobra.Command{ diff --git a/cmd/cli/net_darwin.go b/cmd/cli/net_darwin.go index ece1862..ec7e517 100644 --- a/cmd/cli/net_darwin.go +++ b/cmd/cli/net_darwin.go @@ -9,17 +9,19 @@ import ( "strings" ) -func patchNetIfaceName(iface *net.Interface) error { +func patchNetIfaceName(iface *net.Interface) (bool, error) { b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output() if err != nil { - return err + return false, err } + patched := false if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" { iface.Name = name mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface") + patched = true } - return nil + return patched, nil } func networkServiceName(ifaceName string, r io.Reader) string { diff --git a/cmd/cli/net_others.go b/cmd/cli/net_others.go index 5a66e82..edd89ec 100644 --- a/cmd/cli/net_others.go +++ b/cmd/cli/net_others.go @@ -4,7 +4,7 @@ package cli import "net" -func patchNetIfaceName(iface *net.Interface) error { return nil } +func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true } diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index f46a93f..5b6e6b4 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -13,8 +13,8 @@ import ( "github.com/microsoft/wmi/pkg/hardware/network/netadapter" ) -func patchNetIfaceName(iface *net.Interface) error { - return nil +func patchNetIfaceName(iface *net.Interface) (bool, error) { + return true, nil } // validInterface reports whether the *net.Interface is a valid one. diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index b1fb18b..6aa95b1 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -1119,7 +1119,7 @@ func canBeLocalUpstream(addr string) bool { func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) { validIfacesMap := validInterfacesMap() netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { - // Skip loopback/virtual interface. + // Skip loopback/virtual/down interface. if i.IsLoopback() || len(i.HardwareAddr) == 0 { return } @@ -1128,9 +1128,12 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net. return } netIface := i.Interface - if err := patchNetIfaceName(netIface); err != nil { + if patched, err := patchNetIfaceName(netIface); err != nil { mainLog.Load().Debug().Err(err).Msg("failed to patch net interface name") return + } else if !patched { + // The interface is not functional, skipping. + return } // Skip excluded interface. if netIface.Name == excludeIfaceName { From 2d9c60dea1d7b2a92da942b0138c8c9bc8814060 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 16 Jan 2025 07:24:16 +0700 Subject: [PATCH 045/100] cmd/cli: log that multiple interfaces DNS set --- cmd/cli/commands.go | 24 ++++++++++++++++++++++-- cmd/cli/control_server.go | 2 ++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 9845093..4e32a7d 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -220,8 +220,28 @@ NOTE: running "ctrld start" without any arguments will start already installed c if iface == "auto" { iface = defaultIfaceName() } - logger := mainLog.Load().With().Str("iface", iface).Logger() - logger.Debug().Msg("setting DNS successfully") + res := &ifaceResponse{} + if err := json.NewDecoder(resp.Body).Decode(res); err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get iface info") + return + } + if res.OK { + name := res.Name + if iff, err := net.InterfaceByName(name); err == nil { + _, _ = patchNetIfaceName(iff) + name = iff.Name + } + logger := mainLog.Load().With().Str("iface", name).Logger() + logger.Debug().Msg("setting DNS successfully") + if res.All { + // Log that DNS is set for other interfaces. + withEachPhysicalInterfaces( + name, + "set DNS", + func(i *net.Interface) error { return nil }, + ) + } + } } } } diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 36285e5..1ea1693 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -34,6 +34,7 @@ const ( type ifaceResponse struct { Name string `json:"name"` All bool `json:"all"` + OK bool `json:"ok"` } type controlServer struct { @@ -217,6 +218,7 @@ func (p *prog) registerControlServerHandler() { if p.csSetDnsOk { res.Name = p.runningIface res.All = p.requiredMultiNICsConfig + res.OK = true } } if err := json.NewEncoder(w).Encode(res); err != nil { From 2687a4a0180cd70bd5803d2e54e10f2b267f8b8e Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 16 Jan 2025 19:27:24 -0500 Subject: [PATCH 046/100] remove leaking timeout, fix blocking upstreams checks, leaking is per listener, OS resolvers are tested in parallel, reset is only done is os is down fix test use upstreamIS var init map, fix watcher flag attempt to detect network changes attempt to detect network changes cancel and rerun reinitializeOSResolver cancel and rerun reinitializeOSResolver cancel and rerun reinitializeOSResolver ignore invalid inferaces ignore invalid inferaces allow OS resolver upstream to fail dont wait for dnsWait group on reinit, check for active interfaces to trigger reinit fix unused var simpler active iface check, debug logs dont spam network service name patching on Mac dont wait for os resolver nameserver testing remove test for osresovlers for now async nameserver testing remove unused test --- cmd/cli/dns_proxy.go | 300 ++++++++++++++++++++++++++---------- cmd/cli/net_darwin.go | 3 +- cmd/cli/prog.go | 18 ++- cmd/cli/resolvconf.go | 2 +- cmd/cli/upstream_monitor.go | 4 - resolver.go | 102 +++++++++--- resolver_test.go | 74 --------- 7 files changed, 313 insertions(+), 190 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index b2c0f23..341a830 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -19,6 +19,7 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/controld" @@ -77,6 +78,12 @@ type upstreamForResult struct { } func (p *prog) serveDNS(listenerNum string) error { + // Start network monitoring + if err := p.monitorNetworkChanges(); err != nil { + mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") + // Don't return here as we still want DNS service to run + } + listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { @@ -418,11 +425,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) + upstreamMapKey := strings.Join(upstreams, "_") + leaked := false - if len(upstreamConfigs) > 0 && p.leakingQuery.Load() { - upstreamConfigs = nil - leaked = true - ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) + if len(upstreamConfigs) > 0 { + p.leakingQueryMu.Lock() + if p.leakingQueryRunning[upstreamMapKey] { + upstreamConfigs = nil + leaked = true + ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) + } + p.leakingQueryMu.Unlock() } if len(upstreamConfigs) == 0 { @@ -601,9 +614,15 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) if p.leakOnUpstreamFailure() { p.leakingQueryMu.Lock() - if !p.leakingQueryWasRun { - p.leakingQueryWasRun = true - go p.performLeakingQuery() + // get the map key as concact of upstreams + if !p.leakingQueryRunning[upstreamMapKey] { + p.leakingQueryRunning[upstreamMapKey] = true + // get a map of the failed upstreams + failedUpstreams := make(map[string]*ctrld.UpstreamConfig) + for n, upstream := range upstreamConfigs { + failedUpstreams[upstreams[n]] = upstream + } + go p.performLeakingQuery(failedUpstreams, upstreamMapKey) } p.leakingQueryMu.Unlock() } @@ -929,95 +948,66 @@ func (p *prog) selfUninstallCoolOfPeriod() { } // performLeakingQuery performs necessary works to leak queries to OS resolver. -func (p *prog) performLeakingQuery() { - mainLog.Load().Warn().Msg("leaking query to OS resolver") +// once we store the leakingQuery flag, we are leaking queries to OS resolver +// we then start testing all the upstreams forever, waiting for success, but in parallel +func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamConfig, upstreamMapKey string) { - // Create a context with timeout for the entire operation - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + mainLog.Load().Warn().Msgf("leaking queries for failed upstreams [%v] to OS resolver", failedUpstreams) // Signal dns watchers to stop, so changes made below won't be reverted. - p.leakingQuery.Store(true) + p.leakingQueryMu.Lock() + p.leakingQueryRunning[upstreamMapKey] = true + p.leakingQueryMu.Unlock() defer func() { - p.leakingQuery.Store(false) p.leakingQueryMu.Lock() - p.leakingQueryWasRun = false + p.leakingQueryRunning[upstreamMapKey] = false p.leakingQueryMu.Unlock() mainLog.Load().Warn().Msg("stop leaking query") }() - // Create channels to coordinate operations - resetDone := make(chan struct{}) - checkDone := make(chan struct{}) + // we only want to reset DNS when our resolver is broken + // this allows us to find the new OS resolver nameservers + if p.um.isDown(upstreamOS) { - // Reset DNS with timeout - go func() { - defer close(resetDone) - mainLog.Load().Debug().Msg("attempting to reset DNS") - p.resetDNS() - mainLog.Load().Debug().Msg("DNS reset completed") - }() + mainLog.Load().Debug().Msg("OS resolver is down, reinitializing") + p.reinitializeOSResolver() - // Wait for reset with timeout - select { - case <-resetDone: - mainLog.Load().Debug().Msg("DNS reset successful") - case <-ctx.Done(): - mainLog.Load().Error().Msg("DNS reset timed out") - return } - // Check upstream in background with progress tracking - go func() { - defer close(checkDone) - mainLog.Load().Debug().Msg("starting upstream checks") - for name, uc := range p.cfg.Upstream { - select { - case <-ctx.Done(): - return - default: - mainLog.Load().Debug(). - Str("upstream", name). - Msg("checking upstream") - p.checkUpstream(name, uc) + // Test all failed upstreams in parallel + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + upstreamCh := make(chan string, len(failedUpstreams)) + for name, uc := range failedUpstreams { + go func(name string, uc *ctrld.UpstreamConfig) { + mainLog.Load().Debug(). + Str("upstream", name). + Msg("checking upstream") + + for { + select { + case <-ctx.Done(): + return + default: + p.checkUpstream(name, uc) + mainLog.Load().Debug(). + Str("upstream", name). + Msg("upstream recovered") + upstreamCh <- name + return + } } - } - mainLog.Load().Debug().Msg("upstream checks completed") - }() - - // Wait for upstream checks - select { - case <-checkDone: - mainLog.Load().Debug().Msg("upstream checks successful") - case <-ctx.Done(): - mainLog.Load().Error().Msg("upstream checks timed out") - return + }(name, uc) } - // Initialize OS resolver with timeout - mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver() - mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) + // Wait for any upstream to recover + name := <-upstreamCh - // Wait for DNS operations to complete - waitCh := make(chan struct{}) - go func() { - p.dnsWg.Wait() - close(waitCh) - }() + mainLog.Load().Info(). + Str("upstream", name). + Msg("stopping leak as upstream recovered") - select { - case <-waitCh: - mainLog.Load().Debug().Msg("DNS operations completed") - case <-ctx.Done(): - mainLog.Load().Error().Msg("DNS operations timed out") - return - } - - // Set DNS with timeout - mainLog.Load().Debug().Msg("setting DNS configuration") - p.setDNS() - mainLog.Load().Debug().Msg("DNS configuration set successfully") } // forceFetchingAPI sends signal to force syncing API config if run in cd mode, @@ -1190,3 +1180,157 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M answer.SetReply(m) return answer } + +// reinitializeOSResolver reinitializes the OS resolver +// by removing ctrld listenr from the interface, collecting the network nameservers +// and re-initializing the OS resolver with the nameservers +// applying listener back to the interface +func (p *prog) reinitializeOSResolver() { + // Cancel any existing operations + p.resetCtxMu.Lock() + if p.resetCancel != nil { + p.resetCancel() + } + + // Create new context for this operation + ctx, cancel := context.WithCancel(context.Background()) + p.resetCtx = ctx + p.resetCancel = cancel + p.resetCtxMu.Unlock() + + // Ensure cleanup + defer cancel() + + p.leakingQueryReset.Store(true) + defer p.leakingQueryReset.Store(false) + + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") + return + default: + mainLog.Load().Debug().Msg("attempting to reset DNS") + p.resetDNS() + mainLog.Load().Debug().Msg("DNS reset completed") + } + + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") + return + default: + mainLog.Load().Debug().Msg("initializing OS resolver") + ns := ctrld.InitializeOsResolver() + mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) + } + + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") + return + default: + mainLog.Load().Debug().Msg("setting DNS configuration") + p.setDNS() + mainLog.Load().Debug().Msg("DNS configuration set successfully") + } +} + +// monitorNetworkChanges starts monitoring for network interface changes +func (p *prog) monitorNetworkChanges() error { + // Create network monitor + mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) + if err != nil { + return fmt.Errorf("creating network monitor: %w", err) + } + + mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { + // Get map of valid interfaces + validIfaces := validInterfacesMap() + + // Parse old and new interface states + oldIfs := parseInterfaceState(delta.Old) + newIfs := parseInterfaceState(delta.New) + + // Check for changes in valid interfaces + changed := false + activeInterfaceExists := false + + for ifaceName := range validIfaces { + + oldState, oldExists := oldIfs[strings.ToLower(ifaceName)] + newState, newExists := newIfs[strings.ToLower(ifaceName)] + + if newState != "" && newState != "down" { + activeInterfaceExists = true + } + + if oldExists != newExists || oldState != newState { + changed = true + mainLog.Load().Debug(). + Str("interface", ifaceName). + Str("old_state", oldState). + Str("new_state", newState). + Msg("Valid interface changed state") + break + } else { + mainLog.Load().Debug(). + Str("interface", ifaceName). + Str("old_state", oldState). + Str("new_state", newState). + Msg("Valid interface unchanged") + } + } + + if !changed { + mainLog.Load().Debug().Msgf("Ignoring interface change - no valid interfaces affected") + return + } + + mainLog.Load().Debug().Msgf("Network change detected: from %v to %v", delta.Old, delta.New) + if activeInterfaceExists { + p.reinitializeOSResolver() + } else { + mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") + } + }) + + mon.Start() + mainLog.Load().Debug().Msg("Network monitor started") + return nil +} + +// parseInterfaceState parses the interface state string into a map of interface name -> state +func parseInterfaceState(state *netmon.State) map[string]string { + if state == nil { + return nil + } + + result := make(map[string]string) + + // Extract ifs={...} section + stateStr := state.String() + ifsStart := strings.Index(stateStr, "ifs={") + if ifsStart == -1 { + return result + } + + ifsStr := stateStr[ifsStart+5:] + ifsEnd := strings.Index(ifsStr, "}") + if ifsEnd == -1 { + return result + } + + // Parse each interface entry + ifaces := strings.Split(ifsStr[:ifsEnd], " ") + for _, iface := range ifaces { + parts := strings.Split(iface, ":") + if len(parts) != 2 { + continue + } + name := strings.ToLower(parts[0]) + state := parts[1] + result[name] = state + } + + return result +} diff --git a/cmd/cli/net_darwin.go b/cmd/cli/net_darwin.go index ec7e517..6233161 100644 --- a/cmd/cli/net_darwin.go +++ b/cmd/cli/net_darwin.go @@ -17,9 +17,8 @@ func patchNetIfaceName(iface *net.Interface) (bool, error) { patched := false if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" { - iface.Name = name - mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface") patched = true + iface.Name = name } return patched, nil } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 6aa95b1..a68dad2 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -115,9 +115,13 @@ type prog struct { loopMu sync.Mutex loop map[string]bool - leakingQueryMu sync.Mutex - leakingQueryWasRun bool - leakingQuery atomic.Bool + leakingQueryMu sync.Mutex + leakingQueryRunning map[string]bool + leakingQueryReset atomic.Bool + + resetCtx context.Context + resetCancel context.CancelFunc + resetCtxMu sync.Mutex started chan struct{} onStartedDone chan struct{} @@ -420,6 +424,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) + p.leakingQueryRunning = make(map[string]bool) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() p.cacheFlushDomainsMap = nil @@ -737,12 +742,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces if !requiredMultiNICsConfig() { return } + logger := mainLog.Load().With().Str("iface", iface.Name).Logger() + logger.Debug().Msg("start DNS settings watchdog") - mainLog.Load().Debug().Msg("start DNS settings watchdog") ns := nameservers slices.Sort(ns) ticker := time.NewTicker(p.dnsWatchdogDuration()) - logger := mainLog.Load().With().Str("iface", iface.Name).Logger() + for { select { case <-p.dnsWatcherStopCh: @@ -751,7 +757,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces mainLog.Load().Debug().Msg("stop dns watchdog") return case <-ticker.C: - if p.leakingQuery.Load() { + if p.leakingQueryReset.Load() { return } if dnsChanged(iface, ns) { diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 6df7be6..367ffe7 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -40,7 +40,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: - if p.leakingQuery.Load() { + if p.leakingQueryReset.Load() { return } if !ok { diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 1f3484b..e37db4d 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -44,10 +44,6 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { // increaseFailureCount increase failed queries count for an upstream by 1. func (um *upstreamMonitor) increaseFailureCount(upstream string) { - // Do not count "upstream.os", since it must not be down for leaking queries. - if upstream == upstreamOS { - return - } um.mu.Lock() defer um.mu.Unlock() diff --git a/resolver.go b/resolver.go index 3189dfb..0097fe0 100644 --- a/resolver.go +++ b/resolver.go @@ -78,9 +78,7 @@ func availableNameservers() []string { if _, ok := machineIPsMap[ns]; ok { continue } - if testNameServerFn(ns) { - nss = append(nss, ns) - } + nss = append(nss, ns) } return nss } @@ -100,11 +98,9 @@ func InitializeOsResolver() []string { // - First available LAN servers are saved and store. // - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { - var ( - lanNss []string - publicNss []string - ) + var lanNss, publicNss []string + // First categorize servers for _, ns := range servers { addr, err := netip.ParseAddr(ns) if err != nil { @@ -117,28 +113,84 @@ func initializeOsResolver(servers []string) []string { publicNss = append(publicNss, server) } } + + // Store initial servers immediately if len(lanNss) > 0 { - // Saved first initialized LAN servers. or.initializedLanServers.CompareAndSwap(nil, &lanNss) - } - if len(lanNss) == 0 { - var nss []string - p := or.initializedLanServers.Load() - if p != nil { - for _, ns := range *p { - if testNameServerFn(ns) { - nss = append(nss, ns) - } - } - } - or.lanServers.Store(&nss) - } else { or.lanServers.Store(&lanNss) } + if len(publicNss) == 0 { - publicNss = append(publicNss, controldPublicDnsWithPort) + publicNss = []string{controldPublicDnsWithPort} } or.publicServers.Store(&publicNss) + + // Test servers in background and remove failures + go func() { + // Test servers in parallel but maintain order + type result struct { + index int + server string + valid bool + } + + testServers := func(servers []string) []string { + if len(servers) == 0 { + return nil + } + + results := make(chan result, len(servers)) + var wg sync.WaitGroup + + for i, server := range servers { + wg.Add(1) + go func(idx int, s string) { + defer wg.Done() + results <- result{ + index: idx, + server: s, + valid: testNameServerFn(s), + } + }(i, server) + } + + go func() { + wg.Wait() + close(results) + }() + + // Collect results maintaining original order + validServers := make([]string, 0, len(servers)) + ordered := make([]result, 0, len(servers)) + for r := range results { + ordered = append(ordered, r) + } + slices.SortFunc(ordered, func(a, b result) int { + return a.index - b.index + }) + for _, r := range ordered { + if r.valid { + validServers = append(validServers, r.server) + } else { + ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing") + } + } + return validServers + } + + // Test and update LAN servers + if validLanNss := testServers(lanNss); len(validLanNss) > 0 { + or.lanServers.Store(&validLanNss) + } + + // Test and update public servers + validPublicNss := testServers(publicNss) + if len(validPublicNss) == 0 { + validPublicNss = []string{controldPublicDnsWithPort} + } + or.publicServers.Store(&validPublicNss) + }() + return slices.Concat(lanNss, publicNss) } @@ -192,7 +244,6 @@ func testNameserver(addr string) bool { }{ {".", dns.TypeNS}, // Root NS query - should always work {"controld.com.", dns.TypeA}, // Fallback to a reliable domain - {"google.com.", dns.TypeA}, // Fallback to a reliable domain } client := &dns.Client{ @@ -330,10 +381,8 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: switch { case res.server == controldPublicDnsWithPort: - Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", res.server) controldSuccessAnswer = res.answer case !res.lan && publicServerAnswer == nil: - Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", res.server) publicServerAnswer = res.answer publicServer = res.server default: @@ -351,14 +400,17 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error errs = append(errs, res.err) } if publicServerAnswer != nil { + Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", publicServer) logAnswer(publicServer) return publicServerAnswer, nil } if controldSuccessAnswer != nil { + Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort) logAnswer(controldPublicDnsWithPort) return controldSuccessAnswer, nil } if nonSuccessAnswer != nil { + Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer) logAnswer(nonSuccessServer) return nonSuccessAnswer, nil } diff --git a/resolver_test.go b/resolver_test.go index 7eab744..de8cca0 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -3,13 +3,10 @@ package ctrld import ( "context" "net" - "slices" "sync" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/miekg/dns" ) @@ -178,71 +175,6 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H return server, addr, nil } -func Test_initializeOsResolver(t *testing.T) { - testNameServerFn = testNameserverTest - lanServer1 := "192.168.1.1" - lanServer1WithPort := net.JoinHostPort("192.168.1.1", "53") - lanServer2 := "10.0.10.69" - lanServer2WithPort := net.JoinHostPort("10.0.10.69", "53") - lanServer3 := "192.168.40.1" - lanServer3WithPort := net.JoinHostPort("192.168.40.1", "53") - wanServer := "1.1.1.1" - lanServers := []string{lanServer1WithPort, lanServer2WithPort} - publicServers := []string{net.JoinHostPort(wanServer, "53")} - - or = newResolverWithNameserver(defaultNameservers()) - - // First initialization, initialized servers are saved. - initializeOsResolver([]string{lanServer1, lanServer2, wanServer}) - p := or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - - // No new LAN servers, but lanServer2 gone, initialized servers not changed. - initializeOsResolver([]string{lanServer1, wanServer}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer1WithPort})) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - - // New LAN servers, they are used, initialized servers not changed. - initializeOsResolver([]string{lanServer3, wanServer}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), []string{lanServer3WithPort})) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - - // No LAN server available, initialized servers will be used. - initializeOsResolver([]string{wanServer}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) - - // No Public server, ControlD Public DNS will be used. - initializeOsResolver([]string{}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.True(t, slices.Equal(*or.lanServers.Load(), lanServers)) - assert.True(t, slices.Equal(*or.publicServers.Load(), []string{controldPublicDnsWithPort})) - - // No LAN server available, initialized servers is unavailable, nothing will be used. - nonSuccessTestServerMap[lanServer1WithPort] = true - nonSuccessTestServerMap[lanServer2WithPort] = true - initializeOsResolver([]string{wanServer}) - p = or.initializedLanServers.Load() - assert.NotNil(t, p) - assert.True(t, slices.Equal(*p, lanServers)) - assert.Empty(t, *or.lanServers.Load()) - assert.True(t, slices.Equal(*or.publicServers.Load(), publicServers)) -} - func successHandler() dns.HandlerFunc { return func(w dns.ResponseWriter, msg *dns.Msg) { m := new(dns.Msg) @@ -258,9 +190,3 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc { w.WriteMsg(m) } } - -var nonSuccessTestServerMap = map[string]bool{} - -func testNameserverTest(addr string) bool { - return !nonSuccessTestServerMap[addr] -} From 9718ab8579898a5992a7dc583c95ea5e72833227 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 17 Jan 2025 14:09:51 +0700 Subject: [PATCH 047/100] cmd/cli: fix getting interface name when disabled on Windows By getting the name property directly from adapter instance, instead of using net.InterfaceByIndex function, which could return an error when the adapter is disabled. --- cmd/cli/net_windows.go | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index 5b6e6b4..fe075a3 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -41,7 +41,7 @@ func validInterfaces() []string { q := query.NewWmiQuery("MSFT_NetAdapter") instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) if err != nil { - mainLog.Load().Err(err).Msg("failed to get wmi network adapter") + mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter") return nil } defer instances.Close() @@ -49,7 +49,7 @@ func validInterfaces() []string { for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) if err != nil { - mainLog.Load().Err(err).Msg("failed to get network adapter") + mainLog.Load().Warn().Err(err).Msg("failed to get network adapter") continue } // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) @@ -58,23 +58,18 @@ func validInterfaces() []string { // if this is a physical adapter or FALSE if this is not a physical adapter." physical, err := adapter.GetPropertyConnectorPresent() if err != nil { - mainLog.Load().Err(err).Msg("failed to get network adapter connector present property") + mainLog.Load().Warn().Err(err).Msg("failed to get network adapter connector present property") continue } if !physical { continue } - ifaceIdx, err := adapter.GetInterfaceIndex() + name, err := adapter.GetPropertyName() if err != nil { - mainLog.Load().Err(err).Msg("failed to get interface index") + mainLog.Load().Warn().Err(err).Msg("failed to get interface name") continue } - iff, err := net.InterfaceByIndex(int(ifaceIdx)) - if err != nil { - mainLog.Load().Err(err).Msg("failed to get interface") - continue - } - adapters = append(adapters, iff.Name) + adapters = append(adapters, name) } return adapters } From 7ed6733fb731e6ec9c4f616b50afb91c5c427456 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 21 Jan 2025 00:21:16 +0700 Subject: [PATCH 048/100] cmd/cli: better error if internal log is not available --- cmd/cli/commands.go | 10 ++++++---- cmd/cli/log_writer.go | 3 ++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 4e32a7d..8713ea5 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -31,6 +31,10 @@ import ( ) func initLogCmd() *cobra.Command { + warnRuntimeLoggingNotEnabled := func() { + mainLog.Load().Warn().Msg("runtime debug logging is not enabled") + mainLog.Load().Warn().Msg(`ctrld may be running without "--cd" flag or logging is already enabled`) + } logSendCmd := &cobra.Command{ Use: "send", Short: "Send runtime debug logs to ControlD", @@ -54,8 +58,7 @@ func initLogCmd() *cobra.Command { mainLog.Load().Warn().Msg("runtime logs could only be sent once per minute") return case http.StatusMovedPermanently: - mainLog.Load().Warn().Msg("runtime debugs log is not enabled") - mainLog.Load().Warn().Msg(`ctrld may be run without "--cd" flag or logging is already enabled`) + warnRuntimeLoggingNotEnabled() return } var logs logSentResponse @@ -92,8 +95,7 @@ func initLogCmd() *cobra.Command { switch resp.StatusCode { case http.StatusMovedPermanently: - mainLog.Load().Warn().Msg("runtime debugs log is not enabled") - mainLog.Load().Warn().Msg(`ctrld may be run without "--cd" flag or logging is already enabled`) + warnRuntimeLoggingNotEnabled() return case http.StatusBadRequest: mainLog.Load().Warn().Msg("runtime debugs log is not available") diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index c146f4e..92e0e63 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "strings" "sync" "time" @@ -165,7 +166,7 @@ func (p *prog) logReader() (*logReader, error) { return lr, nil } if p.cfg.Service.LogPath == "" { - return nil, nil + return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil } f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath)) if err != nil { From 69e0aab73e5b07626ffb3da896f87d2d153a9489 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 23 Jan 2025 19:09:10 +0700 Subject: [PATCH 049/100] cmd/cli: use wmi to get AD domain Since using syscall.NetGetJoinInformation won't return the full domain name. Discovered while investigating issue with SRV ldap check. --- cmd/cli/ad_windows.go | 25 ++++++++++++++++--------- go.mod | 1 + go.sum | 2 ++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index 475ba09..3f9fa17 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -1,11 +1,13 @@ package cli import ( + "io" + "log" + "os" "strings" - "syscall" - "unsafe" - "golang.org/x/sys/windows" + "github.com/microsoft/wmi/pkg/base/host" + hh "github.com/microsoft/wmi/pkg/hardware/host" "github.com/Control-D-Inc/ctrld" ) @@ -50,15 +52,20 @@ func addSplitDnsRule(cfg *ctrld.Config, domain string) bool { // getActiveDirectoryDomain returns AD domain name of this computer. func getActiveDirectoryDomain() (string, error) { - var domain *uint16 - var status uint32 - err := syscall.NetGetJoinInformation(nil, &domain, &status) + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + whost := host.NewWmiLocalHost() + cs, err := hh.GetComputerSystem(whost) if err != nil { return "", err } - defer syscall.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) - if status == syscall.NetSetupDomainName { - return windows.UTF16PtrToString(domain), nil + defer cs.Close() + pod, err := cs.GetPropertyPartOfDomain() + if err != nil { + return "", err + } + if pod { + return cs.GetPropertyDomain() } return "", nil } diff --git a/go.mod b/go.mod index 67fe9a2..a86557e 100644 --- a/go.mod +++ b/go.mod @@ -58,6 +58,7 @@ require ( github.com/golang/protobuf v1.5.4 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd // indirect + github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jsimonetti/rtnetlink v1.4.0 // indirect diff --git a/go.sum b/go.sum index bcf1ee7..3eb268a 100644 --- a/go.sum +++ b/go.sum @@ -166,6 +166,8 @@ github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlG github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= From 20759017e6aa9d8c3cf1ed8c2115c712787fbfcd Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 24 Jan 2025 00:38:53 +0700 Subject: [PATCH 050/100] all: use local resolver for ADDC For normal OS resolver, ctrld does not use local addresses as nameserver to avoid possible looping. However, on AD environment with local DNS running, AD queries must be sent to the local DNS server for proper resolving. --- cmd/cli/ad_others.go | 5 +++++ cmd/cli/dns_proxy.go | 19 +++++++++++++++++++ cmd/cli/prog.go | 5 +++++ config.go | 2 +- resolver.go | 19 +++++++++++++++++-- 5 files changed, 47 insertions(+), 3 deletions(-) diff --git a/cmd/cli/ad_others.go b/cmd/cli/ad_others.go index 6a7417f..b23476f 100644 --- a/cmd/cli/ad_others.go +++ b/cmd/cli/ad_others.go @@ -8,3 +8,8 @@ import ( // addExtraSplitDnsRule adds split DNS rule if present. func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false } + +// getActiveDirectoryDomain returns AD domain name of this computer. +func getActiveDirectoryDomain() (string, error) { + return "", nil +} diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 341a830..5396642 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -51,6 +51,12 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } +var localUpstreamConfig = &ctrld.UpstreamConfig{ + Name: "Local resolver", + Type: ctrld.ResolverTypeLocal, + Timeout: 2000, +} + // proxyRequest contains data for proxying a DNS query to upstream. type proxyRequest struct { msg *dns.Msg @@ -443,6 +449,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { upstreams = []string{upstreamOS} } + if p.isAdDomainQuery(req.msg) { + upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} + upstreams = []string{upstreamOS} + } + res := &proxyResponse{} // LAN/PTR lookup flow: @@ -651,6 +662,14 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U return upstreamConfigs } +func (p *prog) isAdDomainQuery(msg *dns.Msg) bool { + if p.adDomain == "" { + return false + } + cDomainName := canonicalName(msg.Question[0].Name) + return dns.IsSubDomain(p.adDomain, cDomainName) +} + // canonicalName returns canonical name from FQDN with "." trimmed. func canonicalName(fqdn string) string { q := strings.TrimSpace(fqdn) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index a68dad2..46d4d18 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -106,6 +106,7 @@ type prog struct { internalLogSent time.Time runningIface string requiredMultiNICsConfig bool + adDomain string selfUninstallMu sync.Mutex refusedQueryCount int @@ -441,6 +442,10 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } } + if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() { + mainLog.Load().Debug().Msgf("active directory domain: %s", domain) + p.adDomain = domain + } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) diff --git a/config.go b/config.go index c88404c..099f75b 100644 --- a/config.go +++ b/config.go @@ -384,7 +384,7 @@ func (uc *UpstreamConfig) IsDiscoverable() bool { return *uc.Discoverable } switch uc.Type { - case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate: + case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate, ResolverTypeLocal: if ip, err := netip.ParseAddr(uc.Domain); err == nil { return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) } diff --git a/resolver.go b/resolver.go index 0097fe0..7dc76b0 100644 --- a/resolver.go +++ b/resolver.go @@ -30,8 +30,10 @@ const ( ResolverTypeOS = "os" // ResolverTypeLegacy specifies legacy resolver. ResolverTypeLegacy = "legacy" - // ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only. + // ResolverTypePrivate is like ResolverTypeOS, but use for private resolver only. ResolverTypePrivate = "private" + // ResolverTypeLocal is like ResolverTypeOS, but use for local resolver only. + ResolverTypeLocal = "local" // ResolverTypeSDNS specifies resolver with information encoded using DNS Stamps. // See: https://dnscrypt.info/stamps-specifications/ ResolverTypeSDNS = "sdns" @@ -47,6 +49,16 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") // or is the Resolver used for ResolverTypeOS. var or = newResolverWithNameserver(defaultNameservers()) +var localResolver = newLocalResolver() + +func newLocalResolver() Resolver { + var nss []string + for _, addr := range Rfc1918Addresses() { + nss = append(nss, net.JoinHostPort(addr, "53")) + } + return NewResolverWithNameserver(nss) +} + // LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network. type LanQueryCtxKey struct{} @@ -89,7 +101,8 @@ func availableNameservers() []string { // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. func InitializeOsResolver() []string { - return initializeOsResolver(availableNameservers()) + ns := initializeOsResolver(availableNameservers()) + return ns } // initializeOsResolver performs logic for choosing OS resolver nameserver. @@ -301,6 +314,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { return &legacyResolver{uc: uc}, nil case ResolverTypePrivate: return NewPrivateResolver(), nil + case ResolverTypeLocal: + return localResolver, nil } return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ) } From 0fbfd160c9940be59f8fdcfc3ecc7d26b8cb8ecf Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 24 Jan 2025 01:39:17 +0700 Subject: [PATCH 051/100] cmd/cli: log interfaces state after dns set The data will be useful for troubleshooting later. --- cmd/cli/dns_proxy.go | 1 + cmd/cli/prog.go | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 5396642..d7eb28a 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1251,6 +1251,7 @@ func (p *prog) reinitializeOSResolver() { mainLog.Load().Debug().Msg("setting DNS configuration") p.setDNS() mainLog.Load().Debug().Msg("DNS configuration set successfully") + p.logInterfacesState() } } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 46d4d18..331f42a 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -273,6 +273,7 @@ func (p *prog) postRun() { p.setDNS() p.csSetDnsDone <- struct{}{} close(p.csSetDnsDone) + p.logInterfacesState() } } @@ -815,6 +816,28 @@ func (p *prog) resetDNS() { } } +func (p *prog) logInterfacesState() { + withEachPhysicalInterfaces("", "", func(i *net.Interface) error { + addrs, err := i.Addrs() + if err != nil { + mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses") + } + nss, err := currentStaticDNS(i) + if err != nil { + mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS") + } + if len(nss) == 0 { + nss = currentDNS(i) + } + mainLog.Load().Debug(). + Any("addrs", addrs). + Strs("nameservers", nss). + Int("index", i.Index). + Msgf("interface state: %s", i.Name) + return nil + }) +} + // findWorkingInterface looks for a network interface with a valid IP configuration func findWorkingInterface(currentIface string) string { // Helper to check if IP is valid (not link-local) From ce3281e70dfb9bedf41269b85fe1955ba756c9b6 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 25 Jan 2025 01:26:48 -0500 Subject: [PATCH 052/100] much more debugging, improved nameserver detection, no more testing nameservers fix logging fix logging try to enable nameserver logs try to enable nameserver logs handle flags in interface state changes debugging debugging debugging fix state detection, AD status fix fix debugging line more dc info always log state changes remove unused method windows AD IP discovery windows AD IP discovery windows AD IP discovery --- cmd/cli/dns_proxy.go | 61 +++++--- go.mod | 1 + go.sum | 4 + nameservers_windows.go | 338 +++++++++++++++++++++++++++++++++++++++-- resolver.go | 292 +++++++++++++++++++---------------- 5 files changed, 538 insertions(+), 158 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index d7eb28a..0d67e88 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -450,6 +450,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } if p.isAdDomainQuery(req.msg) { + ctrld.Log(ctx, mainLog.Load().Debug(), + "AD domain query detected for %s in domain %s", + req.msg.Question[0].Name, p.adDomain) upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig} upstreams = []string{upstreamOS} } @@ -566,14 +569,20 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if upstreamConfig == nil { continue } - ctrld.Log(ctx, mainLog.Load().Debug(), "attempting upstream [ %s ] at index: %d, upstream at index: %s", upstreamConfig.String(), n, upstreams[n]) + logger := mainLog.Load().Debug(). + Str("upstream", upstreamConfig.String()). + Str("query", req.msg.Question[0].Name). + Bool("is_ad_query", p.isAdDomainQuery(req.msg)). + Bool("is_lan_query", isLanOrPtrQuery) if p.isLoop(upstreamConfig) { - mainLog.Load().Warn().Msgf("dns loop detected, upstream: %s", upstreamConfig.String()) + logger.Msg("DNS loop detected") continue } if p.um.isDown(upstreams[n]) { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s is down", upstreams[n]) + logger. + Bool("is_os_resolver", upstreams[n] == upstreamOS). + Msg("Upstream is down") continue } answer := resolve(n, upstreamConfig, req.msg) @@ -1257,7 +1266,6 @@ func (p *prog) reinitializeOSResolver() { // monitorNetworkChanges starts monitoring for network interface changes func (p *prog) monitorNetworkChanges() error { - // Create network monitor mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) if err != nil { return fmt.Errorf("creating network monitor: %w", err) @@ -1267,6 +1275,12 @@ func (p *prog) monitorNetworkChanges() error { // Get map of valid interfaces validIfaces := validInterfacesMap() + // log the delta for debugging + mainLog.Load().Debug(). + Interface("old_state", delta.Old). + Interface("new_state", delta.New). + Msg("Network change detected") + // Parse old and new interface states oldIfs := parseInterfaceState(delta.Old) newIfs := parseInterfaceState(delta.New) @@ -1276,14 +1290,14 @@ func (p *prog) monitorNetworkChanges() error { activeInterfaceExists := false for ifaceName := range validIfaces { - oldState, oldExists := oldIfs[strings.ToLower(ifaceName)] newState, newExists := newIfs[strings.ToLower(ifaceName)] - if newState != "" && newState != "down" { + if newState != "" && !strings.Contains(newState, "down") { activeInterfaceExists = true } + // Compare states directly if oldExists != newExists || oldState != newState { changed = true mainLog.Load().Debug(). @@ -1302,11 +1316,10 @@ func (p *prog) monitorNetworkChanges() error { } if !changed { - mainLog.Load().Debug().Msgf("Ignoring interface change - no valid interfaces affected") + mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") return } - mainLog.Load().Debug().Msgf("Network change detected: from %v to %v", delta.Old, delta.New) if activeInterfaceExists { p.reinitializeOSResolver() } else { @@ -1326,9 +1339,10 @@ func parseInterfaceState(state *netmon.State) map[string]string { } result := make(map[string]string) - - // Extract ifs={...} section + stateStr := state.String() + + // Extract interface information ifsStart := strings.Index(stateStr, "ifs={") if ifsStart == -1 { return result @@ -1340,17 +1354,28 @@ func parseInterfaceState(state *netmon.State) map[string]string { return result } - // Parse each interface entry - ifaces := strings.Split(ifsStr[:ifsEnd], " ") - for _, iface := range ifaces { - parts := strings.Split(iface, ":") + // Get the content between ifs={ } + ifsContent := strings.TrimSpace(ifsStr[:ifsEnd]) + + // Split on "] " to get each interface entry + entries := strings.Split(ifsContent, "] ") + + for _, entry := range entries { + if entry == "" { + continue + } + + // Split on ":[" + parts := strings.Split(entry, ":[") if len(parts) != 2 { continue } - name := strings.ToLower(parts[0]) - state := parts[1] - result[name] = state + + name := strings.TrimSpace(parts[0]) + state := "[" + strings.TrimSuffix(parts[1], "]") + "]" + + result[strings.ToLower(name)] = state } return result -} +} \ No newline at end of file diff --git a/go.mod b/go.mod index a86557e..8e9a8f7 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( require ( aead.dev/minisign v0.2.0 // indirect + github.com/StackExchange/wmi v1.2.1 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index 3eb268a..fcf2ac7 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= +github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= +github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE= github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= @@ -93,6 +95,7 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= @@ -449,6 +452,7 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/nameservers_windows.go b/nameservers_windows.go index 150f252..a8c5191 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -1,44 +1,364 @@ package ctrld import ( + "context" + "fmt" + "net" + "strings" "syscall" + "time" + "unsafe" + "io" + "os" + "github.com/rs/zerolog" + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "github.com/StackExchange/wmi" ) +const ( + maxRetries = 3 + retryDelay = 500 * time.Millisecond + defaultTimeout = 5 * time.Second + minDNSServers = 1 // Minimum number of DNS servers we want to find + NetSetupUnknown uint32 = 0 + NetSetupWorkgroup uint32 = 1 + NetSetupDomain uint32 = 2 + NetSetupCloudDomain uint32 = 3 + DS_FORCE_REDISCOVERY = 0x00000001 + DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 + DS_BACKGROUND_ONLY = 0x00000100 + DS_IP_REQUIRED = 0x00000200 + DS_IS_DNS_NAME = 0x00020000 + DS_RETURN_DNS_NAME = 0x40000000 +) + +type DomainControllerInfo struct { + DomainControllerName *uint16 + DomainControllerAddress *uint16 + DomainControllerAddressType uint32 + DomainGuid windows.GUID + DomainName *uint16 + DnsForestName *uint16 + Flags uint32 + DcSiteName *uint16 + ClientSiteName *uint16 +} + func dnsFns() []dnsFn { return []dnsFn{dnsFromAdapter} } func dnsFromAdapter() []string { - aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix) - if err != nil { - return nil + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + + var ns []string + var err error + + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() } + + for i := 0; i < maxRetries; i++ { + if ctx.Err() != nil { + Log(context.Background(), logger.Debug(), + "dnsFromAdapter lookup cancelled or timed out, attempt %d", i) + return nil + } + + ns, err = getDNSServers(ctx) + if err == nil && len(ns) >= minDNSServers { + if i > 0 { + Log(context.Background(), logger.Debug(), + "Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns)) + } + return ns + } + + // Log the specific failure reason + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get DNS servers, attempt %d: %v", i+1, err) + } else { + Log(context.Background(), logger.Debug(), + "Got insufficient DNS servers, retrying, found %d servers", len(ns)) + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(retryDelay): + } + } + + Log(context.Background(), logger.Debug(), + "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries) + return ns // Return whatever we got, even if insufficient +} + +func getDNSServers(ctx context.Context) ([]string, error) { + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + // Check context before making the call + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // Get DNS servers from adapters (existing method) + flags := winipcfg.GAAFlagIncludeGateways | + winipcfg.GAAFlagIncludePrefix + + aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags) + if err != nil { + return nil, fmt.Errorf("getting adapters: %w", err) + } + + Log(context.Background(), logger.Debug(), + "Found network adapters, count=%d", len(aas)) + + // Try to get domain controller info if domain-joined + var dcServers []string + isDomain := checkDomainJoined() + if isDomain { + + domainName, err := getLocalADDomain() + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get local AD domain: %v", err) + + } else { + + // Load netapi32.dll + netapi32 := windows.NewLazySystemDLL("netapi32.dll") + dsDcName := netapi32.NewProc("DsGetDcNameW") + + var info *DomainControllerInfo + + flags := uint32(DS_RETURN_DNS_NAME | + DS_IP_REQUIRED | + DS_IS_DNS_NAME) + + // Convert domain name to UTF16 pointer + domainUTF16, err := windows.UTF16PtrFromString(domainName) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to convert domain name to UTF16: %v", err) + } else { + Log(context.Background(), logger.Debug(), + "Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags) + + // Call DsGetDcNameW with domain name + ret, _, err := dsDcName.Call( + 0, // ComputerName - can be NULL + uintptr(unsafe.Pointer(domainUTF16)), // DomainName + 0, // DomainGuid - not needed + 0, // SiteName - not needed + uintptr(flags), // Flags + uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output + + if ret != 0 { + switch ret { + case 1355: // ERROR_NO_SUCH_DOMAIN + Log(context.Background(), logger.Debug(), + "Domain not found: %s (%d)", domainName, ret) + case 1311: // ERROR_NO_LOGON_SERVERS + Log(context.Background(), logger.Debug(), + "No logon servers available for domain: %s (%d)", domainName, ret) + case 1004: // ERROR_DC_NOT_FOUND + Log(context.Background(), logger.Debug(), + "Domain controller not found for domain: %s (%d)", domainName, ret) + case 1722: // RPC_S_SERVER_UNAVAILABLE + Log(context.Background(), logger.Debug(), + "RPC server unavailable for domain: %s (%d)", domainName, ret) + default: + Log(context.Background(), logger.Debug(), + "Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err) + } + } else if info != nil { + defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info))) + + // Get DC address + if info.DomainControllerAddress != nil { + dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) + dcAddr = strings.TrimPrefix(dcAddr, "\\\\") + + Log(context.Background(), logger.Debug(), + "Found domain controller address: %s", dcAddr) + + // Try to resolve DC + if ip := net.ParseIP(dcAddr); ip != nil { + dcServers = append(dcServers, ip.String()) + Log(context.Background(), logger.Debug(), + "Added domain controller DNS servers: %v", dcServers) + } + } else { + Log(context.Background(), logger.Debug(), + "No domain controller address found") + } + } + } + + } + } + + // Continue with existing adapter DNS collection ns := make([]string, 0, len(aas)*2) seen := make(map[string]bool) addressMap := make(map[string]struct{}) + + // Collect all local IPs for _, aa := range aas { + if aa.OperStatus != winipcfg.IfOperStatusUp { + Log(context.Background(), logger.Debug(), + "Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus) + continue + } + + Log(context.Background(), logger.Debug(), + "Processing adapter %s", aa.FriendlyName()) + for a := aa.FirstUnicastAddress; a != nil; a = a.Next { - addressMap[a.Address.IP().String()] = struct{}{} + ip := a.Address.IP().String() + addressMap[ip] = struct{}{} + Log(context.Background(), logger.Debug(), + "Added local IP %s from adapter %s", ip, aa.FriendlyName()) } } + + // Collect DNS servers for _, aa := range aas { + if aa.OperStatus != winipcfg.IfOperStatusUp { + continue + } + for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() - if ip == nil || ip.IsLoopback() || seen[ip.String()] { + if ip == nil { + Log(context.Background(), logger.Debug(), + "Skipping nil IP from adapter %s", aa.FriendlyName()) continue } - if _, ok := addressMap[ip.String()]; ok { + + ipStr := ip.String() + logger := logger.Debug(). + Str("ip", ipStr). + Str("adapter", aa.FriendlyName()) + + if ip.IsLoopback() { + logger.Msg("Skipping loopback IP") continue } - seen[ip.String()] = true - ns = append(ns, ip.String()) + + if seen[ipStr] { + logger.Msg("Skipping duplicate IP") + continue + } + + if _, ok := addressMap[ipStr]; ok { + logger.Msg("Skipping local interface IP") + continue + } + + seen[ipStr] = true + ns = append(ns, ipStr) + logger.Msg("Added DNS server") } } - return ns + + // Add DC servers if they're not already in the list + for _, dcServer := range dcServers { + if !seen[dcServer] { + seen[dcServer] = true + ns = append(ns, dcServer) + Log(context.Background(), logger.Debug(), + "Added additional domain controller DNS server: %s", dcServer) + } + } + + if len(ns) == 0 { + return nil, fmt.Errorf("no valid DNS servers found") + } + + Log(context.Background(), logger.Debug(), + "DNS server discovery completed, count=%d, servers=%v (including %d DC servers)", + len(ns), ns, len(dcServers)) + return ns, nil } func nameserversFromResolvconf() []string { return nil } + +// checkDomainJoined checks if the machine is joined to an Active Directory domain +// Returns whether it's domain joined and the domain name if available +func checkDomainJoined() bool { + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + var domain *uint16 + var status uint32 + + err := windows.NetGetJoinInformation(nil, &domain, &status) + if err != nil { + Log(context.Background(), logger.Debug(), + "Failed to get domain join status: %v", err) + return false + } + defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain))) + + domainName := windows.UTF16PtrToString(domain) + Log(context.Background(), logger.Debug(), + "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status) + + // Consider both traditional and cloud domains as valid domain joins + isDomain := status == NetSetupDomain || status == NetSetupCloudDomain + Log(context.Background(), logger.Debug(), + "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", + status, + status == NetSetupDomain, + status == NetSetupCloudDomain, + isDomain) + + return isDomain +} + +// Win32_ComputerSystem is the minimal struct for WMI query +type Win32_ComputerSystem struct { + Domain string +} + +// getLocalADDomain tries to detect the AD domain in two ways: +// 1) USERDNSDOMAIN env var (often set in AD logon sessions) +// 2) WMI Win32_ComputerSystem.Domain +func getLocalADDomain() (string, error) { + // 1) Check environment variable + envDomain := os.Getenv("USERDNSDOMAIN") + if envDomain != "" { + return strings.TrimSpace(envDomain), nil + } + + // 2) Check WMI (requires Windows + admin privileges or sufficient access) + var result []Win32_ComputerSystem + err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result) + if err != nil { + return "", fmt.Errorf("WMI query failed: %v", err) + } + if len(result) == 0 { + return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") + } + + domain := strings.TrimSpace(result[0].Domain) + if domain == "" { + return "", fmt.Errorf("machine does not appear to have a domain set") + } + return domain, nil +} diff --git a/resolver.go b/resolver.go index 7dc76b0..e82b763 100644 --- a/resolver.go +++ b/resolver.go @@ -11,7 +11,9 @@ import ( "sync" "sync/atomic" "time" + "io" + "github.com/rs/zerolog" "github.com/miekg/dns" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" @@ -83,15 +85,40 @@ func availableNameservers() []string { // Ignore local addresses to prevent loop. regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) - for _, v := range slices.Concat(regularIPs, loopbackIPs) { - machineIPsMap[v.String()] = struct{}{} + + + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() } - for _, ns := range nameservers() { + Log(context.Background(), logger.Debug(), + "Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs) + + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + machineIPsMap[ipStr] = struct{}{} + Log(context.Background(), logger.Debug(), + "Added local IP to OS resolverexclusion map: %s", ipStr) + } + + systemNameservers := nameservers() + Log(context.Background(), logger.Debug(), + "Got system nameservers: %v", systemNameservers) + + for _, ns := range systemNameservers { if _, ok := machineIPsMap[ns]; ok { + Log(context.Background(), logger.Debug(), + "Skipping local nameserver: %s", ns) continue } nss = append(nss, ns) + Log(context.Background(), logger.Debug(), + "Added non-local nameserver: %s", ns) } + + Log(context.Background(), logger.Debug(), + "Final available nameservers: %v", nss) return nss } @@ -138,156 +165,159 @@ func initializeOsResolver(servers []string) []string { } or.publicServers.Store(&publicNss) - // Test servers in background and remove failures - go func() { - // Test servers in parallel but maintain order - type result struct { - index int - server string - valid bool - } + // no longer testing servers in the background + // if DCHP nameservers are not working, this is outside of our control - testServers := func(servers []string) []string { - if len(servers) == 0 { - return nil - } + // // Test servers in background and remove failures + // go func() { + // // Test servers in parallel but maintain order + // type result struct { + // index int + // server string + // valid bool + // } - results := make(chan result, len(servers)) - var wg sync.WaitGroup + // testServers := func(servers []string) []string { + // if len(servers) == 0 { + // return nil + // } - for i, server := range servers { - wg.Add(1) - go func(idx int, s string) { - defer wg.Done() - results <- result{ - index: idx, - server: s, - valid: testNameServerFn(s), - } - }(i, server) - } + // results := make(chan result, len(servers)) + // var wg sync.WaitGroup - go func() { - wg.Wait() - close(results) - }() + // for i, server := range servers { + // wg.Add(1) + // go func(idx int, s string) { + // defer wg.Done() + // results <- result{ + // index: idx, + // server: s, + // valid: testNameServerFn(s), + // } + // }(i, server) + // } - // Collect results maintaining original order - validServers := make([]string, 0, len(servers)) - ordered := make([]result, 0, len(servers)) - for r := range results { - ordered = append(ordered, r) - } - slices.SortFunc(ordered, func(a, b result) int { - return a.index - b.index - }) - for _, r := range ordered { - if r.valid { - validServers = append(validServers, r.server) - } else { - ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing") - } - } - return validServers - } + // go func() { + // wg.Wait() + // close(results) + // }() - // Test and update LAN servers - if validLanNss := testServers(lanNss); len(validLanNss) > 0 { - or.lanServers.Store(&validLanNss) - } + // // Collect results maintaining original order + // validServers := make([]string, 0, len(servers)) + // ordered := make([]result, 0, len(servers)) + // for r := range results { + // ordered = append(ordered, r) + // } + // slices.SortFunc(ordered, func(a, b result) int { + // return a.index - b.index + // }) + // for _, r := range ordered { + // if r.valid { + // validServers = append(validServers, r.server) + // } else { + // ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing") + // } + // } + // return validServers + // } - // Test and update public servers - validPublicNss := testServers(publicNss) - if len(validPublicNss) == 0 { - validPublicNss = []string{controldPublicDnsWithPort} - } - or.publicServers.Store(&validPublicNss) - }() + // // Test and update LAN servers + // if validLanNss := testServers(lanNss); len(validLanNss) > 0 { + // or.lanServers.Store(&validLanNss) + // } + + // // Test and update public servers + // validPublicNss := testServers(publicNss) + // if len(validPublicNss) == 0 { + // validPublicNss = []string{controldPublicDnsWithPort} + // } + // or.publicServers.Store(&validPublicNss) + // }() return slices.Concat(lanNss, publicNss) } -// testNameserverFn sends a test query to DNS nameserver to check if the server is available. -var testNameServerFn = testNameserver +// // testNameserverFn sends a test query to DNS nameserver to check if the server is available. +// var testNameServerFn = testNameserver -// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. -func testNameserver(addr string) bool { - // Skip link-local addresses without scope IDs and deprecated site-local addresses - if ip, err := netip.ParseAddr(addr); err == nil { - if ip.Is6() { - if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") { - ProxyLogger.Load().Debug(). - Str("nameserver", addr). - Msg("skipping link-local IPv6 address without scope ID") - return false - } - // Skip deprecated site-local addresses (fec0::/10) - if strings.HasPrefix(ip.String(), "fec0:") { - ProxyLogger.Load().Debug(). - Str("nameserver", addr). - Msg("skipping deprecated site-local IPv6 address") - return false - } - } - } +// // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. +// func testNameserver(addr string) bool { +// // Skip link-local addresses without scope IDs and deprecated site-local addresses +// if ip, err := netip.ParseAddr(addr); err == nil { +// if ip.Is6() { +// if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") { +// ProxyLogger.Load().Debug(). +// Str("nameserver", addr). +// Msg("skipping link-local IPv6 address without scope ID") +// return false +// } +// // Skip deprecated site-local addresses (fec0::/10) +// if strings.HasPrefix(ip.String(), "fec0:") { +// ProxyLogger.Load().Debug(). +// Str("nameserver", addr). +// Msg("skipping deprecated site-local IPv6 address") +// return false +// } +// } +// } - ProxyLogger.Load().Debug(). - Str("input_addr", addr). - Msg("testing nameserver") +// ProxyLogger.Load().Debug(). +// Str("input_addr", addr). +// Msg("testing nameserver") - // Handle both IPv4 and IPv6 addresses - serverAddr := addr - host, port, err := net.SplitHostPort(addr) - if err != nil { - // No port in address, add default port 53 - serverAddr = net.JoinHostPort(addr, "53") - } else if port == "" { - // Has split markers but empty port - serverAddr = net.JoinHostPort(host, "53") - } +// // Handle both IPv4 and IPv6 addresses +// serverAddr := addr +// host, port, err := net.SplitHostPort(addr) +// if err != nil { +// // No port in address, add default port 53 +// serverAddr = net.JoinHostPort(addr, "53") +// } else if port == "" { +// // Has split markers but empty port +// serverAddr = net.JoinHostPort(host, "53") +// } - ProxyLogger.Load().Debug(). - Str("server_addr", serverAddr). - Msg("using server address") +// ProxyLogger.Load().Debug(). +// Str("server_addr", serverAddr). +// Msg("using server address") - // Test domains that are likely to exist and respond quickly - testDomains := []struct { - name string - qtype uint16 - }{ - {".", dns.TypeNS}, // Root NS query - should always work - {"controld.com.", dns.TypeA}, // Fallback to a reliable domain - } +// // Test domains that are likely to exist and respond quickly +// testDomains := []struct { +// name string +// qtype uint16 +// }{ +// {".", dns.TypeNS}, // Root NS query - should always work +// {"controld.com.", dns.TypeA}, // Fallback to a reliable domain +// } - client := &dns.Client{ - Timeout: 2 * time.Second, - Net: "udp", - } +// client := &dns.Client{ +// Timeout: 2 * time.Second, +// Net: "udp", +// } - // Try each test query until one succeeds - for _, test := range testDomains { - msg := new(dns.Msg) - msg.SetQuestion(test.name, test.qtype) - msg.RecursionDesired = true +// // Try each test query until one succeeds +// for _, test := range testDomains { +// msg := new(dns.Msg) +// msg.SetQuestion(test.name, test.qtype) +// msg.RecursionDesired = true - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - resp, _, err := client.ExchangeContext(ctx, msg, serverAddr) - cancel() +// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) +// resp, _, err := client.ExchangeContext(ctx, msg, serverAddr) +// cancel() - if err == nil && resp != nil { - return true - } +// if err == nil && resp != nil { +// return true +// } - ProxyLogger.Load().Error(). - Err(err). - Str("nameserver", serverAddr). - Str("test_domain", test.name). - Str("query_type", dns.TypeToString[test.qtype]). - Msg("DNS availability test failed") - } +// ProxyLogger.Load().Error(). +// Err(err). +// Str("nameserver", serverAddr). +// Str("test_domain", test.name). +// Str("query_type", dns.TypeToString[test.qtype]). +// Msg("DNS availability test failed") +// } - return false -} +// return false +// } // Resolver is the interface that wraps the basic DNS operations. // From e573a490c98bc4a6e8ecb0b0d7741bc1192bb784 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 29 Jan 2025 14:09:53 -0500 Subject: [PATCH 053/100] ignore non physical ifaces in validInterfaces method on Windows debugging skip type 24 in nameserver detection skip type 24 in nameserver detection remove interface type check from valid interfaces for now skip non hardware interfaces in DNS nameserver lookup ignore win api log output set retries to 5 and 1s backoff reset DNS when upgrading to make sure we get the proper OS nameservers on start init running iface for upgrade update windows service options for auto restarts on failure make upgrade use the actual stop and start commands fix the windows service retry logic fix the windows service retry logic task debugging more task debugging windows service name fix windows service name fix fix start command args fix restart delay dont recover from non crash failures fix upgrade flow --- cmd/cli/cli.go | 20 ++-- cmd/cli/commands.go | 70 +++++++------ cmd/cli/dns_proxy.go | 16 +-- cmd/cli/net_windows.go | 22 +++- cmd/cli/prog.go | 4 +- cmd/cli/service.go | 7 +- cmd/cli/service_others.go | 2 + cmd/cli/service_windows.go | 50 ++++++++++ go.mod | 2 +- go.sum | 2 + nameservers_windows.go | 200 ++++++++++++++++++++++++++++--------- 11 files changed, 296 insertions(+), 99 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 7565517..223f14e 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -126,7 +126,7 @@ func initCLI() { rootCmd.CompletionOptions.HiddenDefaultCmd = true initRunCmd() - startCmd := initStartCmd() + startCmd, startCmdAlias := initStartCmd() stopCmd := initStopCmd() restartCmd := initRestartCmd() reloadCmd := initReloadCmd(restartCmd) @@ -135,7 +135,7 @@ func initCLI() { interfacesCmd := initInterfacesCmd() initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd) initClientsCmd() - initUpgradeCmd() + initUpgradeCmd(startCmdAlias) initLogCmd() } @@ -243,6 +243,10 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if err := s.Run(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start service") } + // Configure Windows service failure actions + if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { + mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) + } }() } writeDefaultConfig := !noConfigStart && configBase64 == "" @@ -1016,8 +1020,8 @@ func uninstall(p *prog, s service.Service) { return } tasks := []task{ - {s.Stop, false}, - {s.Uninstall, true}, + {s.Stop, false, "Stop"}, + {s.Uninstall, true, "Uninstall"}, } initLogging() if doTasks(tasks) { @@ -1688,6 +1692,10 @@ func runInCdMode() bool { // curCdUID returns the current ControlD UID used by running ctrld process. func curCdUID() string { if s, _ := newService(&prog{}, svcConfig); s != nil { + // Configure Windows service failure actions + if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { + mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) + } if dir, _ := socketDir(); dir != "" { cc := newSocketControlClient(context.TODO(), s, dir) if cc != nil { @@ -1791,7 +1799,7 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceRe } iface = oldIface return nil - }, false} + }, false, "Reset DNS"} } // doValidateCdRemoteConfig fetches and validates custom config for cdUID. @@ -1840,7 +1848,7 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool { p.resetDNS() - tasks := []task{{s.Uninstall, true}} + tasks := []task{{s.Uninstall, true, "Uninstall"}} if doTasks(tasks) { logger.Info().Msg("uninstalled service") if doStop { diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 8713ea5..f3555e5 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command { return runCmd } -func initStartCmd() *cobra.Command { +func initStartCmd() (*cobra.Command, *cobra.Command) { startCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { checkHasElevatedPrivilege() @@ -310,7 +310,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c initLogging() tasks := []task{ - {s.Stop, false}, + {s.Stop, false, "Stop"}, resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. @@ -321,9 +321,12 @@ NOTE: running "ctrld start" without any arguments will start already installed c return nil }) return nil - }, false}, - {s.Start, true}, - {noticeWritingControlDConfig, false}, + }, false, "Save current DNS"}, + {func() error { + return ConfigureWindowsServiceFailureActions(ctrldServiceName) + }, false, "Configure Windows service failure actions"}, + {s.Start, true, "Start"}, + {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, } mainLog.Load().Notice().Msg("Starting existing ctrld service") if doTasks(tasks) { @@ -387,9 +390,9 @@ NOTE: running "ctrld start" without any arguments will start already installed c } tasks := []task{ - {s.Stop, false}, - {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, - {func() error { return ensureUninstall(s) }, false}, + {s.Stop, false, "Stop"}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"}, + {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { // Save current DNS so we can restore later. @@ -400,12 +403,15 @@ NOTE: running "ctrld start" without any arguments will start already installed c return nil }) return nil - }, false}, - {s.Install, false}, - {s.Start, true}, + }, false, "Save current DNS"}, + {s.Install, false, "Install"}, + {func() error { + return ConfigureWindowsServiceFailureActions(ctrldServiceName) + }, false, "Configure Windows service failure actions"}, + {s.Start, true, "Start"}, // Note that startCmd do not actually write ControlD config, but the config file was // generated after s.Start, so we notice users here for consistent with nextdns mode. - {noticeWritingControlDConfig, false}, + {noticeWritingControlDConfig, false, "Notice writing ControlD config"}, } mainLog.Load().Notice().Msg("Starting service") if doTasks(tasks) { @@ -528,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) rootCmd.AddCommand(startCmdAlias) - return startCmd + return startCmd, startCmdAlias } func initStopCmd() *cobra.Command { @@ -558,7 +564,7 @@ func initStopCmd() *cobra.Command { if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } - if doTasks([]task{{s.Stop, true}}) { + if doTasks([]task{{s.Stop, true, "Stop"}}) { p.router.Cleanup() p.resetDNS() @@ -651,8 +657,8 @@ func initRestartCmd() *cobra.Command { iface = ir.Name } tasks := []task{ - {s.Stop, false}, - {s.Start, true}, + {s.Stop, false, "Stop"}, + {s.Start, true, "Start"}, } if doTasks(tasks) { dir, err := socketDir() @@ -1043,7 +1049,7 @@ func initClientsCmd() *cobra.Command { return clientsCmd } -func initUpgradeCmd() *cobra.Command { +func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command { const ( upgradeChannelDev = "dev" upgradeChannelProd = "prod" @@ -1115,23 +1121,23 @@ func initUpgradeCmd() *cobra.Command { mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") } + // we run the actual commands to make sure all the logic we want is executed doRestart := func() bool { - if !svcInstalled { - return true + + // run the start command so that we reinit the service + // this is to fix the non restarting options on windows for existing clients + // we have to reset os.Args, since other commands use it. + curCdUID := curCdUID() + startArgs := []string{} + os.Args = []string{"ctrld", "start"} + if curCdUID != "" { + startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID)) + os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID)) } - tasks := []task{ - {s.Stop, false}, - {s.Start, false}, - } - if doTasks(tasks) { - if dir, err := socketDir(); err == nil { - if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { - _, _ = cc.post(ifacePath, nil) - return true - } - } - } - return false + startCmd.Run(startCmd, startArgs) + + return true + } if svcInstalled { mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0d67e88..ac808db 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1339,9 +1339,9 @@ func parseInterfaceState(state *netmon.State) map[string]string { } result := make(map[string]string) - + stateStr := state.String() - + // Extract interface information ifsStart := strings.Index(stateStr, "ifs={") if ifsStart == -1 { @@ -1356,26 +1356,26 @@ func parseInterfaceState(state *netmon.State) map[string]string { // Get the content between ifs={ } ifsContent := strings.TrimSpace(ifsStr[:ifsEnd]) - + // Split on "] " to get each interface entry entries := strings.Split(ifsContent, "] ") - + for _, entry := range entries { if entry == "" { continue } - + // Split on ":[" parts := strings.Split(entry, ":[") if len(parts) != 2 { continue } - + name := strings.TrimSpace(parts[0]) state := "[" + strings.TrimSuffix(parts[1], "]") + "]" - + result[strings.ToLower(name)] = state } return result -} \ No newline at end of file +} diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index fe075a3..6290a1c 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -52,23 +52,39 @@ func validInterfaces() []string { mainLog.Load().Warn().Err(err).Msg("failed to get network adapter") continue } + + name, err := adapter.GetPropertyName() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to get interface name") + continue + } + // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) // // "Indicates if a connector is present on the network adapter. This value is set to TRUE // if this is a physical adapter or FALSE if this is not a physical adapter." physical, err := adapter.GetPropertyConnectorPresent() if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get network adapter connector present property") + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property") continue } if !physical { + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter") continue } - name, err := adapter.GetPropertyName() + + // Check if it's a hardware interface. Checking only for connector present is not enough + // because some interfaces are not physical but have a connector. + hardware, err := adapter.GetPropertyHardwareInterface() if err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to get interface name") + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property") continue } + if !hardware { + mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface") + continue + } + adapters = append(adapters, name) } return adapters diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 331f42a..f8147eb 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -45,6 +45,7 @@ const ( upstreamOS = upstreamPrefix + "os" upstreamPrivate = upstreamPrefix + "private" dnsWatchdogDefaultInterval = 20 * time.Second + ctrldServiceName = "ctrld" ) // ControlSocketName returns name for control unix socket. @@ -61,8 +62,9 @@ var logf = func(format string, args ...any) { } var svcConfig = &service.Config{ - Name: "ctrld", + Name: ctrldServiceName, DisplayName: "Control-D Helper Service", + Description: "A highly configurable, multi-protocol DNS forwarding proxy", Option: service.KeyValue{}, } diff --git a/cmd/cli/service.go b/cmd/cli/service.go index e4edfaf..82f144c 100644 --- a/cmd/cli/service.go +++ b/cmd/cli/service.go @@ -156,17 +156,18 @@ func (l *launchd) Status() (service.Status, error) { type task struct { f func() error abortOnError bool + Name string } func doTasks(tasks []task) bool { - var prevErr error for _, task := range tasks { + mainLog.Load().Debug().Msgf("Running task %s", task.Name) if err := task.f(); err != nil { if task.abortOnError { - mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error()) + mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err) return false } - prevErr = err + mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err) } } return true diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index 2303e30..056903c 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -16,3 +16,5 @@ func openLogFile(path string, flags int) (*os.File, error) { // hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running. func hasLocalDnsServerRunning() bool { return false } + +func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil } diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index af4f317..4d3d281 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -2,11 +2,14 @@ package cli import ( "os" + "runtime" "strings" "syscall" + "time" "unsafe" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/mgr" ) func hasElevatedPrivilege() (bool, error) { @@ -30,6 +33,53 @@ func hasElevatedPrivilege() (bool, error) { return token.IsMember(sid) } +// ConfigureWindowsServiceFailureActions checks if the given service +// has the correct failure actions configured, and updates them if not. +func ConfigureWindowsServiceFailureActions(serviceName string) error { + if runtime.GOOS != "windows" { + return nil // no-op on non-Windows + } + + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err != nil { + return err + } + defer s.Close() + + // restart 3 times with a delay of 2 seconds + actions := []mgr.RecoveryAction{ + {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds + } + + // Set the recovery actions (3 restarts, reset period = 120). + err = s.SetRecoveryActions(actions, 120) + if err != nil { + return err + } + + // Ensure that failure actions are NOT triggered on user-initiated stops. + var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG + failureActionsFlag.FailureActionsOnNonCrashFailures = 0 + + if err := windows.ChangeServiceConfig2( + s.Handle, + windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG, + (*byte)(unsafe.Pointer(&failureActionsFlag)), + ); err != nil { + return err + } + + return nil +} + func openLogFile(path string, mode int) (*os.File, error) { if len(path) == 0 { return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND} diff --git a/go.mod b/go.mod index 8e9a8f7..e570bae 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/vishvananda/netlink v1.2.1-beta.2 golang.org/x/net v0.33.0 golang.org/x/sync v0.10.0 - golang.org/x/sys v0.28.0 + golang.org/x/sys v0.29.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.74.0 ) diff --git a/go.sum b/go.sum index fcf2ac7..f2d5ff9 100644 --- a/go.sum +++ b/go.sum @@ -494,6 +494,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/nameservers_windows.go b/nameservers_windows.go index a8c5191..c71e065 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -3,35 +3,41 @@ package ctrld import ( "context" "fmt" + "io" + "log" "net" + "os" "strings" "syscall" "time" "unsafe" - "io" - "os" + "github.com/StackExchange/wmi" + "github.com/microsoft/wmi/pkg/base/host" + "github.com/microsoft/wmi/pkg/base/instance" + "github.com/microsoft/wmi/pkg/base/query" + "github.com/microsoft/wmi/pkg/constant" + "github.com/microsoft/wmi/pkg/hardware/network/netadapter" "github.com/rs/zerolog" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "github.com/StackExchange/wmi" ) const ( - maxRetries = 3 - retryDelay = 500 * time.Millisecond - defaultTimeout = 5 * time.Second - minDNSServers = 1 // Minimum number of DNS servers we want to find - NetSetupUnknown uint32 = 0 - NetSetupWorkgroup uint32 = 1 - NetSetupDomain uint32 = 2 - NetSetupCloudDomain uint32 = 3 - DS_FORCE_REDISCOVERY = 0x00000001 - DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 - DS_BACKGROUND_ONLY = 0x00000100 - DS_IP_REQUIRED = 0x00000200 - DS_IS_DNS_NAME = 0x00020000 - DS_RETURN_DNS_NAME = 0x40000000 + maxRetries = 5 + retryDelay = 1 * time.Second + defaultTimeout = 5 * time.Second + minDNSServers = 1 // Minimum number of DNS servers we want to find + NetSetupUnknown uint32 = 0 + NetSetupWorkgroup uint32 = 1 + NetSetupDomain uint32 = 2 + NetSetupCloudDomain uint32 = 3 + DS_FORCE_REDISCOVERY = 0x00000001 + DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010 + DS_BACKGROUND_ONLY = 0x00000100 + DS_IP_REQUIRED = 0x00000200 + DS_IS_DNS_NAME = 0x00020000 + DS_RETURN_DNS_NAME = 0x40000000 ) type DomainControllerInfo struct { @@ -132,7 +138,7 @@ func getDNSServers(ctx context.Context) ([]string, error) { if err != nil { Log(context.Background(), logger.Debug(), "Failed to get local AD domain: %v", err) - + } else { // Load netapi32.dll @@ -141,9 +147,9 @@ func getDNSServers(ctx context.Context) ([]string, error) { var info *DomainControllerInfo - flags := uint32(DS_RETURN_DNS_NAME | - DS_IP_REQUIRED | - DS_IS_DNS_NAME) + flags := uint32(DS_RETURN_DNS_NAME | + DS_IP_REQUIRED | + DS_IS_DNS_NAME) // Convert domain name to UTF16 pointer domainUTF16, err := windows.UTF16PtrFromString(domainName) @@ -221,6 +227,14 @@ func getDNSServers(ctx context.Context) ([]string, error) { continue } + // Skip if software loopback or other non-physical types + // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows + if aa.IfType == winipcfg.IfTypeSoftwareLoopback { + Log(context.Background(), logger.Debug(), + "Skipping %s (software loopback)", aa.FriendlyName()) + continue + } + Log(context.Background(), logger.Debug(), "Processing adapter %s", aa.FriendlyName()) @@ -232,12 +246,29 @@ func getDNSServers(ctx context.Context) ([]string, error) { } } + validInterfacesMap := validInterfaces() + // Collect DNS servers for _, aa := range aas { if aa.OperStatus != winipcfg.IfOperStatusUp { continue } + // Skip if software loopback or other non-physical types + // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows + if aa.IfType == winipcfg.IfTypeSoftwareLoopback { + Log(context.Background(), logger.Debug(), + "Skipping %s (software loopback)", aa.FriendlyName()) + continue + } + + // if not in the validInterfacesMap, skip + if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok { + Log(context.Background(), logger.Debug(), + "Skipping %s (not in validInterfacesMap)", aa.FriendlyName()) + continue + } + for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next { ip := dns.Address.IP() if ip == nil { @@ -322,8 +353,8 @@ func checkDomainJoined() bool { // Consider both traditional and cloud domains as valid domain joins isDomain := status == NetSetupDomain || status == NetSetupCloudDomain Log(context.Background(), logger.Debug(), - "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", - status, + "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", + status, status == NetSetupDomain, status == NetSetupCloudDomain, isDomain) @@ -333,32 +364,111 @@ func checkDomainJoined() bool { // Win32_ComputerSystem is the minimal struct for WMI query type Win32_ComputerSystem struct { - Domain string + Domain string } // getLocalADDomain tries to detect the AD domain in two ways: -// 1) USERDNSDOMAIN env var (often set in AD logon sessions) -// 2) WMI Win32_ComputerSystem.Domain +// 1. USERDNSDOMAIN env var (often set in AD logon sessions) +// 2. WMI Win32_ComputerSystem.Domain func getLocalADDomain() (string, error) { - // 1) Check environment variable - envDomain := os.Getenv("USERDNSDOMAIN") - if envDomain != "" { - return strings.TrimSpace(envDomain), nil - } + // 1) Check environment variable + envDomain := os.Getenv("USERDNSDOMAIN") + if envDomain != "" { + return strings.TrimSpace(envDomain), nil + } - // 2) Check WMI (requires Windows + admin privileges or sufficient access) - var result []Win32_ComputerSystem - err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result) - if err != nil { - return "", fmt.Errorf("WMI query failed: %v", err) - } - if len(result) == 0 { - return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") - } + // 2) Check WMI (requires Windows + admin privileges or sufficient access) + var result []Win32_ComputerSystem + err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result) + if err != nil { + return "", fmt.Errorf("WMI query failed: %v", err) + } + if len(result) == 0 { + return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") + } + + domain := strings.TrimSpace(result[0].Domain) + if domain == "" { + return "", fmt.Errorf("machine does not appear to have a domain set") + } + return domain, nil +} + +// validInterfaces returns a list of all physical interfaces. +// this is a duplicate of what is in net_windows.go, we should +// clean this up so there is only one version +func validInterfaces() map[string]struct{} { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + //load the logger + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("MSFT_NetAdapter") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get wmi network adapter: %v", err) + return nil + } + defer instances.Close() + var adapters []string + for _, i := range instances { + adapter, err := netadapter.NewNetworkAdapter(i) + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get network adapter: %v", err) + continue + } + + name, err := adapter.GetPropertyName() + if err != nil { + Log(context.Background(), logger.Warn(), + "failed to get interface name: %v", err) + continue + } + + // From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85) + // + // "Indicates if a connector is present on the network adapter. This value is set to TRUE + // if this is a physical adapter or FALSE if this is not a physical adapter." + physical, err := adapter.GetPropertyConnectorPresent() + if err != nil { + Log(context.Background(), logger.Debug(), + "failed to get network adapter connector present property: %v", err) + continue + } + if !physical { + Log(context.Background(), logger.Debug(), + "skipping non-physical adapter: %s", name) + continue + } + + // Check if it's a hardware interface. Checking only for connector present is not enough + // because some interfaces are not physical but have a connector. + hardware, err := adapter.GetPropertyHardwareInterface() + if err != nil { + Log(context.Background(), logger.Debug(), + "failed to get network adapter hardware interface property: %v", err) + continue + } + if !hardware { + Log(context.Background(), logger.Debug(), + "skipping non-hardware interface: %s", name) + continue + } + + adapters = append(adapters, name) + } + + m := make(map[string]struct{}) + for _, ifaceName := range adapters { + m[ifaceName] = struct{}{} + } + return m - domain := strings.TrimSpace(result[0].Domain) - if domain == "" { - return "", fmt.Errorf("machine does not appear to have a domain set") - } - return domain, nil } From f7a6dbe39b10a124efea4931be00f3cbf91693a2 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 30 Jan 2025 05:09:51 -0500 Subject: [PATCH 054/100] fix upgrade flow set service on new run, fix duplicate args set service on new run, fix duplicate args revert startCmd in upgrade flow due to pin compat issues make restart reset DNS like upgrade, add debugging to uninstall method debugging debugging debugging debugging debugging WMI remove stackexchange lib, use ms wmi pkg debugging debugging set correct class fix os reolver init issues fix netadapter class use os resolver instead of fetching default nameservers while already running remove debug lines fix lookup IP fix lookup IP fix lookup IP fix lookup IP fix dns namserver retries when not needed --- cmd/cli/ad_windows.go | 4 +- cmd/cli/cli.go | 20 ++-- cmd/cli/commands.go | 138 ++++++++++++++++++++++----- cmd/cli/net_windows.go | 4 +- cmd/cli/prog.go | 1 + cmd/cli/service_windows.go | 16 +++- go.mod | 1 - go.sum | 6 -- nameservers_windows.go | 103 ++++++++++---------- resolver.go | 188 ++++--------------------------------- 10 files changed, 221 insertions(+), 260 deletions(-) diff --git a/cmd/cli/ad_windows.go b/cmd/cli/ad_windows.go index 3f9fa17..66180a9 100644 --- a/cmd/cli/ad_windows.go +++ b/cmd/cli/ad_windows.go @@ -56,10 +56,12 @@ func getActiveDirectoryDomain() (string, error) { defer log.SetOutput(os.Stderr) whost := host.NewWmiLocalHost() cs, err := hh.GetComputerSystem(whost) + if cs != nil { + defer cs.Close() + } if err != nil { return "", err } - defer cs.Close() pod, err := cs.GetPropertyPartOfDomain() if err != nil { return "", err diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 223f14e..af5bb75 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -126,7 +126,7 @@ func initCLI() { rootCmd.CompletionOptions.HiddenDefaultCmd = true initRunCmd() - startCmd, startCmdAlias := initStartCmd() + startCmd := initStartCmd() stopCmd := initStopCmd() restartCmd := initRestartCmd() reloadCmd := initReloadCmd(restartCmd) @@ -135,7 +135,7 @@ func initCLI() { interfacesCmd := initInterfacesCmd() initServicesCmd(startCmd, stopCmd, restartCmd, reloadCmd, statusCmd, uninstallCmd, interfacesCmd) initClientsCmd() - initUpgradeCmd(startCmdAlias) + initUpgradeCmd() initLogCmd() } @@ -243,10 +243,6 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if err := s.Run(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to start service") } - // Configure Windows service failure actions - if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { - mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) - } }() } writeDefaultConfig := !noConfigStart && configBase64 == "" @@ -394,6 +390,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } } } + // Configure Windows service failure actions + _ = ConfigureWindowsServiceFailureActions(ctrldServiceName) }) p.onStopped = append(p.onStopped, func() { for _, lc := range p.cfg.Listener { @@ -1615,22 +1613,27 @@ var errRequiredDeactivationPin = errors.New("deactivation pin is required to sto // checkDeactivationPin validates if the deactivation pin matches one in ControlD config. func checkDeactivationPin(s service.Service, stopCh chan struct{}) error { + mainLog.Load().Debug().Msg("Checking deactivation pin") dir, err := socketDir() if err != nil { mainLog.Load().Err(err).Msg("could not check deactivation pin") return err } + mainLog.Load().Debug().Msg("Creating control client") var cc *controlClient if s == nil { cc = newSocketControlClientMobile(dir, stopCh) } else { cc = newSocketControlClient(context.TODO(), s, dir) } + mainLog.Load().Debug().Msg("Control client done") if cc == nil { return nil // ctrld is not running. } data, _ := json.Marshal(&deactivationRequest{Pin: deactivationPin}) - resp, _ := cc.post(deactivationPath, bytes.NewReader(data)) + mainLog.Load().Debug().Msg("Posting deactivation request") + resp, err := cc.post(deactivationPath, bytes.NewReader(data)) + mainLog.Load().Debug().Msg("Posting deactivation request done") if resp != nil { switch resp.StatusCode { case http.StatusBadRequest: @@ -1694,7 +1697,7 @@ func curCdUID() string { if s, _ := newService(&prog{}, svcConfig); s != nil { // Configure Windows service failure actions if err := ConfigureWindowsServiceFailureActions(ctrldServiceName); err != nil { - mainLog.Load().Error().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) + mainLog.Load().Debug().Err(err).Msgf("failed to configure Windows service %s failure actions", ctrldServiceName) } if dir, _ := socketDir(); dir != "" { cc := newSocketControlClient(context.TODO(), s, dir) @@ -1777,6 +1780,7 @@ func resetDnsNoLog(p *prog) { func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task { return task{func() error { if iface == "" { + mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask") return nil } // Always reset DNS first, ensuring DNS setting is in a good state. diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index f3555e5..70d4467 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -164,7 +164,7 @@ func initRunCmd() *cobra.Command { return runCmd } -func initStartCmd() (*cobra.Command, *cobra.Command) { +func initStartCmd() *cobra.Command { startCmd := &cobra.Command{ PreRun: func(cmd *cobra.Command, args []string) { checkHasElevatedPrivilege() @@ -391,7 +391,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c tasks := []task{ {s.Stop, false, "Stop"}, - {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Generate NextDNS config"}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"}, {func() error { return ensureUninstall(s) }, false, "Ensure uninstall"}, resetDnsTask(p, s, isCtrldInstalled, currentIface), {func() error { @@ -534,7 +534,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) rootCmd.AddCommand(startCmdAlias) - return startCmd, startCmdAlias + return startCmd } func initStopCmd() *cobra.Command { @@ -647,6 +647,15 @@ func initRestartCmd() *cobra.Command { mainLog.Load().Warn().Msg("service not installed") return } + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } + initLogging() if cdMode { @@ -656,11 +665,53 @@ func initRestartCmd() *cobra.Command { if ir := runningIface(s); ir != nil { iface = ir.Name } - tasks := []task{ - {s.Stop, false, "Stop"}, - {s.Start, true, "Start"}, + + doRestart := func() bool { + tasks := []task{ + {s.Stop, true, "Stop"}, + {func() error { + p.router.Cleanup() + p.resetDNS() + return nil + }, false, "Cleanup"}, + {func() error { + time.Sleep(time.Second * 1) + return nil + }, false, "Waiting for service to stop"}, + } + if doTasks(tasks) { + + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + loop: + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + break loop + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } + } else { + return false + } + + tasks = []task{ + {s.Start, true, "Start"}, + } + + return doTasks(tasks) + } - if doTasks(tasks) { + + if doRestart() { dir, err := socketDir() if err != nil { mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") @@ -668,11 +719,13 @@ func initRestartCmd() *cobra.Command { } cc := newSocketControlClient(context.TODO(), s, dir) if cc == nil { - mainLog.Load().Notice().Msg("Service was not restarted") + mainLog.Load().Error().Msg("Could not complete service restart") os.Exit(1) } _, _ = cc.post(ifacePath, nil) mainLog.Load().Notice().Msg("Service restarted") + } else { + mainLog.Load().Error().Msg("Service restart failed") } }, } @@ -1049,7 +1102,7 @@ func initClientsCmd() *cobra.Command { return clientsCmd } -func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command { +func initUpgradeCmd() *cobra.Command { const ( upgradeChannelDev = "dev" upgradeChannelProd = "prod" @@ -1087,6 +1140,14 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command { mainLog.Load().Error().Msg(err.Error()) return } + if iface == "" { + iface = "auto" + } + p.preRun() + if ir := runningIface(s); ir != nil { + p.runningIface = ir.Name + p.requiredMultiNICsConfig = ir.All + } svcInstalled := true if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { @@ -1121,23 +1182,56 @@ func initUpgradeCmd(startCmd *cobra.Command) *cobra.Command { mainLog.Load().Fatal().Err(err).Msg("failed to update current binary") } - // we run the actual commands to make sure all the logic we want is executed doRestart := func() bool { - - // run the start command so that we reinit the service - // this is to fix the non restarting options on windows for existing clients - // we have to reset os.Args, since other commands use it. - curCdUID := curCdUID() - startArgs := []string{} - os.Args = []string{"ctrld", "start"} - if curCdUID != "" { - startArgs = append(startArgs, fmt.Sprintf("--cd=%s", curCdUID)) - os.Args = append(os.Args, fmt.Sprintf("--cd=%s", curCdUID)) + if !svcInstalled { + return true } - startCmd.Run(startCmd, startArgs) + tasks := []task{ + {s.Stop, true, "Stop"}, + {func() error { + p.router.Cleanup() + p.resetDNS() + return nil + }, false, "Cleanup"}, + {func() error { + time.Sleep(time.Second * 1) + return nil + }, false, "Waiting for service to stop"}, + } + if doTasks(tasks) { - return true + if router.WaitProcessExited() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + loop: + for { + select { + case <-ctx.Done(): + mainLog.Load().Error().Msg("timeout while waiting for service to stop") + break loop + default: + } + time.Sleep(time.Second) + if status, _ := s.Status(); status == service.StatusStopped { + break + } + } + } + } + + tasks = []task{ + {s.Start, true, "Start"}, + } + if doTasks(tasks) { + if dir, err := socketDir(); err == nil { + if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil { + _, _ = cc.post(ifacePath, nil) + return true + } + } + } + return false } if svcInstalled { mainLog.Load().Debug().Msg("Restarting ctrld service using new binary") diff --git a/cmd/cli/net_windows.go b/cmd/cli/net_windows.go index 6290a1c..bed06b5 100644 --- a/cmd/cli/net_windows.go +++ b/cmd/cli/net_windows.go @@ -40,11 +40,13 @@ func validInterfaces() []string { whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if instances != nil { + defer instances.Close() + } if err != nil { mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter") return nil } - defer instances.Close() var adapters []string for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index f8147eb..8390680 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -792,6 +792,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces func (p *prog) resetDNS() { if p.runningIface == "" { + mainLog.Load().Debug().Msg("no running interface, skipping resetDNS") return } // See corresponding comments in (*prog).setDNS function. diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index 4d3d281..6e3bd82 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -52,7 +52,21 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error { } defer s.Close() - // restart 3 times with a delay of 2 seconds + // 1. Retrieve the current config + cfg, err := s.Config() + if err != nil { + return err + } + + // 2. Update the Description + cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy" + + // 3. Apply the updated config + if err := s.UpdateConfig(cfg); err != nil { + return err + } + + // Then proceed with existing actions, e.g. setting failure actions actions := []mgr.RecoveryAction{ {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds diff --git a/go.mod b/go.mod index e570bae..635261f 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,6 @@ require ( require ( aead.dev/minisign v0.2.0 // indirect - github.com/StackExchange/wmi v1.2.1 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index f2d5ff9..2ac97af 100644 --- a/go.sum +++ b/go.sum @@ -42,8 +42,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= -github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDOSA= -github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8= github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE= github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= @@ -95,7 +93,6 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= @@ -452,7 +449,6 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -492,8 +488,6 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= diff --git a/nameservers_windows.go b/nameservers_windows.go index c71e065..54fb8b6 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -12,7 +12,6 @@ import ( "time" "unsafe" - "github.com/StackExchange/wmi" "github.com/microsoft/wmi/pkg/base/host" "github.com/microsoft/wmi/pkg/base/instance" "github.com/microsoft/wmi/pkg/base/query" @@ -24,9 +23,9 @@ import ( ) const ( - maxRetries = 5 - retryDelay = 1 * time.Second - defaultTimeout = 5 * time.Second + maxDNSAdapterRetries = 5 + retryDelayDNSAdapter = 1 * time.Second + defaultDNSAdapterTimeout = 10 * time.Second minDNSServers = 1 // Minimum number of DNS servers we want to find NetSetupUnknown uint32 = 0 NetSetupWorkgroup uint32 = 1 @@ -57,19 +56,18 @@ func dnsFns() []dnsFn { } func dnsFromAdapter() []string { - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout) defer cancel() var ns []string var err error - //load the logger logger := zerolog.New(io.Discard) if ProxyLogger.Load() != nil { logger = *ProxyLogger.Load() } - for i := 0; i < maxRetries; i++ { + for i := 0; i < maxDNSAdapterRetries; i++ { if ctx.Err() != nil { Log(context.Background(), logger.Debug(), "dnsFromAdapter lookup cancelled or timed out, attempt %d", i) @@ -80,12 +78,18 @@ func dnsFromAdapter() []string { if err == nil && len(ns) >= minDNSServers { if i > 0 { Log(context.Background(), logger.Debug(), - "Successfully got DNS servers after %d attempts, found %d servers", i+1, len(ns)) + "Successfully got DNS servers after %d attempts, found %d servers", + i+1, len(ns)) } return ns } - // Log the specific failure reason + // if osResolver is not initialized, this is likely a command line run + // and ctrld is already on the interface, abort retries + if or == nil { + return ns + } + if err != nil { Log(context.Background(), logger.Debug(), "Failed to get DNS servers, attempt %d: %v", i+1, err) @@ -97,17 +101,16 @@ func dnsFromAdapter() []string { select { case <-ctx.Done(): return nil - case <-time.After(retryDelay): + case <-time.After(retryDelayDNSAdapter): } } Log(context.Background(), logger.Debug(), - "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxRetries) - return ns // Return whatever we got, even if insufficient + "Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries) + return ns } func getDNSServers(ctx context.Context) ([]string, error) { - //load the logger logger := zerolog.New(io.Discard) if ProxyLogger.Load() != nil { logger = *ProxyLogger.Load() @@ -133,25 +136,18 @@ func getDNSServers(ctx context.Context) ([]string, error) { var dcServers []string isDomain := checkDomainJoined() if isDomain { - domainName, err := getLocalADDomain() if err != nil { Log(context.Background(), logger.Debug(), "Failed to get local AD domain: %v", err) - } else { - // Load netapi32.dll netapi32 := windows.NewLazySystemDLL("netapi32.dll") dsDcName := netapi32.NewProc("DsGetDcNameW") var info *DomainControllerInfo + flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME) - flags := uint32(DS_RETURN_DNS_NAME | - DS_IP_REQUIRED | - DS_IS_DNS_NAME) - - // Convert domain name to UTF16 pointer domainUTF16, err := windows.UTF16PtrFromString(domainName) if err != nil { Log(context.Background(), logger.Debug(), @@ -190,15 +186,12 @@ func getDNSServers(ctx context.Context) ([]string, error) { } else if info != nil { defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info))) - // Get DC address if info.DomainControllerAddress != nil { dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress) dcAddr = strings.TrimPrefix(dcAddr, "\\\\") - Log(context.Background(), logger.Debug(), "Found domain controller address: %s", dcAddr) - // Try to resolve DC if ip := net.ParseIP(dcAddr); ip != nil { dcServers = append(dcServers, ip.String()) Log(context.Background(), logger.Debug(), @@ -210,7 +203,6 @@ func getDNSServers(ctx context.Context) ([]string, error) { } } } - } } @@ -278,28 +270,26 @@ func getDNSServers(ctx context.Context) ([]string, error) { } ipStr := ip.String() - logger := logger.Debug(). + l := logger.Debug(). Str("ip", ipStr). Str("adapter", aa.FriendlyName()) if ip.IsLoopback() { - logger.Msg("Skipping loopback IP") + l.Msg("Skipping loopback IP") continue } - if seen[ipStr] { - logger.Msg("Skipping duplicate IP") + l.Msg("Skipping duplicate IP") continue } - if _, ok := addressMap[ipStr]; ok { - logger.Msg("Skipping local interface IP") + l.Msg("Skipping local interface IP") continue } seen[ipStr] = true ns = append(ns, ipStr) - logger.Msg("Added DNS server") + l.Msg("Added DNS server") } } @@ -330,7 +320,6 @@ func nameserversFromResolvconf() []string { // checkDomainJoined checks if the machine is joined to an Active Directory domain // Returns whether it's domain joined and the domain name if available func checkDomainJoined() bool { - //load the logger logger := zerolog.New(io.Discard) if ProxyLogger.Load() != nil { logger = *ProxyLogger.Load() @@ -348,9 +337,10 @@ func checkDomainJoined() bool { domainName := windows.UTF16PtrToString(domain) Log(context.Background(), logger.Debug(), - "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", domainName, status) + "Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)", + domainName, status) - // Consider both traditional and cloud domains as valid domain joins + // Consider domain or cloud domain as domain-joined isDomain := status == NetSetupDomain || status == NetSetupCloudDomain Log(context.Background(), logger.Debug(), "Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v", @@ -362,36 +352,44 @@ func checkDomainJoined() bool { return isDomain } -// Win32_ComputerSystem is the minimal struct for WMI query -type Win32_ComputerSystem struct { - Domain string -} - -// getLocalADDomain tries to detect the AD domain in two ways: -// 1. USERDNSDOMAIN env var (often set in AD logon sessions) -// 2. WMI Win32_ComputerSystem.Domain +// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*) +// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call. func getLocalADDomain() (string, error) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) // 1) Check environment variable envDomain := os.Getenv("USERDNSDOMAIN") if envDomain != "" { return strings.TrimSpace(envDomain), nil } - // 2) Check WMI (requires Windows + admin privileges or sufficient access) - var result []Win32_ComputerSystem - err := wmi.Query("SELECT Domain FROM Win32_ComputerSystem", &result) + // 2) Query WMI via the microsoft/wmi library + whost := host.NewWmiLocalHost() + q := query.NewWmiQuery("Win32_ComputerSystem") + instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q) + if instances != nil { + defer instances.Close() + } if err != nil { return "", fmt.Errorf("WMI query failed: %v", err) } - if len(result) == 0 { + + // If no results, return an error + if len(instances) == 0 { return "", fmt.Errorf("no rows returned from Win32_ComputerSystem") } - domain := strings.TrimSpace(result[0].Domain) - if domain == "" { + // We only care about the first row + domainVal, err := instances[0].GetProperty("Domain") + if err != nil { + return "", fmt.Errorf("machine does not appear to have a domain set: %v", err) + } + + domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal)) + if domainName == "" { return "", fmt.Errorf("machine does not appear to have a domain set") } - return domain, nil + return domainName, nil } // validInterfaces returns a list of all physical interfaces. @@ -410,12 +408,14 @@ func validInterfaces() map[string]struct{} { whost := host.NewWmiLocalHost() q := query.NewWmiQuery("MSFT_NetAdapter") instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q) + if instances != nil { + defer instances.Close() + } if err != nil { Log(context.Background(), logger.Warn(), "failed to get wmi network adapter: %v", err) return nil } - defer instances.Close() var adapters []string for _, i := range instances { adapter, err := netadapter.NewNetworkAdapter(i) @@ -470,5 +470,4 @@ func validInterfaces() map[string]struct{} { m[ifaceName] = struct{}{} } return m - } diff --git a/resolver.go b/resolver.go index e82b763..49b81af 100644 --- a/resolver.go +++ b/resolver.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "net/netip" "slices" @@ -11,10 +12,9 @@ import ( "sync" "sync/atomic" "time" - "io" - "github.com/rs/zerolog" "github.com/miekg/dns" + "github.com/rs/zerolog" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" ) @@ -48,11 +48,13 @@ const ( var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") -// or is the Resolver used for ResolverTypeOS. -var or = newResolverWithNameserver(defaultNameservers()) - var localResolver = newLocalResolver() +var ( + resolverMutex sync.Mutex + or *osResolver +) + func newLocalResolver() Resolver { var nss []string for _, addr := range Rfc1918Addresses() { @@ -86,7 +88,6 @@ func availableNameservers() []string { regularIPs, loopbackIPs, _ := netmon.LocalAddresses() machineIPsMap := make(map[string]struct{}, len(regularIPs)) - //load the logger logger := zerolog.New(io.Discard) if ProxyLogger.Load() != nil { @@ -129,6 +130,9 @@ func availableNameservers() []string { // calling this function. func InitializeOsResolver() []string { ns := initializeOsResolver(availableNameservers()) + resolverMutex.Lock() + defer resolverMutex.Unlock() + or = newResolverWithNameserver(ns) return ns } @@ -138,6 +142,7 @@ func InitializeOsResolver() []string { // - First available LAN servers are saved and store. // - Later calls, if no LAN servers available, the saved servers above will be used. func initializeOsResolver(servers []string) []string { + var lanNss, publicNss []string // First categorize servers @@ -154,171 +159,13 @@ func initializeOsResolver(servers []string) []string { } } - // Store initial servers immediately - if len(lanNss) > 0 { - or.initializedLanServers.CompareAndSwap(nil, &lanNss) - or.lanServers.Store(&lanNss) - } - if len(publicNss) == 0 { publicNss = []string{controldPublicDnsWithPort} } - or.publicServers.Store(&publicNss) - - // no longer testing servers in the background - // if DCHP nameservers are not working, this is outside of our control - - // // Test servers in background and remove failures - // go func() { - // // Test servers in parallel but maintain order - // type result struct { - // index int - // server string - // valid bool - // } - - // testServers := func(servers []string) []string { - // if len(servers) == 0 { - // return nil - // } - - // results := make(chan result, len(servers)) - // var wg sync.WaitGroup - - // for i, server := range servers { - // wg.Add(1) - // go func(idx int, s string) { - // defer wg.Done() - // results <- result{ - // index: idx, - // server: s, - // valid: testNameServerFn(s), - // } - // }(i, server) - // } - - // go func() { - // wg.Wait() - // close(results) - // }() - - // // Collect results maintaining original order - // validServers := make([]string, 0, len(servers)) - // ordered := make([]result, 0, len(servers)) - // for r := range results { - // ordered = append(ordered, r) - // } - // slices.SortFunc(ordered, func(a, b result) int { - // return a.index - b.index - // }) - // for _, r := range ordered { - // if r.valid { - // validServers = append(validServers, r.server) - // } else { - // ProxyLogger.Load().Debug().Str("nameserver", r.server).Msg("nameserver failed validation testing") - // } - // } - // return validServers - // } - - // // Test and update LAN servers - // if validLanNss := testServers(lanNss); len(validLanNss) > 0 { - // or.lanServers.Store(&validLanNss) - // } - - // // Test and update public servers - // validPublicNss := testServers(publicNss) - // if len(validPublicNss) == 0 { - // validPublicNss = []string{controldPublicDnsWithPort} - // } - // or.publicServers.Store(&validPublicNss) - // }() return slices.Concat(lanNss, publicNss) } -// // testNameserverFn sends a test query to DNS nameserver to check if the server is available. -// var testNameServerFn = testNameserver - -// // testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available. -// func testNameserver(addr string) bool { -// // Skip link-local addresses without scope IDs and deprecated site-local addresses -// if ip, err := netip.ParseAddr(addr); err == nil { -// if ip.Is6() { -// if ip.IsLinkLocalUnicast() && !strings.Contains(addr, "%") { -// ProxyLogger.Load().Debug(). -// Str("nameserver", addr). -// Msg("skipping link-local IPv6 address without scope ID") -// return false -// } -// // Skip deprecated site-local addresses (fec0::/10) -// if strings.HasPrefix(ip.String(), "fec0:") { -// ProxyLogger.Load().Debug(). -// Str("nameserver", addr). -// Msg("skipping deprecated site-local IPv6 address") -// return false -// } -// } -// } - -// ProxyLogger.Load().Debug(). -// Str("input_addr", addr). -// Msg("testing nameserver") - -// // Handle both IPv4 and IPv6 addresses -// serverAddr := addr -// host, port, err := net.SplitHostPort(addr) -// if err != nil { -// // No port in address, add default port 53 -// serverAddr = net.JoinHostPort(addr, "53") -// } else if port == "" { -// // Has split markers but empty port -// serverAddr = net.JoinHostPort(host, "53") -// } - -// ProxyLogger.Load().Debug(). -// Str("server_addr", serverAddr). -// Msg("using server address") - -// // Test domains that are likely to exist and respond quickly -// testDomains := []struct { -// name string -// qtype uint16 -// }{ -// {".", dns.TypeNS}, // Root NS query - should always work -// {"controld.com.", dns.TypeA}, // Fallback to a reliable domain -// } - -// client := &dns.Client{ -// Timeout: 2 * time.Second, -// Net: "udp", -// } - -// // Try each test query until one succeeds -// for _, test := range testDomains { -// msg := new(dns.Msg) -// msg.SetQuestion(test.name, test.qtype) -// msg.RecursionDesired = true - -// ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) -// resp, _, err := client.ExchangeContext(ctx, msg, serverAddr) -// cancel() - -// if err == nil && resp != nil { -// return true -// } - -// ProxyLogger.Load().Error(). -// Err(err). -// Str("nameserver", serverAddr). -// Str("test_domain", test.name). -// Str("query_type", dns.TypeToString[test.qtype]). -// Msg("DNS availability test failed") -// } - -// return false -// } - // Resolver is the interface that wraps the basic DNS operations. // // Resolve resolves the DNS query, return the result and the corresponding error. @@ -339,6 +186,9 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeDOQ: return &doqResolver{uc: uc}, nil case ResolverTypeOS: + if or == nil { + or = newResolverWithNameserver(defaultNameservers()) + } return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil @@ -351,9 +201,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { } type osResolver struct { - initializedLanServers atomic.Pointer[[]string] - lanServers atomic.Pointer[[]string] - publicServers atomic.Pointer[[]string] + lanServers atomic.Pointer[[]string] + publicServers atomic.Pointer[[]string] } type osResolverResult struct { @@ -504,7 +353,10 @@ func LookupIP(domain string) []string { } func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) { - nss := defaultNameservers() + if or == nil { + or = newResolverWithNameserver(defaultNameservers()) + } + nss := *or.lanServers.Load() if withBootstrapDNS { nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...) } From 028475a1938c76f68728fc5b1a032c47740fff18 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 31 Jan 2025 14:50:38 -0500 Subject: [PATCH 055/100] fix os.Resolve method to prefer LAN answers fix os.Resolve method to prefer LAN answers early return for stop cmd when not installed or stopped increase service restart delay to 5s --- cmd/cli/commands.go | 11 ++++++++++ cmd/cli/prog.go | 4 +++- cmd/cli/service_windows.go | 6 +++--- resolver.go | 44 +++++++++++++++++++++++--------------- 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 70d4467..8396c19 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -561,6 +561,17 @@ func initStopCmd() *cobra.Command { } initLogging() + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is already stopped") + return + } + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 8390680..fd49764 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -578,7 +578,9 @@ func (p *prog) metricsEnabled() bool { func (p *prog) Stop(s service.Service) error { p.stopDnsWatchers() mainLog.Load().Debug().Msg("dns watchers stopped") - mainLog.Load().Info().Msg("Service stopped") + defer func() { + mainLog.Load().Info().Msg("Service stopped") + }() close(p.stopCh) if err := p.deAllocateIP(); err != nil { mainLog.Load().Error().Err(err).Msg("de-allocate ip failed") diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index 6e3bd82..c4df5a5 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -68,9 +68,9 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error { // Then proceed with existing actions, e.g. setting failure actions actions := []mgr.RecoveryAction{ - {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds - {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds - {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds } // Set the recovery actions (3 restarts, reset period = 120). diff --git a/resolver.go b/resolver.go index 49b81af..c26560b 100644 --- a/resolver.go +++ b/resolver.go @@ -8,7 +8,6 @@ import ( "net" "net/netip" "slices" - "strings" "sync" "sync/atomic" "time" @@ -212,6 +211,11 @@ type osResolverResult struct { lan bool } +type publicResponse struct { + answer *dns.Msg + server string +} + // Resolve resolves DNS queries using pre-configured nameservers. // Query is sent to all nameservers concurrently, and the first // success response will be returned. @@ -257,33 +261,37 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } logAnswer := func(server string) { - if before, _, found := strings.Cut(server, ":"); found { - server = before + host, _, err := net.SplitHostPort(server) + if err != nil { + // If splitting fails, fallback to the original server string + host = server } - Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", server) + Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host) } var ( nonSuccessAnswer *dns.Msg nonSuccessServer string controldSuccessAnswer *dns.Msg - publicServerAnswer *dns.Msg - publicServer string + publicResponses []publicResponse ) errs := make([]error, 0, numServers) for res := range ch { switch { case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: switch { - case res.server == controldPublicDnsWithPort: - controldSuccessAnswer = res.answer - case !res.lan && publicServerAnswer == nil: - publicServerAnswer = res.answer - publicServer = res.server - default: - Log(ctx, ProxyLogger.Load().Debug(), "got LAN answer from: %s", res.server) + case res.lan: + // Always prefer LAN responses immediately + Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil + case res.server == controldPublicDnsWithPort: + controldSuccessAnswer = res.answer + case !res.lan: + publicResponses = append(publicResponses, publicResponse{ + answer: res.answer, + server: res.server, + }) } case res.answer != nil: nonSuccessAnswer = res.answer @@ -293,10 +301,12 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } errs = append(errs, res.err) } - if publicServerAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", publicServer) - logAnswer(publicServer) - return publicServerAnswer, nil + + if len(publicResponses) > 0 { + resp := publicResponses[0] + Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server) + logAnswer(resp.server) + return resp.answer, nil } if controldSuccessAnswer != nil { Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort) From 1560455ca3b325eec8c8c6f6f29e3029d0ec9a81 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Sun, 2 Feb 2025 11:35:49 +0700 Subject: [PATCH 056/100] Use all available nameservers in lookupIP Some systems may be configured with public DNS only, so relying solely on LAN servers could make the lookup process failed unexpectedly. --- config_internal_test.go | 4 ++++ resolver.go | 1 + 2 files changed, 5 insertions(+) diff --git a/config_internal_test.go b/config_internal_test.go index 6823686..44b7e2f 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -2,12 +2,16 @@ package ctrld import ( "net/url" + "os" "testing" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" ) func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { + l := zerolog.New(os.Stdout) + ProxyLogger.Store(&l) uc := &UpstreamConfig{ Name: "test", Type: ResolverTypeDOH, diff --git a/resolver.go b/resolver.go index c26560b..f036967 100644 --- a/resolver.go +++ b/resolver.go @@ -367,6 +367,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) or = newResolverWithNameserver(defaultNameservers()) } nss := *or.lanServers.Load() + nss = append(nss, *or.publicServers.Load()...) if withBootstrapDNS { nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...) } From 168eaf538b50f94e5b0e0e5bc6d9fb8b9b2d2b61 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 3 Feb 2025 21:19:03 -0500 Subject: [PATCH 057/100] increase OSresolver timeout, fix debug log statements flush dns cache, manually hit captive portal on MacOS fix real ip in debug log treat all upstreams as down upon network change delay upstream checks when leaking queries on network changes --- cmd/cli/dns_proxy.go | 96 +++++++++++++++++++++++++++++++++++++++----- resolver.go | 2 +- 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index ac808db..25e3e53 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "net" + "net/http" "net/netip" + "os/exec" "runtime" "slices" "strconv" @@ -42,7 +44,7 @@ const ( var osUpstreamConfig = &ctrld.UpstreamConfig{ Name: "OS resolver", Type: ctrld.ResolverTypeOS, - Timeout: 2000, + Timeout: 3000, } var privateUpstreamConfig = &ctrld.UpstreamConfig{ @@ -436,10 +438,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { leaked := false if len(upstreamConfigs) > 0 { p.leakingQueryMu.Lock() - if p.leakingQueryRunning[upstreamMapKey] { + if p.leakingQueryRunning[upstreamMapKey] || p.leakingQueryRunning["all"] { upstreamConfigs = nil leaked = true - ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) + if p.leakingQueryRunning["all"] { + ctrld.Log(ctx, mainLog.Load().Debug(), "all upstreams marked down for network change, leaking query to OS resolver") + } else { + ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) + } } p.leakingQueryMu.Unlock() } @@ -576,13 +582,13 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { Bool("is_lan_query", isLanOrPtrQuery) if p.isLoop(upstreamConfig) { - logger.Msg("DNS loop detected") + ctrld.Log(ctx, logger, "DNS loop detected") continue } if p.um.isDown(upstreams[n]) { logger. - Bool("is_os_resolver", upstreams[n] == upstreamOS). - Msg("Upstream is down") + Bool("is_os_resolver", upstreams[n] == upstreamOS) + ctrld.Log(ctx, logger, "Upstream is down") continue } answer := resolve(n, upstreamConfig, req.msg) @@ -995,10 +1001,11 @@ func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamCon // we only want to reset DNS when our resolver is broken // this allows us to find the new OS resolver nameservers - if p.um.isDown(upstreamOS) { + // we skip the all upstream lock key to prevent duplicate calls + if p.um.isDown(upstreamOS) && upstreamMapKey != "all" { mainLog.Load().Debug().Msg("OS resolver is down, reinitializing") - p.reinitializeOSResolver() + p.reinitializeOSResolver(false) } @@ -1006,6 +1013,15 @@ func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamCon ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // if a network change, delay upstream checks by 1s + // this is to ensure we actually leak queries to OS resolver + // We have observed some captive portals leak queries to public upstreams + // This can cause the captive portal on MacOS to not trigger a popup + if upstreamMapKey != "all" { + mainLog.Load().Debug().Msg("network change leaking queries, delaying upstream checks by 1s") + time.Sleep(1 * time.Second) + } + upstreamCh := make(chan string, len(failedUpstreams)) for name, uc := range failedUpstreams { go func(name string, uc *ctrld.UpstreamConfig) { @@ -1213,7 +1229,7 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M // by removing ctrld listenr from the interface, collecting the network nameservers // and re-initializing the OS resolver with the nameservers // applying listener back to the interface -func (p *prog) reinitializeOSResolver() { +func (p *prog) reinitializeOSResolver(networkChange bool) { // Cancel any existing operations p.resetCtxMu.Lock() if p.resetCancel != nil { @@ -1232,6 +1248,21 @@ func (p *prog) reinitializeOSResolver() { p.leakingQueryReset.Store(true) defer p.leakingQueryReset.Store(false) + defer func() { + // start leaking queries immediately + if networkChange { + // set all upstreams to fialed and provide to performLeakingQuery + failedUpstreams := make(map[string]*ctrld.UpstreamConfig) + for _, upstream := range p.cfg.Upstream { + failedUpstreams[upstream.Name] = upstream + } + go p.performLeakingQuery(failedUpstreams, "all") + } + if err := FlushDNSCache(); err != nil { + mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") + } + }() + select { case <-ctx.Done(): mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") @@ -1264,6 +1295,51 @@ func (p *prog) reinitializeOSResolver() { } } +func triggerCaptiveCheck() { + // Wait for a short period to ensure DNS reinitialization is complete. + time.Sleep(2 * time.Second) + + // if not Mac OS, return + if runtime.GOOS != "darwin" { + return + } + + // Trigger a lookup for captive.apple.com. + // This can be done either via a DNS query or an HTTP GET. + // Here we use a simple HTTP GET which is what macOS CaptiveNetworkAssistant uses. + client := &http.Client{ + Timeout: 5 * time.Second, + } + resp, err := client.Get("http://captive.apple.com/generate_204") + if err != nil { + mainLog.Load().Debug().Msg("failed to trigger captive portal check") + return + } + resp.Body.Close() + mainLog.Load().Debug().Msg("triggered captive portal check by querying captive.apple.com") +} + +// FlushDNSCache flushes the DNS cache on macOS. +func FlushDNSCache() error { + // if not Mac OS, return + if runtime.GOOS != "darwin" { + return nil + } + + // Flush the DNS cache via mDNSResponder. + // This is typically needed on modern macOS systems. + if err := exec.Command("sudo", "killall", "-HUP", "mDNSResponder").Run(); err != nil { + return fmt.Errorf("failed to flush mDNSResponder: %w", err) + } + + // Optionally, flush the directory services cache. + if err := exec.Command("sudo", "dscacheutil", "-flushcache").Run(); err != nil { + return fmt.Errorf("failed to flush dscacheutil: %w", err) + } + + return nil +} + // monitorNetworkChanges starts monitoring for network interface changes func (p *prog) monitorNetworkChanges() error { mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) @@ -1321,7 +1397,7 @@ func (p *prog) monitorNetworkChanges() error { } if activeInterfaceExists { - p.reinitializeOSResolver() + p.reinitializeOSResolver(true) } else { mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") } diff --git a/resolver.go b/resolver.go index f036967..34a6cdd 100644 --- a/resolver.go +++ b/resolver.go @@ -237,7 +237,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error ctx, cancel := context.WithCancel(ctx) defer cancel() - dnsClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second} + dnsClient := &dns.Client{Net: "udp", Timeout: 3 * time.Second} ch := make(chan *osResolverResult, numServers) wg := &sync.WaitGroup{} wg.Add(numServers) From f57972ead767064228406ac7da3f634c37b25271 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 4 Feb 2025 13:27:15 +0700 Subject: [PATCH 058/100] cmd/cli: make runtime log format better By using more friendly markers to indicate the end of each log section, so it's easier to read/parse for both human and machine. --- cmd/cli/log_writer.go | 9 +++++---- cmd/cli/log_writer_test.go | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 92e0e63..1bcfe3d 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -20,8 +20,9 @@ const ( logWriterSmallSize = 1024 * 1024 * 1 // 1 MB logWriterInitialSize = 32 * 1024 // 32 KB logSentInterval = time.Minute - logTruncatedMarker = "...\n" - logSeparator = "\n===\n\n" + logStartEndMarker = "\n\n=== START_END ===\n\n" + logLogEndMarker = "\n\n=== LOG_END ===\n\n" + logWarnEndMarker = "\n\n=== WARN END ===\n\n" ) type logViewResponse struct { @@ -74,7 +75,7 @@ func (lw *logWriter) Write(p []byte) (int, error) { } lw.buf.Reset() lw.buf.Write(buf) - lw.buf.WriteString(logTruncatedMarker) // indicate that the log was truncated. + lw.buf.WriteString(logStartEndMarker) // indicate that the log was truncated. } // If p is bigger than buffer size, truncate p by half until its size is smaller. for len(p)+lw.buf.Len() > lw.size { @@ -157,7 +158,7 @@ func (p *prog) logReader() (*logReader, error) { wlwReader := bytes.NewReader(wlw.buf.Bytes()) wlwSize := wlw.buf.Len() wlw.mu.Unlock() - reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logSeparator)), wlwReader) + reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logLogEndMarker)), wlwReader) lr := &logReader{r: io.NopCloser(reader)} lr.size = int64(lwSize + wlwSize) if lr.size == 0 { diff --git a/cmd/cli/log_writer_test.go b/cmd/cli/log_writer_test.go index 92c772b..bd48785 100644 --- a/cmd/cli/log_writer_test.go +++ b/cmd/cli/log_writer_test.go @@ -16,7 +16,7 @@ func Test_logWriter_Write(t *testing.T) { t.Fatalf("unexpected buf content: %v", lw.buf.String()) } newData := "B" - halfData := strings.Repeat("A", len(data)/2) + logTruncatedMarker + halfData := strings.Repeat("A", len(data)/2) + logStartEndMarker lw.Write([]byte(newData)) if lw.buf.String() != halfData+newData { t.Fatalf("unexpected new buf content: %v", lw.buf.String()) From eb27d1482be212a20bfc1fe173d8883a1418d659 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 4 Feb 2025 13:33:34 +0700 Subject: [PATCH 059/100] cmd/cli: use warn level for network changes logging So these events will be recorded separately from normal runtime log, making troubleshooting later more easily. While at it, only update ctrld.ProxyLogger for runCmd, it's the only one which needs to log the query when proxying requests. --- cmd/cli/cli.go | 2 ++ cmd/cli/dns_proxy.go | 14 +++++++------- cmd/cli/main.go | 2 -- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index af5bb75..f59d4bb 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -267,6 +267,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // Log config do not have thing to validate, so it's safe to init log here, // so it's able to log information in processCDFlags. logWriters := initLogging() + // TODO: find a better way. + ctrld.ProxyLogger.Store(mainLog.Load()) // Initializing internal logging after global logging. p.initInternalLogging(logWriters) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 25e3e53..099bae2 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1251,7 +1251,7 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { defer func() { // start leaking queries immediately if networkChange { - // set all upstreams to fialed and provide to performLeakingQuery + // set all upstreams to failed and provide to performLeakingQuery failedUpstreams := make(map[string]*ctrld.UpstreamConfig) for _, upstream := range p.cfg.Upstream { failedUpstreams[upstream.Name] = upstream @@ -1280,7 +1280,7 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { default: mainLog.Load().Debug().Msg("initializing OS resolver") ns := ctrld.InitializeOsResolver() - mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns) + mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) } select { @@ -1352,7 +1352,7 @@ func (p *prog) monitorNetworkChanges() error { validIfaces := validInterfacesMap() // log the delta for debugging - mainLog.Load().Debug(). + mainLog.Load().Warn(). Interface("old_state", delta.Old). Interface("new_state", delta.New). Msg("Network change detected") @@ -1376,14 +1376,14 @@ func (p *prog) monitorNetworkChanges() error { // Compare states directly if oldExists != newExists || oldState != newState { changed = true - mainLog.Load().Debug(). + mainLog.Load().Warn(). Str("interface", ifaceName). Str("old_state", oldState). Str("new_state", newState). Msg("Valid interface changed state") break } else { - mainLog.Load().Debug(). + mainLog.Load().Warn(). Str("interface", ifaceName). Str("old_state", oldState). Str("new_state", newState). @@ -1392,14 +1392,14 @@ func (p *prog) monitorNetworkChanges() error { } if !changed { - mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") + mainLog.Load().Warn().Msg("Ignoring interface change - no valid interfaces affected") return } if activeInterfaceExists { p.reinitializeOSResolver(true) } else { - mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") + mainLog.Load().Warn().Msg("No active interfaces found, skipping reinitialization") } }) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 53662aa..7041318 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -143,8 +143,6 @@ func initLoggingWithBackup(doBackup bool) []io.Writer { multi := zerolog.MultiLevelWriter(writers...) l := mainLog.Load().Output(multi).With().Logger() mainLog.Store(&l) - // TODO: find a better way. - ctrld.ProxyLogger.Store(&l) zerolog.SetGlobalLevel(zerolog.NoticeLevel) logLevel := cfg.Service.LogLevel From 57ef7170808f251d251d622761df0937fcde85f3 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 4 Feb 2025 13:36:48 +0700 Subject: [PATCH 060/100] cmd/cli: improve error message returned by FlushDNSCache By recording both the error and output of external commands. While at it: - Removing un-necessary usages of sudo, since ctrld already running with root privilege. - Removing un-used function triggerCaptiveCheck. --- cmd/cli/dns_proxy.go | 35 +++++------------------------------ 1 file changed, 5 insertions(+), 30 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 099bae2..853e77a 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "net" - "net/http" "net/netip" "os/exec" "runtime" @@ -1295,46 +1294,22 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { } } -func triggerCaptiveCheck() { - // Wait for a short period to ensure DNS reinitialization is complete. - time.Sleep(2 * time.Second) - - // if not Mac OS, return - if runtime.GOOS != "darwin" { - return - } - - // Trigger a lookup for captive.apple.com. - // This can be done either via a DNS query or an HTTP GET. - // Here we use a simple HTTP GET which is what macOS CaptiveNetworkAssistant uses. - client := &http.Client{ - Timeout: 5 * time.Second, - } - resp, err := client.Get("http://captive.apple.com/generate_204") - if err != nil { - mainLog.Load().Debug().Msg("failed to trigger captive portal check") - return - } - resp.Body.Close() - mainLog.Load().Debug().Msg("triggered captive portal check by querying captive.apple.com") -} - // FlushDNSCache flushes the DNS cache on macOS. func FlushDNSCache() error { - // if not Mac OS, return + // if not macOS, return if runtime.GOOS != "darwin" { return nil } // Flush the DNS cache via mDNSResponder. // This is typically needed on modern macOS systems. - if err := exec.Command("sudo", "killall", "-HUP", "mDNSResponder").Run(); err != nil { - return fmt.Errorf("failed to flush mDNSResponder: %w", err) + if out, err := exec.Command("killall", "-HUP", "mDNSResponder").CombinedOutput(); err != nil { + return fmt.Errorf("failed to flush mDNSResponder: %w, output: %s", err, string(out)) } // Optionally, flush the directory services cache. - if err := exec.Command("sudo", "dscacheutil", "-flushcache").Run(); err != nil { - return fmt.Errorf("failed to flush dscacheutil: %w", err) + if out, err := exec.Command("dscacheutil", "-flushcache").CombinedOutput(); err != nil { + return fmt.Errorf("failed to flush dscacheutil: %w, output: %s", err, string(out)) } return nil From 595071b6089a2843279598fe42fb09a1d461d2aa Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 4 Feb 2025 17:58:05 +0700 Subject: [PATCH 061/100] all: update client info table on network changes So the client metadata will be updated correctly when the device roaming between networks. --- cmd/cli/dns_proxy.go | 19 ++++++++-- cmd/cli/prog.go | 59 ++++++++++++++++++++---------- internal/clientinfo/client_info.go | 40 ++++++++++++++++---- 3 files changed, 87 insertions(+), 31 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 853e77a..18ac373 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -84,9 +84,9 @@ type upstreamForResult struct { srcAddr string } -func (p *prog) serveDNS(listenerNum string) error { +func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { // Start network monitoring - if err := p.monitorNetworkChanges(); err != nil { + if err := p.monitorNetworkChanges(mainCtx); err != nil { mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring") // Don't return here as we still want DNS service to run } @@ -1316,7 +1316,7 @@ func FlushDNSCache() error { } // monitorNetworkChanges starts monitoring for network interface changes -func (p *prog) monitorNetworkChanges() error { +func (p *prog) monitorNetworkChanges(ctx context.Context) error { mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: ")) if err != nil { return fmt.Errorf("creating network monitor: %w", err) @@ -1336,6 +1336,19 @@ func (p *prog) monitorNetworkChanges() error { oldIfs := parseInterfaceState(delta.Old) newIfs := parseInterfaceState(delta.New) + // Client info discover only run on non-mobile platforms. + if !isMobile() { + // If this is major change, re-init client info table if its self IP changes. + if delta.Monitor.IsMajorChangeFrom(delta.Old, delta.New) { + selfIP := defaultRouteIP() + if currentSelfIP := p.ciTable.SelfIP(); currentSelfIP != selfIP && selfIP != "" { + p.stopClientInfoDiscover() + p.setupClientInfoDiscover(selfIP) + p.runClientInfoDiscover(ctx) + } + } + } + // Check for changes in valid interfaces changed := false activeInterfaceExists := false diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index fd49764..4c9270c 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -72,6 +72,7 @@ var useSystemdResolved = false type prog struct { mu sync.Mutex + wg sync.WaitGroup waitCh chan struct{} stopCh chan struct{} reloadCh chan struct{} // For Windows. @@ -451,7 +452,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } var wg sync.WaitGroup - wg.Add(len(p.cfg.Listener)) + p.wg = wg + p.wg.Add(len(p.cfg.Listener)) for _, nc := range p.cfg.Network { for _, cidr := range nc.Cidrs { @@ -477,12 +479,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } p.setupUpstream(p.cfg) - p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers) - if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) - format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) - p.ciTable.AddLeaseFile(leaseFile, format) - } + p.setupClientInfoDiscover(defaultRouteIP()) } // context for managing spawn goroutines. @@ -491,12 +488,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { // Newer versions of android and iOS denies permission which breaks connectivity. if !isMobile() && !reload { - wg.Add(1) - go func() { - defer wg.Done() - p.ciTable.Init() - p.ciTable.RefreshLoop(ctx) - }() + p.runClientInfoDiscover(ctx) go p.watchLinkState(ctx) } @@ -511,7 +503,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum); err != nil { + if err := p.serveDNS(ctx, listenerNum); err != nil { mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } }(listenerNum) @@ -519,7 +511,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { go func() { defer func() { cancelFunc() - wg.Done() + p.wg.Done() }() select { case <-p.stopCh: @@ -540,19 +532,19 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { close(p.onStartedDone) - wg.Add(1) + p.wg.Add(1) go func() { - defer wg.Done() + defer p.wg.Done() // Check for possible DNS loop. p.checkDnsLoop() // Start check DNS loop ticker. p.checkDnsLoopTicker(ctx) }() - wg.Add(1) + p.wg.Add(1) // Prometheus exporter goroutine. go func() { - defer wg.Done() + defer p.wg.Done() p.runMetricsServer(ctx, reloadCh) }() @@ -567,7 +559,34 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.postRun() p.initInternalLogging(logWriters) } - wg.Wait() + p.wg.Wait() +} + +// setupClientInfoDiscover performs necessary works for running client info discover. +func (p *prog) setupClientInfoDiscover(selfIP string) { + p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers) + if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { + mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) + p.ciTable.AddLeaseFile(leaseFile, format) + } +} + +// runClientInfoDiscover runs the client info discover in background. +func (p *prog) runClientInfoDiscover(ctx context.Context) { + p.wg.Add(1) + go func() { + defer p.wg.Done() + p.ciTable.Init() + p.ciTable.RefreshLoop(ctx) + }() +} + +// stopClientInfoDiscover stops the current client info discover goroutine. +// It blocks until the goroutine terminated. +func (p *prog) stopClientInfoDiscover() { + p.ciTable.Stop() + mainLog.Load().Debug().Msg("stopped client info discover") } // metricsEnabled reports whether prometheus exporter is enabled/disabled. diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 780334b..04ec4c3 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -77,6 +77,7 @@ type Table struct { hostnameResolvers []HostnameResolver refreshers []refresher initOnce sync.Once + stopOnce sync.Once refreshInterval int dhcp *dhcp @@ -90,6 +91,7 @@ type Table struct { vni *virtualNetworkIface svcCfg ctrld.ServiceConfig quitCh chan struct{} + stopCh chan struct{} selfIP string cdUID string ptrNameservers []string @@ -103,6 +105,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { return &Table{ svcCfg: cfg.Service, quitCh: make(chan struct{}), + stopCh: make(chan struct{}), selfIP: selfIP, cdUID: cdUID, ptrNameservers: ns, @@ -120,24 +123,47 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) { // RefreshLoop runs all the refresher to update new client info data. func (t *Table) RefreshLoop(ctx context.Context) { timer := time.NewTicker(time.Second * time.Duration(t.refreshInterval)) - defer timer.Stop() + defer func() { + timer.Stop() + close(t.quitCh) + }() for { select { case <-timer.C: - for _, r := range t.refreshers { - _ = r.refresh() - } + t.Refresh() + case <-t.stopCh: + return case <-ctx.Done(): - close(t.quitCh) return } } } +// Init initializes all client info discovers. func (t *Table) Init() { t.initOnce.Do(t.init) } +// Refresh forces all discovers to retrieve new data. +func (t *Table) Refresh() { + for _, r := range t.refreshers { + _ = r.refresh() + } +} + +// Stop stops all the discovers. +// It blocks until all the discovers done. +func (t *Table) Stop() { + t.stopOnce.Do(func() { + close(t.stopCh) + }) + <-t.quitCh +} + +func (t *Table) SelfIP() string { + return t.selfIP +} + func (t *Table) init() { // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { @@ -381,9 +407,7 @@ func (t *Table) lookupHostnameAll(ip, mac string) []*hostnameEntry { // ListClients returns list of clients discovered by ctrld. func (t *Table) ListClients() []*Client { - for _, r := range t.refreshers { - _ = r.refresh() - } + t.Refresh() ipMap := make(map[string]*Client) il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni} for _, ir := range il { From 2d3779ec27ce59e2b3fc07c4aa039a59111b06ed Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 4 Feb 2025 18:38:48 -0500 Subject: [PATCH 062/100] fix MacOS nameserver detection, fix not installed errors for commands copy fix get valid ifaces in nameservers_bsd nameservers on MacOS can be found in resolv.conf reliably nameservers on MacOS can be found in resolv.conf reliably exclude local IPs from MacOS resolve conf check use scutil for MacOS, simplify reinit logic to prevent duplicate calls add more dns server fetching options never skip OS resolver in IsDown check split dsb and darwin nameserver methods, add delay for setting DNS on interface on network change. increase delay to 5s but only on MacOS --- cmd/cli/commands.go | 56 ++++++++++ cmd/cli/dns_proxy.go | 86 ++++++--------- cmd/cli/log_writer.go | 4 +- cmd/cli/prog.go | 4 +- nameservers_bsd.go | 19 +--- nameservers_darwin.go | 243 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 337 insertions(+), 75 deletions(-) create mode 100644 nameservers_darwin.go diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 8396c19..3dae547 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -43,6 +43,20 @@ func initLogCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -82,6 +96,20 @@ func initLogCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -765,6 +793,20 @@ func initReloadCmd(restartCmd *cobra.Command) *cobra.Command { Short: "Reload the ctrld service", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") @@ -1045,6 +1087,20 @@ func initClientsCmd() *cobra.Command { checkHasElevatedPrivilege() }, Run: func(cmd *cobra.Command, args []string) { + + p := &prog{router: router.New(&cfg, false)} + s, _ := newService(p, svcConfig) + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is not running") + return + } + dir, err := socketDir() if err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 18ac373..646bafb 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -585,10 +585,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { continue } if p.um.isDown(upstreams[n]) { - logger. - Bool("is_os_resolver", upstreams[n] == upstreamOS) - ctrld.Log(ctx, logger, "Upstream is down") - continue + // never skip the OS resolver, since we usually query this resolver when we + // have no other upstreams to query + if upstreams[n] != upstreamOS { + logger. + Bool("is_os_resolver", upstreams[n] == upstreamOS) + ctrld.Log(ctx, logger, "Upstream is down") + continue + } } answer := resolve(n, upstreamConfig, req.msg) if answer == nil { @@ -1231,67 +1235,43 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M func (p *prog) reinitializeOSResolver(networkChange bool) { // Cancel any existing operations p.resetCtxMu.Lock() - if p.resetCancel != nil { - p.resetCancel() - } - - // Create new context for this operation - ctx, cancel := context.WithCancel(context.Background()) - p.resetCtx = ctx - p.resetCancel = cancel - p.resetCtxMu.Unlock() - - // Ensure cleanup - defer cancel() + defer p.resetCtxMu.Unlock() p.leakingQueryReset.Store(true) defer p.leakingQueryReset.Store(false) - defer func() { - // start leaking queries immediately - if networkChange { - // set all upstreams to failed and provide to performLeakingQuery - failedUpstreams := make(map[string]*ctrld.UpstreamConfig) - for _, upstream := range p.cfg.Upstream { - failedUpstreams[upstream.Name] = upstream - } - go p.performLeakingQuery(failedUpstreams, "all") + mainLog.Load().Debug().Msg("attempting to reset DNS") + p.resetDNS() + mainLog.Load().Debug().Msg("DNS reset completed") + + mainLog.Load().Debug().Msg("initializing OS resolver") + ns := ctrld.InitializeOsResolver() + mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) + + // start leaking queries immediately// start leaking queries immediately + if networkChange { + // set all upstreams to failed and provide to performLeakingQuery + failedUpstreams := make(map[string]*ctrld.UpstreamConfig) + for _, upstream := range p.cfg.Upstream { + failedUpstreams[upstream.Name] = upstream } + go p.performLeakingQuery(failedUpstreams, "all") + if err := FlushDNSCache(); err != nil { mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") } - }() - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("attempting to reset DNS") - p.resetDNS() - mainLog.Load().Debug().Msg("DNS reset completed") + if runtime.GOOS == "darwin" { + // delay putting back the ctrld listener to allow for captive portal to trigger + time.Sleep(5 * time.Second) + } } - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver() - mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) - } + mainLog.Load().Debug().Msg("setting DNS configuration") + p.setDNS() + mainLog.Load().Debug().Msg("DNS configuration set successfully") + p.logInterfacesState() - select { - case <-ctx.Done(): - mainLog.Load().Debug().Msg("DNS reset cancelled by new network change") - return - default: - mainLog.Load().Debug().Msg("setting DNS configuration") - p.setDNS() - mainLog.Load().Debug().Msg("DNS configuration set successfully") - p.logInterfacesState() - } } // FlushDNSCache flushes the DNS cache on macOS. diff --git a/cmd/cli/log_writer.go b/cmd/cli/log_writer.go index 1bcfe3d..339d984 100644 --- a/cmd/cli/log_writer.go +++ b/cmd/cli/log_writer.go @@ -20,9 +20,9 @@ const ( logWriterSmallSize = 1024 * 1024 * 1 // 1 MB logWriterInitialSize = 32 * 1024 // 32 KB logSentInterval = time.Minute - logStartEndMarker = "\n\n=== START_END ===\n\n" + logStartEndMarker = "\n\n=== INIT_END ===\n\n" logLogEndMarker = "\n\n=== LOG_END ===\n\n" - logWarnEndMarker = "\n\n=== WARN END ===\n\n" + logWarnEndMarker = "\n\n=== WARN_END ===\n\n" ) type logViewResponse struct { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 4c9270c..41dc2c4 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -123,9 +123,7 @@ type prog struct { leakingQueryRunning map[string]bool leakingQueryReset atomic.Bool - resetCtx context.Context - resetCancel context.CancelFunc - resetCtxMu sync.Mutex + resetCtxMu sync.Mutex started chan struct{} onStartedDone chan struct{} diff --git a/nameservers_bsd.go b/nameservers_bsd.go index 2beebd0..b835060 100644 --- a/nameservers_bsd.go +++ b/nameservers_bsd.go @@ -1,19 +1,16 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package ctrld import ( "net" - "os/exec" - "runtime" - "strings" "syscall" "golang.org/x/net/route" ) func dnsFns() []dnsFn { - return []dnsFn{dnsFromRIB, dnsFromIPConfig} + return []dnsFn{dnsFromRIB} } func dnsFromRIB() []string { @@ -49,18 +46,6 @@ func dnsFromRIB() []string { return dns } -func dnsFromIPConfig() []string { - if runtime.GOOS != "darwin" { - return nil - } - cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server") - out, _ := cmd.Output() - if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil { - return []string{ip.String()} - } - return nil -} - func toNetIP(addr route.Addr) net.IP { switch t := addr.(type) { case *route.Inet4Addr: diff --git a/nameservers_darwin.go b/nameservers_darwin.go new file mode 100644 index 0000000..bec6ce4 --- /dev/null +++ b/nameservers_darwin.go @@ -0,0 +1,243 @@ +//go:build darwin + +package ctrld + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "os/exec" + "regexp" + "slices" + "strings" + "time" + + "github.com/rs/zerolog" + "tailscale.com/net/netmon" +) + +func dnsFns() []dnsFn { + return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers} +} + +// dnsFromResolvConf reads nameservers from /etc/resolv.conf +func dnsFromResolvConf() []string { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var dns []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + file, err := os.Open("/etc/resolv.conf") + if err != nil { + Log(context.Background(), logger.Error(), "failed to open /etc/resolv.conf (attempt %d/%d)", attempt+1, maxRetries) + continue + } + defer file.Close() + + var localDNS []string + seen := make(map[string]bool) + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "nameserver" { + continue + } + if ip := net.ParseIP(fields[1]); ip != nil { + // skip loopback IPs + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + continue + } + } + if !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + + if err := scanner.Err(); err != nil { + Log(context.Background(), logger.Error(), "error reading /etc/resolv.conf (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + // If we successfully read the file and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return dns +} + +func getDNSFromScutil() []string { + logger := zerolog.New(io.Discard) + if ProxyLogger.Load() != nil { + logger = *ProxyLogger.Load() + } + + const ( + maxRetries = 10 + retryInterval = 100 * time.Millisecond + ) + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var nameservers []string + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryInterval) + } + + cmd := exec.Command("scutil", "--dns") + output, err := cmd.Output() + if err != nil { + Log(context.Background(), logger.Error(), "failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + var localDNS []string + seen := make(map[string]bool) + + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "nameserver[") { + parts := strings.Split(line, ":") + if len(parts) == 2 { + ns := strings.TrimSpace(parts[1]) + if ip := net.ParseIP(ns); ip != nil { + // skip loopback IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + ipStr := v.String() + if ip.String() == ipStr { + isLocal = true + break + } + } + if !isLocal && !seen[ip.String()] { + seen[ip.String()] = true + localDNS = append(localDNS, ip.String()) + } + } + } + } + } + + if err := scanner.Err(); err != nil { + Log(context.Background(), logger.Error(), "error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err) + continue + } + + // If we successfully read the output and found nameservers, return them + if len(localDNS) > 0 { + return localDNS + } + } + + return nameservers +} + +func getDHCPNameservers(iface string) ([]string, error) { + // Run the ipconfig command for the given interface. + cmd := exec.Command("ipconfig", "getpacket", iface) + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("error running ipconfig: %v", err) + } + + // Look for a line like: + // domain_name_servers = 192.168.1.1 8.8.8.8; + re := regexp.MustCompile(`domain_name_servers\s*=\s*(.*);`) + matches := re.FindStringSubmatch(string(output)) + if len(matches) < 2 { + return nil, fmt.Errorf("no DHCP nameservers found") + } + + // Split the nameservers by whitespace. + nameservers := strings.Fields(matches[1]) + return nameservers, nil +} + +func getAllDHCPNameservers() []string { + interfaces, err := net.Interfaces() + if err != nil { + return nil + } + + regularIPs, loopbackIPs, _ := netmon.LocalAddresses() + + var allNameservers []string + seen := make(map[string]bool) + + for _, iface := range interfaces { + // Skip interfaces that are: + // - down + // - loopback + // - not physical (virtual) + // - point-to-point (like VPN interfaces) + // - without MAC address (non-physical) + if iface.Flags&net.FlagUp == 0 || + iface.Flags&net.FlagLoopback != 0 || + iface.Flags&net.FlagPointToPoint != 0 || + (iface.Flags&net.FlagBroadcast == 0 && + iface.Flags&net.FlagMulticast == 0) || + len(iface.HardwareAddr) == 0 || + strings.HasPrefix(iface.Name, "utun") || + strings.HasPrefix(iface.Name, "llw") || + strings.HasPrefix(iface.Name, "awdl") { + continue + } + + // Verify it's a valid MAC address (should be 6 bytes for IEEE 802 MAC-48) + if len(iface.HardwareAddr) != 6 { + continue + } + + nameservers, err := getDHCPNameservers(iface.Name) + if err != nil { + continue + } + + // Add unique nameservers to the result, skipping local IPs + for _, ns := range nameservers { + if ip := net.ParseIP(ns); ip != nil { + // skip loopback and local IPs + isLocal := false + for _, v := range slices.Concat(regularIPs, loopbackIPs) { + if ip.String() == v.String() { + isLocal = true + break + } + } + if !isLocal && !seen[ns] { + seen[ns] = true + allNameservers = append(allNameservers, ns) + } + } + } + } + + return allNameservers +} From 47d7ace3a7647ffabc9062ec3ad3c060a732a258 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 5 Feb 2025 13:34:46 +0700 Subject: [PATCH 063/100] Simplify dnsFromResolvConf By using existed package instead of hand written one. While at it, also simplifying the logger getter, since the ProxyLogger is guaranted to be non-nil. --- nameservers_darwin.go | 38 ++++++-------------------------------- 1 file changed, 6 insertions(+), 32 deletions(-) diff --git a/nameservers_darwin.go b/nameservers_darwin.go index bec6ce4..d536d78 100644 --- a/nameservers_darwin.go +++ b/nameservers_darwin.go @@ -7,17 +7,16 @@ import ( "bytes" "context" "fmt" - "io" "net" - "os" "os/exec" "regexp" "slices" "strings" "time" - "github.com/rs/zerolog" "tailscale.com/net/netmon" + + "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) func dnsFns() []dnsFn { @@ -26,11 +25,6 @@ func dnsFns() []dnsFn { // dnsFromResolvConf reads nameservers from /etc/resolv.conf func dnsFromResolvConf() []string { - logger := zerolog.New(io.Discard) - if ProxyLogger.Load() != nil { - logger = *ProxyLogger.Load() - } - const ( maxRetries = 10 retryInterval = 100 * time.Millisecond @@ -44,24 +38,12 @@ func dnsFromResolvConf() []string { time.Sleep(retryInterval) } - file, err := os.Open("/etc/resolv.conf") - if err != nil { - Log(context.Background(), logger.Error(), "failed to open /etc/resolv.conf (attempt %d/%d)", attempt+1, maxRetries) - continue - } - defer file.Close() - + nss := resolvconffile.NameServers("") var localDNS []string seen := make(map[string]bool) - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) < 2 || fields[0] != "nameserver" { - continue - } - if ip := net.ParseIP(fields[1]); ip != nil { + for _, ns := range nss { + if ip := net.ParseIP(ns); ip != nil { // skip loopback IPs for _, v := range slices.Concat(regularIPs, loopbackIPs) { ipStr := v.String() @@ -76,11 +58,6 @@ func dnsFromResolvConf() []string { } } - if err := scanner.Err(); err != nil { - Log(context.Background(), logger.Error(), "error reading /etc/resolv.conf (attempt %d/%d): %v", attempt+1, maxRetries, err) - continue - } - // If we successfully read the file and found nameservers, return them if len(localDNS) > 0 { return localDNS @@ -91,10 +68,7 @@ func dnsFromResolvConf() []string { } func getDNSFromScutil() []string { - logger := zerolog.New(io.Discard) - if ProxyLogger.Load() != nil { - logger = *ProxyLogger.Load() - } + logger := *ProxyLogger.Load() const ( maxRetries = 10 From 60686f55ff3e70f00fa22a45c5ec4ca04c975581 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 5 Feb 2025 16:37:18 +0700 Subject: [PATCH 064/100] cmd/cli: set ProxyLogger correctly for interactive commands The ProxyLogger must only be set after mainLog is fully initialized. However, it's being set before the final initialization of mainlog, causing it still refers to stale old pointer. To fix this, introduce a new function to discard ProxyLogger explicitly, and use this function to init logging for all interactive commands. --- cmd/cli/cli.go | 4 +--- cmd/cli/commands.go | 8 ++++---- cmd/cli/main.go | 10 ++++++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index f59d4bb..49adca3 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -267,8 +267,6 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { // Log config do not have thing to validate, so it's safe to init log here, // so it's able to log information in processCDFlags. logWriters := initLogging() - // TODO: find a better way. - ctrld.ProxyLogger.Store(mainLog.Load()) // Initializing internal logging after global logging. p.initInternalLogging(logWriters) @@ -1023,7 +1021,7 @@ func uninstall(p *prog, s service.Service) { {s.Stop, false, "Stop"}, {s.Uninstall, true, "Uninstall"}, } - initLogging() + initInteractiveLogging() if doTasks(tasks) { if err := p.router.ConfigureService(svcConfig); err != nil { mainLog.Load().Fatal().Err(err).Msg("could not configure service") diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 3dae547..43ec485 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -336,7 +336,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } - initLogging() + initInteractiveLogging() tasks := []task{ {s.Stop, false, "Stop"}, resetDnsTask(p, s, isCtrldInstalled, currentIface), @@ -399,7 +399,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } - initLogging() + initInteractiveLogging() if nextdns != "" { removeNextDNSFromArgs(sc) @@ -588,7 +588,7 @@ func initStopCmd() *cobra.Command { p.requiredMultiNICsConfig = ir.All } - initLogging() + initInteractiveLogging() status, err := s.Status() if errors.Is(err, service.ErrNotInstalled) { @@ -695,7 +695,7 @@ func initRestartCmd() *cobra.Command { p.requiredMultiNICsConfig = ir.All } - initLogging() + initInteractiveLogging() if cdMode { doValidateCdRemoteConfig(cdUID) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 7041318..819797a 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -106,6 +106,14 @@ func initLogging() []io.Writer { return initLoggingWithBackup(true) } +// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded +// to be used for all interactive commands. +func initInteractiveLogging() { + initLogging() + l := zerolog.New(io.Discard) + ctrld.ProxyLogger.Store(&l) +} + // initLoggingWithBackup initializes log setup base on current config. // If doBackup is true, backup old log file with ".1" suffix. // @@ -143,6 +151,8 @@ func initLoggingWithBackup(doBackup bool) []io.Writer { multi := zerolog.MultiLevelWriter(writers...) l := mainLog.Load().Output(multi).With().Logger() mainLog.Store(&l) + // TODO: find a better way. + ctrld.ProxyLogger.Store(&l) zerolog.SetGlobalLevel(zerolog.NoticeLevel) logLevel := cfg.Service.LogLevel From cf6d16b43989f47291ae1e6772ad9d5feb843491 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 5 Feb 2025 01:41:16 -0500 Subject: [PATCH 065/100] set new dialer on every request debugging debugging debugging debugging use default route interface IP for OS resolver queries remove retries fix resolv.conf clobbering on MacOS, set custom local addr for os resolver queries remove the client info discovery logic on network change, this was overkill just for the IP, and was causing service failure after switching networks many times rapidly handle ipv6 local addresses guard ciTable from nil pointer debugging failure count --- cmd/cli/dns_proxy.go | 88 +++++++++++++++----- cmd/cli/prog.go | 1 + cmd/cli/resolvconf.go | 109 +++++++++++++++++++++++-- cmd/cli/upstream_monitor.go | 14 +++- config.go | 2 +- internal/clientinfo/client_info.go | 11 +++ resolver.go | 127 ++++++++++++++++++++++++++++- 7 files changed, 317 insertions(+), 35 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 646bafb..01e1673 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -99,6 +99,7 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { } handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + mainLog.Load().Debug().Msgf("serveDNS handler called") p.sema.acquire() defer p.sema.release() if len(m.Question) == 0 { @@ -1238,7 +1239,10 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { defer p.resetCtxMu.Unlock() p.leakingQueryReset.Store(true) - defer p.leakingQueryReset.Store(false) + defer func() { + time.Sleep(time.Second) + p.leakingQueryReset.Store(false) + }() mainLog.Load().Debug().Msg("attempting to reset DNS") p.resetDNS() @@ -1260,7 +1264,6 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { if err := FlushDNSCache(); err != nil { mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") } - if runtime.GOOS == "darwin" { // delay putting back the ctrld listener to allow for captive portal to trigger time.Sleep(5 * time.Second) @@ -1316,21 +1319,9 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { oldIfs := parseInterfaceState(delta.Old) newIfs := parseInterfaceState(delta.New) - // Client info discover only run on non-mobile platforms. - if !isMobile() { - // If this is major change, re-init client info table if its self IP changes. - if delta.Monitor.IsMajorChangeFrom(delta.Old, delta.New) { - selfIP := defaultRouteIP() - if currentSelfIP := p.ciTable.SelfIP(); currentSelfIP != selfIP && selfIP != "" { - p.stopClientInfoDiscover() - p.setupClientInfoDiscover(selfIP) - p.runClientInfoDiscover(ctx) - } - } - } - // Check for changes in valid interfaces changed := false + var changedIface, changedIfaceState string activeInterfaceExists := false for ifaceName := range validIfaces { @@ -1343,7 +1334,14 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Compare states directly if oldExists != newExists || oldState != newState { - changed = true + + // If the interface is up, we need to reinitialize the OS resolver + if newState != "" && !strings.Contains(newState, "down") { + changed = true + changedIface = ifaceName + changedIfaceState = newState + } + mainLog.Load().Warn(). Str("interface", ifaceName). Str("old_state", oldState). @@ -1364,11 +1362,33 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { return } - if activeInterfaceExists { - p.reinitializeOSResolver(true) - } else { + if !activeInterfaceExists { mainLog.Load().Warn().Msg("No active interfaces found, skipping reinitialization") + return } + + // Use the defaultRouteIP() result or fallback to the changed interface's IP from the delta. + selfIP := defaultRouteIP() + if selfIP == "" && changedIface != "" { + selfIP = extractIPv4FromState(changedIfaceState) + mainLog.Load().Info().Msgf("defaultRouteIP returned empty, using changed iface '%s' IP: %s", changedIface, selfIP) + } + + // Extract IPv6 from the changed interface state. + ipv6 := extractIPv6FromState(changedIfaceState) + + if ip := net.ParseIP(selfIP); ip != nil { + ctrld.SetDefaultLocalIPv4(ip) + // if we have a new IP, set the client info to the new IP + if !isMobile() && p.ciTable != nil { + p.ciTable.SetSelfIP(selfIP) + } + } + if ip := net.ParseIP(ipv6); ip != nil { + ctrld.SetDefaultLocalIPv6(ip) + } + mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) + p.reinitializeOSResolver(true) }) mon.Start() @@ -1423,3 +1443,33 @@ func parseInterfaceState(state *netmon.State) map[string]string { return result } + +// extractIPv4FromState extracts an IPv4 address from an interface state string. +// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239". +// If no valid IP can be found, it returns an empty string. +func extractIPv4FromState(state string) string { + trimmed := strings.Trim(state, "[]") + parts := strings.Fields(trimmed) + for _, part := range parts { + ipPart := strings.Split(part, "/")[0] + if ip := net.ParseIP(ipPart); ip != nil && ip.To4() != nil { + return ipPart + } + } + return "" +} + +// extractIPv6FromState extracts an IPv6 address from an interface state string. +// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239". +// If no valid IP can be found, it returns an empty string. +func extractIPv6FromState(state string) string { + trimmed := strings.Trim(state, "[]") + parts := strings.Fields(trimmed) + for _, part := range parts { + ipPart := strings.Split(part, "/")[0] + if ip := net.ParseIP(ipPart); ip != nil && ip.To4() == nil { + return ipPart + } + } + return "" +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 41dc2c4..c7eba13 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -504,6 +504,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if err := p.serveDNS(ctx, listenerNum); err != nil { mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } + mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr) }(listenerNum) } go func() { diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 367ffe7..9d37d68 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -3,11 +3,38 @@ package cli import ( "net" "net/netip" + "os" "path/filepath" + "strings" + "time" "github.com/fsnotify/fsnotify" ) +// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found. +// Returns nil if no nameservers are found. +func (p *prog) parseResolvConfNameservers(path string) ([]string, error) { + content, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + // Parse the file for "nameserver" lines + var currentNS []string + lines := strings.Split(string(content), "\n") + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "nameserver") { + parts := strings.Fields(trimmed) + if len(parts) >= 2 { + currentNS = append(currentNS, parts[1]) + } + } + } + + return currentNS, nil +} + // watchResolvConf watches any changes to /etc/resolv.conf file, // and reverting to the original config set by ctrld. func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) { @@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f continue } if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { - mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting") - if err := watcher.Remove(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to pause watcher") - continue + mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") + + // Convert expected nameservers to strings for comparison + expectedNS := make([]string, len(ns)) + for i, addr := range ns { + expectedNS[i] = addr.String() } - if err := setDnsFn(iface, ns); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + + var foundNS []string + var err error + + maxRetries := 1 + for retry := 0; retry < maxRetries; retry++ { + foundNS, err = p.parseResolvConfNameservers(resolvConfPath) + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content") + break + } + + // If we found nameservers, break out of retry loop + if len(foundNS) > 0 { + break + } + + // Only retry if we found no nameservers + if retry < maxRetries-1 { + mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) + select { + case <-p.stopCh: + return + case <-p.dnsWatcherStopCh: + return + case <-time.After(2 * time.Second): + continue + } + } else { + mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries") + } } - if err := watcher.Add(watchDir); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") - return + + // If we found nameservers, check if they match what we expect + if len(foundNS) > 0 { + // Check if the nameservers match exactly what we expect + matches := len(foundNS) == len(expectedNS) + if matches { + for i := range foundNS { + if foundNS[i] != expectedNS[i] { + matches = false + break + } + } + } + + mainLog.Load().Debug(). + Strs("found", foundNS). + Strs("expected", expectedNS). + Bool("matches", matches). + Msg("checking nameservers") + + // Only revert if the nameservers don't match + if !matches { + if err := watcher.Remove(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to pause watcher") + continue + } + + if err := setDnsFn(iface, ns); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + } + + if err := watcher.Add(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") + return + } + } } } case err, ok := <-watcher.Errors: diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index e37db4d..fc5d65d 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -42,14 +42,24 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { return um } -// increaseFailureCount increase failed queries count for an upstream by 1. +// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information. func (um *upstreamMonitor) increaseFailureCount(upstream string) { um.mu.Lock() defer um.mu.Unlock() um.failureReq[upstream] += 1 failedCount := um.failureReq[upstream] - um.down[upstream] = failedCount >= maxFailureRequest + + // Log the updated failure count + mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) + + // Check if the failure count has reached the threshold to mark the upstream as down. + if failedCount >= maxFailureRequest { + um.down[upstream] = true + mainLog.Load().Warn().Msgf("upstream %q marked as down (failure count: %d)", upstream, failedCount) + } else { + um.down[upstream] = false + } } // isDown reports whether the given upstream is being marked as down. diff --git a/config.go b/config.go index 099f75b..e1454f9 100644 --- a/config.go +++ b/config.go @@ -458,7 +458,7 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { if uc.rebootstrap.CompareAndSwap(false, true) { - ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip") + ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc) } return true, nil }) diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 04ec4c3..e6bda79 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -93,6 +93,7 @@ type Table struct { quitCh chan struct{} stopCh chan struct{} selfIP string + selfIPLock sync.RWMutex cdUID string ptrNameservers []string } @@ -160,10 +161,20 @@ func (t *Table) Stop() { <-t.quitCh } +// SelfIP returns the selfIP value of the Table in a thread-safe manner. func (t *Table) SelfIP() string { + t.selfIPLock.RLock() + defer t.selfIPLock.RUnlock() return t.selfIP } +// SetSelfIP sets the selfIP value of the Table in a thread-safe manner. +func (t *Table) SetSelfIP(ip string) { + t.selfIPLock.Lock() + defer t.selfIPLock.Unlock() + t.selfIP = ip +} + func (t *Table) init() { // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { diff --git a/resolver.go b/resolver.go index 34a6cdd..01348dc 100644 --- a/resolver.go +++ b/resolver.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/netip" + "runtime" "slices" "sync" "sync/atomic" @@ -50,8 +51,10 @@ var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53") var localResolver = newLocalResolver() var ( - resolverMutex sync.Mutex - or *osResolver + resolverMutex sync.Mutex + or *osResolver + defaultLocalIPv4 atomic.Value // holds net.IP (IPv4) + defaultLocalIPv6 atomic.Value // holds net.IP (IPv6) ) func newLocalResolver() Resolver { @@ -216,6 +219,108 @@ type publicResponse struct { server string } +// SetDefaultLocalIPv4 updates the stored local IPv4. +func SetDefaultLocalIPv4(ip net.IP) { + Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip) + defaultLocalIPv4.Store(ip) +} + +// SetDefaultLocalIPv6 updates the stored local IPv6. +func SetDefaultLocalIPv6(ip net.IP) { + Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip) + defaultLocalIPv6.Store(ip) +} + +// GetDefaultLocalIPv4 returns the stored local IPv4 or nil if none. +func GetDefaultLocalIPv4() net.IP { + if v := defaultLocalIPv4.Load(); v != nil { + return v.(net.IP) + } + return nil +} + +// GetDefaultLocalIPv6 returns the stored local IPv6 or nil if none. +func GetDefaultLocalIPv6() net.IP { + if v := defaultLocalIPv6.Load(); v != nil { + return v.(net.IP) + } + return nil +} + +// debugDialer is a helper type that wraps a net.Dialer and logs +// the local IP address used when dialing out. +type debugDialer struct { + *net.Dialer +} + +// DialContext wraps the underlying DialContext and logs the local address of the connection. +func (d *debugDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := d.Dialer.DialContext(ctx, network, addr) + if err != nil { + // Log the error even before a connection is established. + if d.Dialer.LocalAddr != nil { + Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v (local addr: %v)", addr, err, d.Dialer.LocalAddr) + } else { + Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v", addr, err) + } + return nil, err + } + // Log the local address (source IP) used for this connection. + Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s succeeded; local address: %s", + addr, conn.LocalAddr().String()) + return conn, nil +} + +// customDNSExchange wraps the DNS exchange to use our debug dialer. +// It uses dns.ExchangeWithConn so that our custom dialer is used directly. +func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, error) { + baseDialer := &net.Dialer{ + Timeout: 3 * time.Second, + Resolver: &net.Resolver{PreferGo: true}, + } + if desiredLocalIP != nil { + baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0} + } + dd := &debugDialer{Dialer: baseDialer} + + // Attempt UDP first. + udpConn, err := dd.DialContext(ctx, "udp", server) + if err != nil { + return nil, err + } + defer udpConn.Close() + udpDnsConn := &dns.Conn{Conn: udpConn} + if err = udpDnsConn.WriteMsg(msg); err != nil { + return nil, err + } + reply, err := udpDnsConn.ReadMsg() + if err != nil { + return nil, err + } + + // If the UDP reply is not truncated, return it. + if !reply.Truncated { + return reply, nil + } + + // If truncated, retry over TCP once. + Log(ctx, ProxyLogger.Load().Debug(), "UDP response truncated, switching to TCP (1 retry)") + tcpConn, err := dd.DialContext(ctx, "tcp", server) + if err != nil { + return reply, nil // fallback to UDP reply if TCP dial fails. + } + defer tcpConn.Close() + tcpDnsConn := &dns.Conn{Conn: tcpConn} + if err = tcpDnsConn.WriteMsg(msg); err != nil { + return reply, nil // fallback if TCP write fails. + } + tcpReply, err := tcpDnsConn.ReadMsg() + if err != nil { + return reply, nil // fallback if TCP read fails. + } + return tcpReply, nil +} + // Resolve resolves DNS queries using pre-configured nameservers. // Query is sent to all nameservers concurrently, and the first // success response will be returned. @@ -237,7 +342,6 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error ctx, cancel := context.WithCancel(ctx) defer cancel() - dnsClient := &dns.Client{Net: "udp", Timeout: 3 * time.Second} ch := make(chan *osResolverResult, numServers) wg := &sync.WaitGroup{} wg.Add(numServers) @@ -250,7 +354,22 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error for _, server := range servers { go func(server string) { defer wg.Done() - answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server) + var answer *dns.Msg + var err error + var localOSResolverIP net.IP + if runtime.GOOS == "darwin" { + host, _, err := net.SplitHostPort(server) + if err == nil { + ip := net.ParseIP(host) + if ip != nil && ip.To4() == nil { + // IPv6 nameserver; use default IPv6 address (if set) + localOSResolverIP = GetDefaultLocalIPv6() + } else { + localOSResolverIP = GetDefaultLocalIPv4() + } + } + } + answer, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP) ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan} }(server) } From 1c50c2b6af88127aaf601262ce06aaacb162d942 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 6 Feb 2025 15:16:32 +0700 Subject: [PATCH 066/100] Set deadline for custom UDP/TCP conn Otherwise, OS resolver may hang forever if the server does not reply. While at it, also removing unused method stopClientInfoDiscover. Updates #344 --- cmd/cli/prog.go | 7 ------- resolver.go | 2 ++ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index c7eba13..07a7592 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -581,13 +581,6 @@ func (p *prog) runClientInfoDiscover(ctx context.Context) { }() } -// stopClientInfoDiscover stops the current client info discover goroutine. -// It blocks until the goroutine terminated. -func (p *prog) stopClientInfoDiscover() { - p.ciTable.Stop() - mainLog.Load().Debug().Msg("stopped client info discover") -} - // metricsEnabled reports whether prometheus exporter is enabled/disabled. func (p *prog) metricsEnabled() bool { return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != "" diff --git a/resolver.go b/resolver.go index 01348dc..f4299e6 100644 --- a/resolver.go +++ b/resolver.go @@ -289,6 +289,7 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired return nil, err } defer udpConn.Close() + udpConn.SetDeadline(time.Now().Add(3 * time.Second)) udpDnsConn := &dns.Conn{Conn: udpConn} if err = udpDnsConn.WriteMsg(msg); err != nil { return nil, err @@ -310,6 +311,7 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired return reply, nil // fallback to UDP reply if TCP dial fails. } defer tcpConn.Close() + tcpConn.SetDeadline(time.Now().Add(3 * time.Second)) tcpDnsConn := &dns.Conn{Conn: tcpConn} if err = tcpDnsConn.WriteMsg(msg); err != nil { return reply, nil // fallback if TCP write fails. From 2716ae29bda75f661e83231bdfe45645c955bb08 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 6 Feb 2025 15:37:03 +0700 Subject: [PATCH 067/100] cmd/cli: remove unnecessary prog wait group Since the client info is now only run once, we don't need to propagate the wait group to other places for controlling new run. --- cmd/cli/prog.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 07a7592..5b36175 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -72,7 +72,6 @@ var useSystemdResolved = false type prog struct { mu sync.Mutex - wg sync.WaitGroup waitCh chan struct{} stopCh chan struct{} reloadCh chan struct{} // For Windows. @@ -450,8 +449,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } var wg sync.WaitGroup - p.wg = wg - p.wg.Add(len(p.cfg.Listener)) + wg.Add(len(p.cfg.Listener)) for _, nc := range p.cfg.Network { for _, cidr := range nc.Cidrs { @@ -486,7 +484,10 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { // Newer versions of android and iOS denies permission which breaks connectivity. if !isMobile() && !reload { - p.runClientInfoDiscover(ctx) + wg.Add(1) + go func() { + p.runClientInfoDiscover(ctx) + }() go p.watchLinkState(ctx) } @@ -510,7 +511,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { go func() { defer func() { cancelFunc() - p.wg.Done() + wg.Done() }() select { case <-p.stopCh: @@ -531,19 +532,19 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { close(p.onStartedDone) - p.wg.Add(1) + wg.Add(1) go func() { - defer p.wg.Done() + defer wg.Done() // Check for possible DNS loop. p.checkDnsLoop() // Start check DNS loop ticker. p.checkDnsLoopTicker(ctx) }() - p.wg.Add(1) + wg.Add(1) // Prometheus exporter goroutine. go func() { - defer p.wg.Done() + defer wg.Done() p.runMetricsServer(ctx, reloadCh) }() @@ -558,7 +559,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.postRun() p.initInternalLogging(logWriters) } - p.wg.Wait() + wg.Wait() } // setupClientInfoDiscover performs necessary works for running client info discover. @@ -571,14 +572,10 @@ func (p *prog) setupClientInfoDiscover(selfIP string) { } } -// runClientInfoDiscover runs the client info discover in background. +// runClientInfoDiscover runs the client info discover. func (p *prog) runClientInfoDiscover(ctx context.Context) { - p.wg.Add(1) - go func() { - defer p.wg.Done() - p.ciTable.Init() - p.ciTable.RefreshLoop(ctx) - }() + p.ciTable.Init() + p.ciTable.RefreshLoop(ctx) } // metricsEnabled reports whether prometheus exporter is enabled/disabled. From 3132d1b0328945015c77f3a9d79c5fcc05e038ed Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 6 Feb 2025 22:56:37 +0700 Subject: [PATCH 068/100] Remove debug dialer Since its puporse is solely for debugging, it could be one now. --- resolver.go | 71 ++++------------------------------------------------- 1 file changed, 5 insertions(+), 66 deletions(-) diff --git a/resolver.go b/resolver.go index f4299e6..650fd05 100644 --- a/resolver.go +++ b/resolver.go @@ -247,33 +247,9 @@ func GetDefaultLocalIPv6() net.IP { return nil } -// debugDialer is a helper type that wraps a net.Dialer and logs -// the local IP address used when dialing out. -type debugDialer struct { - *net.Dialer -} - -// DialContext wraps the underlying DialContext and logs the local address of the connection. -func (d *debugDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := d.Dialer.DialContext(ctx, network, addr) - if err != nil { - // Log the error even before a connection is established. - if d.Dialer.LocalAddr != nil { - Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v (local addr: %v)", addr, err, d.Dialer.LocalAddr) - } else { - Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s failed: %v", addr, err) - } - return nil, err - } - // Log the local address (source IP) used for this connection. - Log(ctx, ProxyLogger.Load().Debug(), "debugDialer: dial to %s succeeded; local address: %s", - addr, conn.LocalAddr().String()) - return conn, nil -} - // customDNSExchange wraps the DNS exchange to use our debug dialer. // It uses dns.ExchangeWithConn so that our custom dialer is used directly. -func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, error) { +func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, time.Duration, error) { baseDialer := &net.Dialer{ Timeout: 3 * time.Second, Resolver: &net.Resolver{PreferGo: true}, @@ -281,46 +257,9 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired if desiredLocalIP != nil { baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0} } - dd := &debugDialer{Dialer: baseDialer} - - // Attempt UDP first. - udpConn, err := dd.DialContext(ctx, "udp", server) - if err != nil { - return nil, err - } - defer udpConn.Close() - udpConn.SetDeadline(time.Now().Add(3 * time.Second)) - udpDnsConn := &dns.Conn{Conn: udpConn} - if err = udpDnsConn.WriteMsg(msg); err != nil { - return nil, err - } - reply, err := udpDnsConn.ReadMsg() - if err != nil { - return nil, err - } - - // If the UDP reply is not truncated, return it. - if !reply.Truncated { - return reply, nil - } - - // If truncated, retry over TCP once. - Log(ctx, ProxyLogger.Load().Debug(), "UDP response truncated, switching to TCP (1 retry)") - tcpConn, err := dd.DialContext(ctx, "tcp", server) - if err != nil { - return reply, nil // fallback to UDP reply if TCP dial fails. - } - defer tcpConn.Close() - tcpConn.SetDeadline(time.Now().Add(3 * time.Second)) - tcpDnsConn := &dns.Conn{Conn: tcpConn} - if err = tcpDnsConn.WriteMsg(msg); err != nil { - return reply, nil // fallback if TCP write fails. - } - tcpReply, err := tcpDnsConn.ReadMsg() - if err != nil { - return reply, nil // fallback if TCP read fails. - } - return tcpReply, nil + dnsClient := &dns.Client{Net: "udp"} + dnsClient.Dialer = baseDialer + return dnsClient.ExchangeContext(ctx, msg, server) } // Resolve resolves DNS queries using pre-configured nameservers. @@ -371,7 +310,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } } } - answer, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP) + answer, _, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP) ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan} }(server) } From ae6945cedfe6627efccbc048d22899dfb8697231 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 7 Feb 2025 01:02:06 +0700 Subject: [PATCH 069/100] cmd/cli: fix missing wg.Done call --- cmd/cli/prog.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 5b36175..b26b814 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -486,6 +486,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { if !isMobile() && !reload { wg.Add(1) go func() { + defer wg.Done() p.runClientInfoDiscover(ctx) }() go p.watchLinkState(ctx) From 38064d6ad54b75ac340568affd1889cd1eb39f56 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 13:30:11 -0500 Subject: [PATCH 070/100] parse InterfaceIPs for network delta, not just ifs block --- cmd/cli/dns_proxy.go | 141 +++++++++++++++++++++++++++++++------------ 1 file changed, 101 insertions(+), 40 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 01e1673..f9d6478 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/hex" + "encoding/json" "errors" "fmt" "net" @@ -11,6 +12,7 @@ import ( "os/exec" "runtime" "slices" + "sort" "strconv" "strings" "sync" @@ -1309,44 +1311,68 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Get map of valid interfaces validIfaces := validInterfacesMap() - // log the delta for debugging + // Log the delta for debugging mainLog.Load().Warn(). Interface("old_state", delta.Old). Interface("new_state", delta.New). Msg("Network change detected") - // Parse old and new interface states - oldIfs := parseInterfaceState(delta.Old) - newIfs := parseInterfaceState(delta.New) + // Parse old and new interface states, and extract IPs as well. + oldIfs, oldIPs := parseInterfaceState(delta.Old) + newIfs, newIPs := parseInterfaceState(delta.New) - // Check for changes in valid interfaces changed := false var changedIface, changedIfaceState string activeInterfaceExists := false + // Iterate over valid interfaces. for ifaceName := range validIfaces { - oldState, oldExists := oldIfs[strings.ToLower(ifaceName)] - newState, newExists := newIfs[strings.ToLower(ifaceName)] + lname := strings.ToLower(ifaceName) + oldState, oldExists := oldIfs[lname] + newState, newExists := newIfs[lname] + // Check if the interface appears active in the new state. if newState != "" && !strings.Contains(newState, "down") { activeInterfaceExists = true } - // Compare states directly - if oldExists != newExists || oldState != newState { + // Compare raw state strings... + stateChanged := (oldExists != newExists || oldState != newState) - // If the interface is up, we need to reinitialize the OS resolver + // ... and also compare the parsed IP slices. + ipChanged := false + oldIPSlice, okOld := oldIPs[lname] + newIPSlice, okNew := newIPs[lname] + if okOld && okNew { + // Create copies and sort them so that order does not matter. + sortedOld := append([]string(nil), oldIPSlice...) + sortedNew := append([]string(nil), newIPSlice...) + sort.Strings(sortedOld) + sort.Strings(sortedNew) + if !slices.Equal(sortedOld, sortedNew) { + ipChanged = true + } + } else if okOld != okNew { + ipChanged = true + } + + // If either the state string or the IPs have changed... + if stateChanged || ipChanged { if newState != "" && !strings.Contains(newState, "down") { changed = true changedIface = ifaceName - changedIfaceState = newState + // Prefer newState if present; if not, generate one from the IP slice. + if newState == "" && okNew { + changedIfaceState = "[" + strings.Join(newIPSlice, " ") + "]" + } else { + changedIfaceState = newState + } } - mainLog.Load().Warn(). Str("interface", ifaceName). Str("old_state", oldState). Str("new_state", newState). - Msg("Valid interface changed state") + Msg("Valid interface changed state (IP change detected: " + strconv.FormatBool(ipChanged) + ")") break } else { mainLog.Load().Warn(). @@ -1367,19 +1393,19 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { return } - // Use the defaultRouteIP() result or fallback to the changed interface's IP from the delta. + // Use the defaultRouteIP() result, or fall back to the changed interface's IPv4 from the new state. selfIP := defaultRouteIP() if selfIP == "" && changedIface != "" { selfIP = extractIPv4FromState(changedIfaceState) - mainLog.Load().Info().Msgf("defaultRouteIP returned empty, using changed iface '%s' IP: %s", changedIface, selfIP) + mainLog.Load().Info().Msgf("defaultRouteIP returned empty, using changed iface '%s' IPv4: %s", changedIface, selfIP) } - // Extract IPv6 from the changed interface state. + // Extract IPv6 from the changed state. ipv6 := extractIPv6FromState(changedIfaceState) if ip := net.ParseIP(selfIP); ip != nil { ctrld.SetDefaultLocalIPv4(ip) - // if we have a new IP, set the client info to the new IP + // If we have a new IP, update the client info. if !isMobile() && p.ciTable != nil { p.ciTable.SetSelfIP(selfIP) } @@ -1396,52 +1422,87 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { return nil } -// parseInterfaceState parses the interface state string into a map of interface name -> state -func parseInterfaceState(state *netmon.State) map[string]string { +// parseInterfaceState parses the netmon state into two maps: +// 1. stateMap: a mapping from interfaces (lowercase) to their original state string, +// formatted in square brackets (e.g. "[192.168.1.200/24 fe80::69f6:e16e:8bdb:0000/64]"). +// 2. ipMap: a mapping from interfaces (lowercase) to a slice of IP addresses extracted from that state. +// +// It first attempts JSON parsing to pull out both the "Interface" and "InterfaceIPs" fields. +// If JSON parsing fails, it falls back to the legacy parsing logic. +func parseInterfaceState(state *netmon.State) (map[string]string, map[string][]string) { + result := make(map[string]string) // Interface name -> state string. + ipMap := make(map[string][]string) // Interface name -> slice of IP addresses. + if state == nil { - return nil + return result, ipMap } - - result := make(map[string]string) - stateStr := state.String() - // Extract interface information - ifsStart := strings.Index(stateStr, "ifs={") - if ifsStart == -1 { - return result + // Attempt to parse the state string as JSON so we can extract both "Interface" and "InterfaceIPs". + var raw map[string]json.RawMessage + if err := json.Unmarshal([]byte(stateStr), &raw); err == nil { + var interfaces map[string]interface{} + var interfaceIPs map[string][]string + + if v, ok := raw["Interface"]; ok { + _ = json.Unmarshal(v, &interfaces) + } + if v, ok := raw["InterfaceIPs"]; ok { + _ = json.Unmarshal(v, &interfaceIPs) + } + // For every interface in the "Interface" section, check for its IPs. + for name := range interfaces { + lowerName := strings.ToLower(name) + if ips, ok := interfaceIPs[name]; ok && len(ips) > 0 { + result[lowerName] = "[" + strings.Join(ips, " ") + "]" + ipMap[lowerName] = ips + } else { + result[lowerName] = "[]" + ipMap[lowerName] = []string{} + } + } + return result, ipMap } + // Fallback: try parsing the legacy "ifs={...}" section from the state string. + ifsStart := strings.Index(stateStr, "ifs={") + if ifsStart == -1 { + return result, ipMap + } ifsStr := stateStr[ifsStart+5:] ifsEnd := strings.Index(ifsStr, "}") if ifsEnd == -1 { - return result + return result, ipMap } - - // Get the content between ifs={ } ifsContent := strings.TrimSpace(ifsStr[:ifsEnd]) - - // Split on "] " to get each interface entry entries := strings.Split(ifsContent, "] ") - for _, entry := range entries { if entry == "" { continue } - - // Split on ":[" parts := strings.Split(entry, ":[") if len(parts) != 2 { continue } - name := strings.TrimSpace(parts[0]) - state := "[" + strings.TrimSuffix(parts[1], "]") + "]" + stateEntry := "[" + strings.TrimSuffix(parts[1], "]") + "]" + lowerName := strings.ToLower(name) + result[lowerName] = stateEntry - result[strings.ToLower(name)] = state + // Attempt to extract IP addresses from stateEntry. + ipList := []string{} + trimmed := strings.Trim(stateEntry, "[]") + fields := strings.Fields(trimmed) + for _, f := range fields { + // We assume the IP is the part before the "/", if present. + candidate := strings.Split(f, "/")[0] + if ip := net.ParseIP(candidate); ip != nil { + ipList = append(ipList, candidate) + } + } + ipMap[lowerName] = ipList } - - return result + return result, ipMap } // extractIPv4FromState extracts an IPv4 address from an interface state string. From 41a97a6609e7c2c1e4eb86108e29db397609ba6a Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 15:00:46 -0500 Subject: [PATCH 071/100] clean up network change state logic --- cmd/cli/dns_proxy.go | 228 +++++++++++-------------------------------- 1 file changed, 58 insertions(+), 170 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index f9d6478..9f721a2 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "encoding/hex" - "encoding/json" "errors" "fmt" "net" @@ -12,7 +11,6 @@ import ( "os/exec" "runtime" "slices" - "sort" "strconv" "strings" "sync" @@ -1311,101 +1309,76 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Get map of valid interfaces validIfaces := validInterfacesMap() - // Log the delta for debugging - mainLog.Load().Warn(). + mainLog.Load().Debug(). Interface("old_state", delta.Old). Interface("new_state", delta.New). Msg("Network change detected") - // Parse old and new interface states, and extract IPs as well. - oldIfs, oldIPs := parseInterfaceState(delta.Old) - newIfs, newIPs := parseInterfaceState(delta.New) - changed := false - var changedIface, changedIfaceState string activeInterfaceExists := false - // Iterate over valid interfaces. + // Check each valid interface for changes for ifaceName := range validIfaces { - lname := strings.ToLower(ifaceName) - oldState, oldExists := oldIfs[lname] - newState, newExists := newIfs[lname] + oldIface := delta.Old.Interface[ifaceName] + newIface, exists := delta.New.Interface[ifaceName] + if !exists { + continue + } + oldIPs := delta.Old.InterfaceIPs[ifaceName] + newIPs := delta.New.InterfaceIPs[ifaceName] - // Check if the interface appears active in the new state. - if newState != "" && !strings.Contains(newState, "down") { + // Check if interface is up and has IPs + if newIface.IsUp() && len(newIPs) > 0 { activeInterfaceExists = true } - // Compare raw state strings... - stateChanged := (oldExists != newExists || oldState != newState) - - // ... and also compare the parsed IP slices. - ipChanged := false - oldIPSlice, okOld := oldIPs[lname] - newIPSlice, okNew := newIPs[lname] - if okOld && okNew { - // Create copies and sort them so that order does not matter. - sortedOld := append([]string(nil), oldIPSlice...) - sortedNew := append([]string(nil), newIPSlice...) - sort.Strings(sortedOld) - sort.Strings(sortedNew) - if !slices.Equal(sortedOld, sortedNew) { - ipChanged = true - } - } else if okOld != okNew { - ipChanged = true - } - - // If either the state string or the IPs have changed... - if stateChanged || ipChanged { - if newState != "" && !strings.Contains(newState, "down") { + // Compare interface states and IPs + if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) { + if newIface.IsUp() && len(newIPs) > 0 { changed = true - changedIface = ifaceName - // Prefer newState if present; if not, generate one from the IP slice. - if newState == "" && okNew { - changedIfaceState = "[" + strings.Join(newIPSlice, " ") + "]" - } else { - changedIfaceState = newState - } + mainLog.Load().Debug(). + Str("interface", ifaceName). + Interface("old_ips", oldIPs). + Interface("new_ips", newIPs). + Msg("Interface state or IPs changed") + break } - mainLog.Load().Warn(). - Str("interface", ifaceName). - Str("old_state", oldState). - Str("new_state", newState). - Msg("Valid interface changed state (IP change detected: " + strconv.FormatBool(ipChanged) + ")") - break - } else { - mainLog.Load().Warn(). - Str("interface", ifaceName). - Str("old_state", oldState). - Str("new_state", newState). - Msg("Valid interface unchanged") } } if !changed { - mainLog.Load().Warn().Msg("Ignoring interface change - no valid interfaces affected") + mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected") return } if !activeInterfaceExists { - mainLog.Load().Warn().Msg("No active interfaces found, skipping reinitialization") + mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization") return } - // Use the defaultRouteIP() result, or fall back to the changed interface's IPv4 from the new state. + // Get IPs from default route interface in new state selfIP := defaultRouteIP() - if selfIP == "" && changedIface != "" { - selfIP = extractIPv4FromState(changedIfaceState) - mainLog.Load().Info().Msgf("defaultRouteIP returned empty, using changed iface '%s' IPv4: %s", changedIface, selfIP) + var ipv6 string + + if delta.New.DefaultRouteInterface != "" { + for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] { + addr := ip.Addr() + if addr.Is4() && selfIP == "" && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + selfIP = addr.String() + } + if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + ipv6 = addr.String() + } + } } - // Extract IPv6 from the changed state. - ipv6 := extractIPv6FromState(changedIfaceState) + if selfIP == "" { + mainLog.Load().Debug().Msg("No valid IPv4 found on default route interface") + return + } if ip := net.ParseIP(selfIP); ip != nil { ctrld.SetDefaultLocalIPv4(ip) - // If we have a new IP, update the client info. if !isMobile() && p.ciTable != nil { p.ciTable.SetSelfIP(selfIP) } @@ -1422,115 +1395,30 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { return nil } -// parseInterfaceState parses the netmon state into two maps: -// 1. stateMap: a mapping from interfaces (lowercase) to their original state string, -// formatted in square brackets (e.g. "[192.168.1.200/24 fe80::69f6:e16e:8bdb:0000/64]"). -// 2. ipMap: a mapping from interfaces (lowercase) to a slice of IP addresses extracted from that state. -// -// It first attempts JSON parsing to pull out both the "Interface" and "InterfaceIPs" fields. -// If JSON parsing fails, it falls back to the legacy parsing logic. -func parseInterfaceState(state *netmon.State) (map[string]string, map[string][]string) { - result := make(map[string]string) // Interface name -> state string. - ipMap := make(map[string][]string) // Interface name -> slice of IP addresses. - - if state == nil { - return result, ipMap +// interfaceStatesEqual compares two interface states +func interfaceStatesEqual(a, b *netmon.Interface) bool { + if a == nil || b == nil { + return a == b } - stateStr := state.String() - - // Attempt to parse the state string as JSON so we can extract both "Interface" and "InterfaceIPs". - var raw map[string]json.RawMessage - if err := json.Unmarshal([]byte(stateStr), &raw); err == nil { - var interfaces map[string]interface{} - var interfaceIPs map[string][]string - - if v, ok := raw["Interface"]; ok { - _ = json.Unmarshal(v, &interfaces) - } - if v, ok := raw["InterfaceIPs"]; ok { - _ = json.Unmarshal(v, &interfaceIPs) - } - // For every interface in the "Interface" section, check for its IPs. - for name := range interfaces { - lowerName := strings.ToLower(name) - if ips, ok := interfaceIPs[name]; ok && len(ips) > 0 { - result[lowerName] = "[" + strings.Join(ips, " ") + "]" - ipMap[lowerName] = ips - } else { - result[lowerName] = "[]" - ipMap[lowerName] = []string{} - } - } - return result, ipMap - } - - // Fallback: try parsing the legacy "ifs={...}" section from the state string. - ifsStart := strings.Index(stateStr, "ifs={") - if ifsStart == -1 { - return result, ipMap - } - ifsStr := stateStr[ifsStart+5:] - ifsEnd := strings.Index(ifsStr, "}") - if ifsEnd == -1 { - return result, ipMap - } - ifsContent := strings.TrimSpace(ifsStr[:ifsEnd]) - entries := strings.Split(ifsContent, "] ") - for _, entry := range entries { - if entry == "" { - continue - } - parts := strings.Split(entry, ":[") - if len(parts) != 2 { - continue - } - name := strings.TrimSpace(parts[0]) - stateEntry := "[" + strings.TrimSuffix(parts[1], "]") + "]" - lowerName := strings.ToLower(name) - result[lowerName] = stateEntry - - // Attempt to extract IP addresses from stateEntry. - ipList := []string{} - trimmed := strings.Trim(stateEntry, "[]") - fields := strings.Fields(trimmed) - for _, f := range fields { - // We assume the IP is the part before the "/", if present. - candidate := strings.Split(f, "/")[0] - if ip := net.ParseIP(candidate); ip != nil { - ipList = append(ipList, candidate) - } - } - ipMap[lowerName] = ipList - } - return result, ipMap + return a.IsUp() == b.IsUp() } -// extractIPv4FromState extracts an IPv4 address from an interface state string. -// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239". -// If no valid IP can be found, it returns an empty string. -func extractIPv4FromState(state string) string { - trimmed := strings.Trim(state, "[]") - parts := strings.Fields(trimmed) - for _, part := range parts { - ipPart := strings.Split(part, "/")[0] - if ip := net.ParseIP(ipPart); ip != nil && ip.To4() != nil { - return ipPart - } +// interfaceIPsEqual compares two slices of IP prefixes +func interfaceIPsEqual(a, b []netip.Prefix) bool { + if len(a) != len(b) { + return false } - return "" -} -// extractIPv6FromState extracts an IPv6 address from an interface state string. -// For example, given "[172.16.226.239/22 llu6]", it returns "172.16.226.239". -// If no valid IP can be found, it returns an empty string. -func extractIPv6FromState(state string) string { - trimmed := strings.Trim(state, "[]") - parts := strings.Fields(trimmed) - for _, part := range parts { - ipPart := strings.Split(part, "/")[0] - if ip := net.ParseIP(ipPart); ip != nil && ip.To4() == nil { - return ipPart + // Create maps for easier comparison + aMap := make(map[netip.Prefix]bool) + for _, ip := range a { + aMap[ip] = true + } + + for _, ip := range b { + if !aMap[ip] { + return false } } - return "" + return true } From 72f0b89fdc4174fbce9707743cf837d40b30d1d8 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 15:03:23 -0500 Subject: [PATCH 072/100] remove redundant return --- cmd/cli/dns_proxy.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 9f721a2..a91c417 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1372,11 +1372,6 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { } } - if selfIP == "" { - mainLog.Load().Debug().Msg("No valid IPv4 found on default route interface") - return - } - if ip := net.ParseIP(selfIP); ip != nil { ctrld.SetDefaultLocalIPv4(ip) if !isMobile() && p.ciTable != nil { From 6644ce53f2d0aaade7288c299f2ef797c34ac417 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 15:22:58 -0500 Subject: [PATCH 073/100] fix interface IP CIDR parsing --- cmd/cli/dns_proxy.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index a91c417..779b29a 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1362,7 +1362,9 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { if delta.New.DefaultRouteInterface != "" { for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] { - addr := ip.Addr() + // Parse the CIDR notation to get just the IP + ipAddr, _ := netip.ParsePrefix(ip.String()) + addr := ipAddr.Addr() if addr.Is4() && selfIP == "" && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { selfIP = addr.String() } From 4a05fb6b289ce6115ecc912560cc60b3b8f168ea Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 15:30:54 -0500 Subject: [PATCH 074/100] use the changed iface if no default route is set yet --- cmd/cli/dns_proxy.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 779b29a..b00a355 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1316,7 +1316,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { changed := false activeInterfaceExists := false - + changeIPs := []netip.Prefix{} // Check each valid interface for changes for ifaceName := range validIfaces { oldIface := delta.Old.Interface[ifaceName] @@ -1336,6 +1336,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) { if newIface.IsUp() && len(newIPs) > 0 { changed = true + changeIPs = newIPs mainLog.Load().Debug(). Str("interface", ifaceName). Interface("old_ips", oldIPs). @@ -1372,6 +1373,18 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { ipv6 = addr.String() } } + } else { + // If no default route interface is set yet, use the changed IPs + for _, ip := range changeIPs { + ipAddr, _ := netip.ParsePrefix(ip.String()) + addr := ipAddr.Addr() + if addr.Is4() && selfIP == "" && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + selfIP = addr.String() + } + if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + ipv6 = addr.String() + } + } } if ip := net.ParseIP(selfIP); ip != nil { From fef85cadebb05979308953acd2ef46c1140456f5 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 16:03:52 -0500 Subject: [PATCH 075/100] filter non usabel IPs from state changes --- cmd/cli/dns_proxy.go | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index b00a355..ad65209 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1419,14 +1419,39 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool { return false } - // Create maps for easier comparison - aMap := make(map[netip.Prefix]bool) - for _, ip := range a { - aMap[ip] = true + isUsableIP := func(ip netip.Prefix) bool { + addr := ip.Addr() + return !addr.IsLinkLocalUnicast() && // fe80::/10 + !addr.IsLoopback() && // 127.0.0.1/8, ::1 + !addr.IsMulticast() && // 224.0.0.0/4, ff00::/8 + !addr.IsUnspecified() && // 0.0.0.0, :: + !addr.IsLinkLocalMulticast() && // 224.0.0.0/24 + !(addr.Is4() && addr.String() == "255.255.255.255") && // broadcast + !tsaddr.CGNATRange().Contains(addr) // 100.64.0.0/10 CGNAT } + // Filter and create maps for comparison + aMap := make(map[netip.Prefix]bool) + for _, ip := range a { + if isUsableIP(ip) { + aMap[ip] = true + } + } + + bMap := make(map[netip.Prefix]bool) for _, ip := range b { - if !aMap[ip] { + if isUsableIP(ip) { + bMap[ip] = true + } + } + + // Compare the filtered IP sets + if len(aMap) != len(bMap) { + return false + } + + for ip := range aMap { + if !bMap[ip] { return false } } From 917052723df132deea0720ddbe5879b6d687c457 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 16:12:10 -0500 Subject: [PATCH 076/100] don't overwrite OS resolver nameservers if there arent any --- cmd/cli/dns_proxy.go | 8 ++++++-- cmd/cli/prog.go | 2 +- resolver.go | 9 +++++++-- resolver_test.go | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index ad65209..5aa8e52 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1249,8 +1249,12 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { mainLog.Load().Debug().Msg("DNS reset completed") mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver() - mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msgf("no nameservers found, using existing OS resolver values") + } else { + mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) + } // start leaking queries immediately// start leaking queries immediately if networkChange { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index b26b814..d7a9a95 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -268,7 +268,7 @@ func (p *prog) preRun() { func (p *prog) postRun() { if !service.Interactive() { p.resetDNS() - ns := ctrld.InitializeOsResolver() + ns := ctrld.InitializeOsResolver(false) mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() p.csSetDnsDone <- struct{}{} diff --git a/resolver.go b/resolver.go index 650fd05..19ebc1f 100644 --- a/resolver.go +++ b/resolver.go @@ -130,8 +130,13 @@ func availableNameservers() []string { // // It's the caller's responsibility to ensure the system DNS is in a clean state before // calling this function. -func InitializeOsResolver() []string { - ns := initializeOsResolver(availableNameservers()) +func InitializeOsResolver(guardAgainstNoNameservers bool) []string { + nameservers := availableNameservers() + // if no nameservers, return empty slice so we dont remove all nameservers + if len(nameservers) == 0 && guardAgainstNoNameservers { + return []string{} + } + ns := initializeOsResolver(nameservers) resolverMutex.Lock() defer resolverMutex.Unlock() or = newResolverWithNameserver(ns) diff --git a/resolver_test.go b/resolver_test.go index de8cca0..e96e875 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -115,7 +115,7 @@ func Test_osResolver_InitializationRace(t *testing.T) { for range n { go func() { defer wg.Done() - InitializeOsResolver() + InitializeOsResolver(false) }() } wg.Wait() From bb2210b06ae236c36f95517a31ee1e265d6316ac Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 16:28:38 -0500 Subject: [PATCH 077/100] ip detection debugging --- cmd/cli/dns_proxy.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 5aa8e52..492e79e 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1366,12 +1366,15 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { var ipv6 string if delta.New.DefaultRouteInterface != "" { + mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface]) for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] { - // Parse the CIDR notation to get just the IP ipAddr, _ := netip.ParsePrefix(ip.String()) addr := ipAddr.Addr() - if addr.Is4() && selfIP == "" && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { - selfIP = addr.String() + if selfIP == "" && addr.Is4() { + mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + selfIP = addr.String() + } } if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { ipv6 = addr.String() @@ -1379,11 +1382,15 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { } } else { // If no default route interface is set yet, use the changed IPs + mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs) for _, ip := range changeIPs { ipAddr, _ := netip.ParsePrefix(ip.String()) addr := ipAddr.Addr() - if addr.Is4() && selfIP == "" && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { - selfIP = addr.String() + if selfIP == "" && addr.Is4() { + mainLog.Load().Debug().Msgf("checking IP: %s", addr.String()) + if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { + selfIP = addr.String() + } } if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() { ipv6 = addr.String() From 9618efbcde7b7c817301b8de2b44b90e4fda4cfc Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Feb 2025 18:18:07 -0500 Subject: [PATCH 078/100] improve network change ip filtering logic --- cmd/cli/dns_proxy.go | 81 ++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 492e79e..d5b77c7 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1313,9 +1313,12 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { // Get map of valid interfaces validIfaces := validInterfacesMap() + isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New) + mainLog.Load().Debug(). Interface("old_state", delta.Old). Interface("new_state", delta.New). + Bool("is_major_change", isMajorChange). Msg("Network change detected") changed := false @@ -1331,20 +1334,23 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { oldIPs := delta.Old.InterfaceIPs[ifaceName] newIPs := delta.New.InterfaceIPs[ifaceName] - // Check if interface is up and has IPs - if newIface.IsUp() && len(newIPs) > 0 { + // Filter new IPs to only those that are usable. + usableNewIPs := filterUsableIPs(newIPs) + + // Check if interface is up and has usable IPs. + if newIface.IsUp() && len(usableNewIPs) > 0 { activeInterfaceExists = true } - // Compare interface states and IPs + // Compare interface states and IPs (interfaceIPsEqual will itself filter the IPs). if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) { - if newIface.IsUp() && len(newIPs) > 0 { + if newIface.IsUp() && len(usableNewIPs) > 0 { changed = true - changeIPs = newIPs + changeIPs = usableNewIPs mainLog.Load().Debug(). Str("interface", ifaceName). Interface("old_ips", oldIPs). - Interface("new_ips", newIPs). + Interface("new_ips", usableNewIPs). Msg("Interface state or IPs changed") break } @@ -1424,45 +1430,40 @@ func interfaceStatesEqual(a, b *netmon.Interface) bool { return a.IsUp() == b.IsUp() } -// interfaceIPsEqual compares two slices of IP prefixes +// filterUsableIPs is a helper that returns only "usable" IP prefixes, +// filtering out link-local, loopback, multicast, unspecified, broadcast, or CGNAT addresses. +func filterUsableIPs(prefixes []netip.Prefix) []netip.Prefix { + var usable []netip.Prefix + for _, p := range prefixes { + addr := p.Addr() + if addr.IsLinkLocalUnicast() || + addr.IsLoopback() || + addr.IsMulticast() || + addr.IsUnspecified() || + addr.IsLinkLocalMulticast() || + (addr.Is4() && addr.String() == "255.255.255.255") || + tsaddr.CGNATRange().Contains(addr) { + continue + } + usable = append(usable, p) + } + return usable +} + +// Modified interfaceIPsEqual compares only the usable (non-link local, non-loopback, etc.) IP addresses. func interfaceIPsEqual(a, b []netip.Prefix) bool { - if len(a) != len(b) { + aUsable := filterUsableIPs(a) + bUsable := filterUsableIPs(b) + if len(aUsable) != len(bUsable) { return false } - isUsableIP := func(ip netip.Prefix) bool { - addr := ip.Addr() - return !addr.IsLinkLocalUnicast() && // fe80::/10 - !addr.IsLoopback() && // 127.0.0.1/8, ::1 - !addr.IsMulticast() && // 224.0.0.0/4, ff00::/8 - !addr.IsUnspecified() && // 0.0.0.0, :: - !addr.IsLinkLocalMulticast() && // 224.0.0.0/24 - !(addr.Is4() && addr.String() == "255.255.255.255") && // broadcast - !tsaddr.CGNATRange().Contains(addr) // 100.64.0.0/10 CGNAT + aMap := make(map[string]bool) + for _, ip := range aUsable { + aMap[ip.String()] = true } - - // Filter and create maps for comparison - aMap := make(map[netip.Prefix]bool) - for _, ip := range a { - if isUsableIP(ip) { - aMap[ip] = true - } - } - - bMap := make(map[netip.Prefix]bool) - for _, ip := range b { - if isUsableIP(ip) { - bMap[ip] = true - } - } - - // Compare the filtered IP sets - if len(aMap) != len(bMap) { - return false - } - - for ip := range aMap { - if !bMap[ip] { + for _, ip := range bUsable { + if !aMap[ip.String()] { return false } } From fb49cb71e3a99ec103387e82c6968e4e7f84109b Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 00:09:03 -0500 Subject: [PATCH 079/100] debounce upstream failure checking and failure counts --- cmd/cli/dns_proxy.go | 35 +++++++++++++++++++++++++++-------- cmd/cli/upstream_monitor.go | 17 ++++++++++++++++- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index d5b77c7..f7bbe6e 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -559,7 +559,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { if isNetworkErr { p.um.increaseFailureCount(upstreams[n]) if p.um.isDown(upstreams[n]) { - go p.checkUpstream(upstreams[n], upstreamConfig) + p.um.mu.RLock() + if !p.um.checking[upstreams[n]] { + go p.checkUpstream(upstreams[n], upstreamConfig) + } + p.um.mu.RUnlock() } } // For timeout error (i.e: context deadline exceed), force re-bootstrapping. @@ -569,6 +573,12 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } return nil } + // if we have an answer, we should reset the failure count + if answer != nil { + p.um.mu.Lock() + p.um.failureReq[upstreams[n]] = 0 + p.um.mu.Unlock() + } return answer } for n, upstreamConfig := range upstreamConfigs { @@ -1029,15 +1039,21 @@ func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamCon upstreamCh := make(chan string, len(failedUpstreams)) for name, uc := range failedUpstreams { go func(name string, uc *ctrld.UpstreamConfig) { - mainLog.Load().Debug(). - Str("upstream", name). - Msg("checking upstream") - for { select { case <-ctx.Done(): return default: + // make sure this upstream is not already being checked + p.um.mu.RLock() + if p.um.checking[name] { + p.um.mu.RUnlock() + continue + } + mainLog.Load().Debug(). + Str("upstream", name). + Msg("checking upstream") + p.checkUpstream(name, uc) mainLog.Load().Debug(). Str("upstream", name). @@ -1256,12 +1272,15 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) } - // start leaking queries immediately// start leaking queries immediately + // start leaking queries immediately if networkChange { // set all upstreams to failed and provide to performLeakingQuery failedUpstreams := make(map[string]*ctrld.UpstreamConfig) - for _, upstream := range p.cfg.Upstream { - failedUpstreams[upstream.Name] = upstream + // Iterate over both key and upstream to ensure that we have a fallback key + for key, upstream := range p.cfg.Upstream { + mainLog.Load().Debug().Msgf("network change upstream checking: %v, key: %q", upstream, key) + mapKey := upstreamPrefix + key + failedUpstreams[mapKey] = upstream } go p.performLeakingQuery(failedUpstreams, "all") diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index fc5d65d..df52a14 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -21,10 +21,11 @@ const ( type upstreamMonitor struct { cfg *ctrld.Config - mu sync.Mutex + mu sync.RWMutex checking map[string]bool down map[string]bool failureReq map[string]uint64 + recovered map[string]bool } func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { @@ -33,6 +34,7 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { checking: make(map[string]bool), down: make(map[string]bool), failureReq: make(map[string]uint64), + recovered: make(map[string]bool), } for n := range cfg.Upstream { upstream := upstreamPrefix + n @@ -47,6 +49,11 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { um.mu.Lock() defer um.mu.Unlock() + if um.recovered[upstream] { + mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream) + return + } + um.failureReq[upstream] += 1 failedCount := um.failureReq[upstream] @@ -77,6 +84,14 @@ func (um *upstreamMonitor) reset(upstream string) { um.failureReq[upstream] = 0 um.down[upstream] = false + um.recovered[upstream] = true + go func() { + // debounce the recovery to avoid incrementing failure counts already in flight + time.Sleep(1 * time.Second) + um.mu.Lock() + um.recovered[upstream] = false + um.mu.Unlock() + }() } // checkUpstream checks the given upstream status, periodically sending query to upstream From 1d207379cb9eca98940ed8f24bf58ff1cb145ad7 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 00:59:47 -0500 Subject: [PATCH 080/100] wait for healthy upstream before accepting queries on network change --- cmd/cli/dns_proxy.go | 141 ++++++++++++++++++++++++++++++++++++++----- cmd/cli/prog.go | 3 + 2 files changed, 128 insertions(+), 16 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index f7bbe6e..cd5fb60 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1250,7 +1250,7 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M // and re-initializing the OS resolver with the nameservers // applying listener back to the interface func (p *prog) reinitializeOSResolver(networkChange bool) { - // Cancel any existing operations + // Cancel any existing operations. p.resetCtxMu.Lock() defer p.resetCtxMu.Unlock() @@ -1261,9 +1261,11 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { }() mainLog.Load().Debug().Msg("attempting to reset DNS") + // Remove the listener immediately. p.resetDNS() mainLog.Load().Debug().Msg("DNS reset completed") + // Initialize OS resolver regardless of upstream recovery. mainLog.Load().Debug().Msg("initializing OS resolver") ns := ctrld.InitializeOsResolver(true) if len(ns) == 0 { @@ -1272,18 +1274,38 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) } - // start leaking queries immediately if networkChange { - // set all upstreams to failed and provide to performLeakingQuery - failedUpstreams := make(map[string]*ctrld.UpstreamConfig) - // Iterate over both key and upstream to ensure that we have a fallback key - for key, upstream := range p.cfg.Upstream { - mainLog.Load().Debug().Msgf("network change upstream checking: %v, key: %q", upstream, key) - mapKey := upstreamPrefix + key - failedUpstreams[mapKey] = upstream + // If we're already waiting on a recovery from a previous network change, + // cancel that wait to avoid stale recovery. + p.recoveryCancelMu.Lock() + if p.recoveryCancel != nil { + mainLog.Load().Debug().Msg("Cancelling previous recovery wait due to new network change") + p.recoveryCancel() + p.recoveryCancel = nil } - go p.performLeakingQuery(failedUpstreams, "all") + // Create a new context (with a timeout) for this recovery wait. + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + p.recoveryCancel = cancel + p.recoveryCancelMu.Unlock() + // Launch a goroutine that monitors the non-OS upstreams. + go func() { + recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener") + return + } + mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; reattaching DNS", recoveredUpstream) + p.setDNS() + p.logInterfacesState() + + // Clear the recovery cancel func as recovery has been achieved. + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() + }() + + // Optionally flush DNS caches (if needed). if err := FlushDNSCache(); err != nil { mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") } @@ -1291,13 +1313,11 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { // delay putting back the ctrld listener to allow for captive portal to trigger time.Sleep(5 * time.Second) } + } else { + // For non-network-change cases, immediately re-enable the listener. + p.setDNS() + p.logInterfacesState() } - - mainLog.Load().Debug().Msg("setting DNS configuration") - p.setDNS() - mainLog.Load().Debug().Msg("DNS configuration set successfully") - p.logInterfacesState() - } // FlushDNSCache flushes the DNS cache on macOS. @@ -1488,3 +1508,92 @@ func interfaceIPsEqual(a, b []netip.Prefix) bool { } return true } + +// checkUpstreamOnce sends a test query to the specified upstream. +// Returns nil if the upstream responds successfully. +func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error { + mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream) + + resolver, err := ctrld.NewResolver(uc) + if err != nil { + mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream) + return err + } + + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + + timeout := 1000 * time.Millisecond + if uc.Timeout > 0 { + timeout = time.Millisecond * time.Duration(uc.Timeout) + } + mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + uc.ReBootstrap() + mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream) + + start := time.Now() + _, err = resolver.Resolve(ctx, msg) + duration := time.Since(start) + + if err != nil { + mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration) + } else { + mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration) + } + return err +} + +// waitForNonOSResolverRecovery spawns a health check for each non-OS upstream +// and returns when the first one recovers. +func (p *prog) waitForNonOSResolverRecovery(ctx context.Context) (string, error) { + recoveredCh := make(chan string, 1) + var wg sync.WaitGroup + + // Loop over your upstream configuration; skip the OS resolver. + for k, uc := range p.cfg.Upstream { + if uc.Type == ctrld.ResolverTypeOS { + continue + } + + upstreamName := upstreamPrefix + k + mainLog.Load().Debug().Msgf("Launching recovery check for upstream: %s", upstreamName) + wg.Add(1) + go func(name string, uc *ctrld.UpstreamConfig) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + mainLog.Load().Debug().Msgf("Context done for upstream %s; stopping recovery check", name) + return + default: + if err := p.checkUpstreamOnce(name, uc); err == nil { + mainLog.Load().Debug().Msgf("Upstream %s is healthy; signaling recovery", name) + select { + case recoveredCh <- name: + default: + } + return + } else { + mainLog.Load().Debug().Msgf("Upstream %s not healthy, retrying...", name) + } + time.Sleep(checkUpstreamBackoffSleep) + } + } + }(upstreamName, uc) + } + + var recovered string + select { + case recovered = <-recoveredCh: + mainLog.Load().Debug().Msgf("Received recovered upstream: %s", recovered) + case <-ctx.Done(): + return "", ctx.Err() + } + + wg.Wait() + return recovered, nil +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d7a9a95..3dc9e1b 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -124,6 +124,9 @@ type prog struct { resetCtxMu sync.Mutex + recoveryCancelMu sync.Mutex + recoveryCancel context.CancelFunc + started chan struct{} onStartedDone chan struct{} onStarted []func() From 375844ff1a3d874bd3bb73619192f2ea99f0322b Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 01:50:24 -0500 Subject: [PATCH 081/100] remove handler log line --- cmd/cli/dns_proxy.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index cd5fb60..1185c6c 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -99,7 +99,6 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { } handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { - mainLog.Load().Debug().Msgf("serveDNS handler called") p.sema.acquire() defer p.sema.release() if len(m.Question) == 0 { From 4b05b6da7bb78dfb5f39b4be4bbf9d6819ba3d52 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 01:57:37 -0500 Subject: [PATCH 082/100] fix missing unlock --- cmd/cli/dns_proxy.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 1185c6c..c5e5fd2 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1049,6 +1049,7 @@ func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamCon p.um.mu.RUnlock() continue } + p.um.mu.RUnlock() mainLog.Load().Debug(). Str("upstream", name). Msg("checking upstream") From 0c74838740e089c92ba6f4c6991cbf667e8f4f5e Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 02:08:44 -0500 Subject: [PATCH 083/100] init os resolver after upstream recovers --- cmd/cli/dns_proxy.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index c5e5fd2..1b08c9d 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1265,15 +1265,6 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { p.resetDNS() mainLog.Load().Debug().Msg("DNS reset completed") - // Initialize OS resolver regardless of upstream recovery. - mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver(true) - if len(ns) == 0 { - mainLog.Load().Warn().Msgf("no nameservers found, using existing OS resolver values") - } else { - mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) - } - if networkChange { // If we're already waiting on a recovery from a previous network change, // cancel that wait to avoid stale recovery. @@ -1295,7 +1286,18 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener") return } - mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; reattaching DNS", recoveredUpstream) + + mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; initializing OS resolver and attaching DNS listener", recoveredUpstream) + + // Initialize OS resolver regardless of upstream recovery. + mainLog.Load().Debug().Msg("initializing OS resolver") + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msgf("no nameservers found, using existing OS resolver values") + } else { + mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) + } + p.setDNS() p.logInterfacesState() From 715bcc4aa1c98094e258ad1498a1664473b6babf Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 7 Feb 2025 14:25:18 +0700 Subject: [PATCH 084/100] internal/clientinfo: make SetSelfIP to update new data So after network changes, the new data will be used instead of the stale old one. --- internal/clientinfo/client_info.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index e6bda79..35d5dbb 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -173,6 +173,8 @@ func (t *Table) SetSelfIP(ip string) { t.selfIPLock.Lock() defer t.selfIPLock.Unlock() t.selfIP = ip + t.dhcp.selfIP = t.selfIP + t.dhcp.addSelf() } func (t *Table) init() { From 7a23f82192c528e636bffc87dd14c39a301ec8de Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 02:53:19 -0500 Subject: [PATCH 085/100] set leakingQueryReset to prevent watchdogs from resetting dns --- cmd/cli/dns_proxy.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 1b08c9d..6b26a33 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1281,6 +1281,9 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { // Launch a goroutine that monitors the non-OS upstreams. go func() { + p.leakingQueryReset.Store(true) + defer p.leakingQueryReset.Store(false) + recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx) if err != nil { mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener") From e1301ade963b9e6b1388daf214f1a711af4d8ec1 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 02:55:42 -0500 Subject: [PATCH 086/100] remove context timeout --- cmd/cli/dns_proxy.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 6b26a33..0eab69b 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1274,8 +1274,8 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { p.recoveryCancel() p.recoveryCancel = nil } - // Create a new context (with a timeout) for this recovery wait. - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + + ctx, cancel := context.WithCancel(context.Background()) p.recoveryCancel = cancel p.recoveryCancelMu.Unlock() From 398f71fd007182070673d38a46d68ff6205a2437 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 03:22:31 -0500 Subject: [PATCH 087/100] fix leakingQueryReset usages --- cmd/cli/dns_proxy.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0eab69b..d9f124e 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1255,10 +1255,6 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { defer p.resetCtxMu.Unlock() p.leakingQueryReset.Store(true) - defer func() { - time.Sleep(time.Second) - p.leakingQueryReset.Store(false) - }() mainLog.Load().Debug().Msg("attempting to reset DNS") // Remove the listener immediately. @@ -1281,9 +1277,6 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { // Launch a goroutine that monitors the non-OS upstreams. go func() { - p.leakingQueryReset.Store(true) - defer p.leakingQueryReset.Store(false) - recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx) if err != nil { mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener") @@ -1304,6 +1297,9 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { p.setDNS() p.logInterfacesState() + // allow watchers to reset changes + p.leakingQueryReset.Store(false) + // Clear the recovery cancel func as recovery has been achieved. p.recoveryCancelMu.Lock() p.recoveryCancel = nil @@ -1322,6 +1318,9 @@ func (p *prog) reinitializeOSResolver(networkChange bool) { // For non-network-change cases, immediately re-enable the listener. p.setDNS() p.logInterfacesState() + + // allow watchers to reset changes + p.leakingQueryReset.Store(false) } } From caf98b4dfe5db2eda7fa4ce12a12cc3ba56e8879 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 7 Feb 2025 15:30:03 +0700 Subject: [PATCH 088/100] cmd/cli: ignore log file config for interactive logging Otherwise, the interactive commands may clobber the existed log file of ctrld daemon, causing it stops writing log until restarted. --- cmd/cli/main.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 819797a..73a601d 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -108,8 +108,14 @@ func initLogging() []io.Writer { // initInteractiveLogging is like initLogging, but the ProxyLogger is discarded // to be used for all interactive commands. +// +// Current log file config will also be ignored. func initInteractiveLogging() { - initLogging() + old := cfg.Service.LogPath + cfg.Service.LogPath = "" + zerolog.TimeFieldFormat = time.RFC3339 + ".000" + initLoggingWithBackup(false) + cfg.Service.LogPath = old l := zerolog.New(io.Discard) ctrld.ProxyLogger.Store(&l) } From 253a57ca017a21e2d69456e9c495254a693b5232 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 7 Feb 2025 15:46:07 +0700 Subject: [PATCH 089/100] cmd/cli: make validating remote config non-fatal during restart Since we already have a config on disk, it's better to enforce what we have instead of fatal. --- cmd/cli/cli.go | 11 +++++++++-- cmd/cli/commands.go | 21 ++++++++++----------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 49adca3..07abf3c 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1807,10 +1807,17 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceRe } // doValidateCdRemoteConfig fetches and validates custom config for cdUID. -func doValidateCdRemoteConfig(cdUID string) { +func doValidateCdRemoteConfig(cdUID string, fatal bool) { rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) if err != nil { - mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) + logger := mainLog.Load().Fatal() + if !fatal { + logger = mainLog.Load().Warn() + } + logger.Err(err).Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) + if !fatal { + return + } } // validateCdRemoteConfig clobbers v, saving it here to restore later. oldV := v diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 43ec485..d340574 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -373,7 +373,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c } if cdUID != "" { - doValidateCdRemoteConfig(cdUID) + doValidateCdRemoteConfig(cdUID, true) } else if uid := cdUIDFromProvToken(); uid != "" { cdUID = uid mainLog.Load().Debug().Msg("using uid from provision token") @@ -698,7 +698,7 @@ func initRestartCmd() *cobra.Command { initInteractiveLogging() if cdMode { - doValidateCdRemoteConfig(cdUID) + doValidateCdRemoteConfig(cdUID, false) } if ir := runningIface(s); ir != nil { @@ -751,17 +751,16 @@ func initRestartCmd() *cobra.Command { } if doRestart() { - dir, err := socketDir() - if err != nil { + if dir, err := socketDir(); err == nil { + cc := newSocketControlClient(context.TODO(), s, dir) + if cc == nil { + mainLog.Load().Error().Msg("Could not complete service restart") + os.Exit(1) + } + _, _ = cc.post(ifacePath, nil) + } else { mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") - return } - cc := newSocketControlClient(context.TODO(), s, dir) - if cc == nil { - mainLog.Load().Error().Msg("Could not complete service restart") - os.Exit(1) - } - _, _ = cc.post(ifacePath, nil) mainLog.Load().Notice().Msg("Service restarted") } else { mainLog.Load().Error().Msg("Service restart failed") From af4b826b680cb128b9366901b4b1c67fb7a8c68e Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 7 Feb 2025 19:09:20 +0700 Subject: [PATCH 090/100] cmd/cli: implement valid interfaces map for all systems Previously, a valid interfaces map is only meaningful on Windows and Darwin, where ctrld needs to set DNS for all physical interfaces. With new network monitor, the valid interfaces is used for checking new changes, thus we have to implement the valid interfaces map for all systems. - On Linux, just retrieving all non-virtual interfaces. - On others, fallback to use default route interface only. --- cmd/cli/dns_proxy.go | 2 +- cmd/cli/net_linux.go | 52 +++++++++++++++++++++++++++++++++++++++++++ cmd/cli/net_others.go | 17 +++++++++++--- 3 files changed, 67 insertions(+), 4 deletions(-) create mode 100644 cmd/cli/net_linux.go diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index d9f124e..0447eef 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1366,7 +1366,7 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { changed := false activeInterfaceExists := false - changeIPs := []netip.Prefix{} + var changeIPs []netip.Prefix // Check each valid interface for changes for ifaceName := range validIfaces { oldIface := delta.Old.Interface[ifaceName] diff --git a/cmd/cli/net_linux.go b/cmd/cli/net_linux.go new file mode 100644 index 0000000..ea17d3d --- /dev/null +++ b/cmd/cli/net_linux.go @@ -0,0 +1,52 @@ +package cli + +import ( + "net" + "net/netip" + "os" + "strings" + + "tailscale.com/net/netmon" +) + +func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } + +// validInterface reports whether the *net.Interface is a valid one. +// Only non-virtual interfaces are considered valid. +func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { + _, ok := validIfacesMap[iface.Name] + return ok +} + +// validInterfacesMap returns a set containing non virtual interfaces. +func validInterfacesMap() map[string]struct{} { + m := make(map[string]struct{}) + vis := virtualInterfaces() + netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) { + if _, existed := vis[i.Name]; existed { + return + } + m[i.Name] = struct{}{} + }) + // Fallback to default route interface if found nothing. + if len(m) == 0 { + defaultRoute, err := netmon.DefaultRoute() + if err != nil { + return m + } + m[defaultRoute.InterfaceName] = struct{}{} + } + return m +} + +// virtualInterfaces returns a map of virtual interfaces on current machine. +func virtualInterfaces() map[string]struct{} { + s := make(map[string]struct{}) + entries, _ := os.ReadDir("/sys/devices/virtual/net") + for _, entry := range entries { + if entry.IsDir() { + s[strings.TrimSpace(entry.Name())] = struct{}{} + } + } + return s +} diff --git a/cmd/cli/net_others.go b/cmd/cli/net_others.go index edd89ec..f347278 100644 --- a/cmd/cli/net_others.go +++ b/cmd/cli/net_others.go @@ -1,11 +1,22 @@ -//go:build !darwin && !windows +//go:build !darwin && !windows && !linux package cli -import "net" +import ( + "net" + + "tailscale.com/net/netmon" +) func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true } -func validInterfacesMap() map[string]struct{} { return nil } +// validInterfacesMap returns a set containing only default route interfaces. +func validInterfacesMap() map[string]struct{} { + defaultRoute, err := netmon.DefaultRoute() + if err != nil { + return nil + } + return map[string]struct{}{defaultRoute.InterfaceName: {}} +} From 98042d8dbd7a482f6d9c0763f6442fb254f985fb Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 15:25:19 -0500 Subject: [PATCH 091/100] remove leaking logic in favor of recovery logic. --- cmd/cli/dns_proxy.go | 370 ++++++++++++++---------------------- cmd/cli/prog.go | 35 ++-- cmd/cli/resolvconf.go | 2 +- cmd/cli/upstream_monitor.go | 13 ++ 4 files changed, 174 insertions(+), 246 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0447eef..31e8aa8 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -432,23 +432,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) - upstreamMapKey := strings.Join(upstreams, "_") - - leaked := false - if len(upstreamConfigs) > 0 { - p.leakingQueryMu.Lock() - if p.leakingQueryRunning[upstreamMapKey] || p.leakingQueryRunning["all"] { - upstreamConfigs = nil - leaked = true - if p.leakingQueryRunning["all"] { - ctrld.Log(ctx, mainLog.Load().Debug(), "all upstreams marked down for network change, leaking query to OS resolver") - } else { - ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams) - } - } - p.leakingQueryMu.Unlock() - } - if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} @@ -472,11 +455,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // 4. Try remote upstream. isLanOrPtrQuery := false if req.ufr.matched { - if leaked { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v (leaked)", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) - } else { - ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) - } + ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) } else { switch { case isSrvLookup(req.msg): @@ -557,13 +536,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { isNetworkErr := errNetworkError(err) if isNetworkErr { p.um.increaseFailureCount(upstreams[n]) - if p.um.isDown(upstreams[n]) { - p.um.mu.RLock() - if !p.um.checking[upstreams[n]] { - go p.checkUpstream(upstreams[n], upstreamConfig) - } - p.um.mu.RUnlock() - } } // For timeout error (i.e: context deadline exceed), force re-bootstrapping. var e net.Error @@ -594,16 +566,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, logger, "DNS loop detected") continue } - if p.um.isDown(upstreams[n]) { - // never skip the OS resolver, since we usually query this resolver when we - // have no other upstreams to query - if upstreams[n] != upstreamOS { - logger. - Bool("is_os_resolver", upstreams[n] == upstreamOS) - ctrld.Log(ctx, logger, "Upstream is down") - continue - } - } answer := resolve(n, upstreamConfig, req.msg) if answer == nil { if serveStaleCache && staleAnswer != nil { @@ -651,20 +613,29 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return res } ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) - if p.leakOnUpstreamFailure() { - p.leakingQueryMu.Lock() - // get the map key as concact of upstreams - if !p.leakingQueryRunning[upstreamMapKey] { - p.leakingQueryRunning[upstreamMapKey] = true - // get a map of the failed upstreams - failedUpstreams := make(map[string]*ctrld.UpstreamConfig) - for n, upstream := range upstreamConfigs { - failedUpstreams[upstreams[n]] = upstream + + // if we have no healthy upstreams, trigger recovery flow + if p.recoverOnUpstreamFailure() { + if p.um.countHealthy(upstreams) == 0 { + p.recoveryCancelMu.Lock() + if p.recoveryCancel == nil { + var reason RecoveryReason + if upstreams[0] == upstreamOS { + reason = RecoveryReasonOSFailure + } else { + reason = RecoveryReasonRegularFailure + } + mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason) + go p.handleRecovery(reason) + } else { + mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection") } - go p.performLeakingQuery(failedUpstreams, upstreamMapKey) + p.recoveryCancelMu.Unlock() + } else { + mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger") } - p.leakingQueryMu.Unlock() } + answer := new(dns.Msg) answer.SetRcode(req.msg, dns.RcodeServerFailure) res.answer = answer @@ -994,86 +965,6 @@ func (p *prog) selfUninstallCoolOfPeriod() { p.selfUninstallMu.Unlock() } -// performLeakingQuery performs necessary works to leak queries to OS resolver. -// once we store the leakingQuery flag, we are leaking queries to OS resolver -// we then start testing all the upstreams forever, waiting for success, but in parallel -func (p *prog) performLeakingQuery(failedUpstreams map[string]*ctrld.UpstreamConfig, upstreamMapKey string) { - - mainLog.Load().Warn().Msgf("leaking queries for failed upstreams [%v] to OS resolver", failedUpstreams) - - // Signal dns watchers to stop, so changes made below won't be reverted. - p.leakingQueryMu.Lock() - p.leakingQueryRunning[upstreamMapKey] = true - p.leakingQueryMu.Unlock() - defer func() { - p.leakingQueryMu.Lock() - p.leakingQueryRunning[upstreamMapKey] = false - p.leakingQueryMu.Unlock() - mainLog.Load().Warn().Msg("stop leaking query") - }() - - // we only want to reset DNS when our resolver is broken - // this allows us to find the new OS resolver nameservers - // we skip the all upstream lock key to prevent duplicate calls - if p.um.isDown(upstreamOS) && upstreamMapKey != "all" { - - mainLog.Load().Debug().Msg("OS resolver is down, reinitializing") - p.reinitializeOSResolver(false) - - } - - // Test all failed upstreams in parallel - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // if a network change, delay upstream checks by 1s - // this is to ensure we actually leak queries to OS resolver - // We have observed some captive portals leak queries to public upstreams - // This can cause the captive portal on MacOS to not trigger a popup - if upstreamMapKey != "all" { - mainLog.Load().Debug().Msg("network change leaking queries, delaying upstream checks by 1s") - time.Sleep(1 * time.Second) - } - - upstreamCh := make(chan string, len(failedUpstreams)) - for name, uc := range failedUpstreams { - go func(name string, uc *ctrld.UpstreamConfig) { - for { - select { - case <-ctx.Done(): - return - default: - // make sure this upstream is not already being checked - p.um.mu.RLock() - if p.um.checking[name] { - p.um.mu.RUnlock() - continue - } - p.um.mu.RUnlock() - mainLog.Load().Debug(). - Str("upstream", name). - Msg("checking upstream") - - p.checkUpstream(name, uc) - mainLog.Load().Debug(). - Str("upstream", name). - Msg("upstream recovered") - upstreamCh <- name - return - } - } - }(name, uc) - } - - // Wait for any upstream to recover - name := <-upstreamCh - - mainLog.Load().Info(). - Str("upstream", name). - Msg("stopping leak as upstream recovered") - -} - // forceFetchingAPI sends signal to force syncing API config if run in cd mode, // and the domain == "cdUID.verify.controld.com" func (p *prog) forceFetchingAPI(domain string) { @@ -1245,85 +1136,6 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M return answer } -// reinitializeOSResolver reinitializes the OS resolver -// by removing ctrld listenr from the interface, collecting the network nameservers -// and re-initializing the OS resolver with the nameservers -// applying listener back to the interface -func (p *prog) reinitializeOSResolver(networkChange bool) { - // Cancel any existing operations. - p.resetCtxMu.Lock() - defer p.resetCtxMu.Unlock() - - p.leakingQueryReset.Store(true) - - mainLog.Load().Debug().Msg("attempting to reset DNS") - // Remove the listener immediately. - p.resetDNS() - mainLog.Load().Debug().Msg("DNS reset completed") - - if networkChange { - // If we're already waiting on a recovery from a previous network change, - // cancel that wait to avoid stale recovery. - p.recoveryCancelMu.Lock() - if p.recoveryCancel != nil { - mainLog.Load().Debug().Msg("Cancelling previous recovery wait due to new network change") - p.recoveryCancel() - p.recoveryCancel = nil - } - - ctx, cancel := context.WithCancel(context.Background()) - p.recoveryCancel = cancel - p.recoveryCancelMu.Unlock() - - // Launch a goroutine that monitors the non-OS upstreams. - go func() { - recoveredUpstream, err := p.waitForNonOSResolverRecovery(ctx) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("No non-OS upstream recovered within the timeout; not re-enabling the listener") - return - } - - mainLog.Load().Info().Msgf("Non-OS upstream %q recovered; initializing OS resolver and attaching DNS listener", recoveredUpstream) - - // Initialize OS resolver regardless of upstream recovery. - mainLog.Load().Debug().Msg("initializing OS resolver") - ns := ctrld.InitializeOsResolver(true) - if len(ns) == 0 { - mainLog.Load().Warn().Msgf("no nameservers found, using existing OS resolver values") - } else { - mainLog.Load().Warn().Msgf("re-initialized OS resolver with nameservers: %v", ns) - } - - p.setDNS() - p.logInterfacesState() - - // allow watchers to reset changes - p.leakingQueryReset.Store(false) - - // Clear the recovery cancel func as recovery has been achieved. - p.recoveryCancelMu.Lock() - p.recoveryCancel = nil - p.recoveryCancelMu.Unlock() - }() - - // Optionally flush DNS caches (if needed). - if err := FlushDNSCache(); err != nil { - mainLog.Load().Warn().Err(err).Msg("failed to flush DNS cache") - } - if runtime.GOOS == "darwin" { - // delay putting back the ctrld listener to allow for captive portal to trigger - time.Sleep(5 * time.Second) - } - } else { - // For non-network-change cases, immediately re-enable the listener. - p.setDNS() - p.logInterfacesState() - - // allow watchers to reset changes - p.leakingQueryReset.Store(false) - } -} - // FlushDNSCache flushes the DNS cache on macOS. func FlushDNSCache() error { // if not macOS, return @@ -1457,7 +1269,10 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { ctrld.SetDefaultLocalIPv6(ip) } mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6) - p.reinitializeOSResolver(true) + + if p.recoverOnUpstreamFailure() { + p.handleRecovery(RecoveryReasonNetworkChange) + } }) mon.Start() @@ -1551,53 +1366,154 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro return err } -// waitForNonOSResolverRecovery spawns a health check for each non-OS upstream -// and returns when the first one recovers. -func (p *prog) waitForNonOSResolverRecovery(ctx context.Context) (string, error) { +// handleRecovery performs a unified recovery by removing DNS settings, +// canceling existing recovery checks for network changes, but coalescing duplicate +// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout), +// and then re-applying the DNS settings. +func (p *prog) handleRecovery(reason RecoveryReason) { + mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings") + + // For network changes, cancel any existing recovery check because the network state has changed. + if reason == RecoveryReasonNetworkChange { + p.recoveryCancelMu.Lock() + if p.recoveryCancel != nil { + mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)") + p.recoveryCancel() + p.recoveryCancel = nil + } + p.recoveryCancelMu.Unlock() + } else { + // For upstream failures, if a recovery is already in progress, do nothing new. + p.recoveryCancelMu.Lock() + if p.recoveryCancel != nil { + mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") + p.recoveryCancelMu.Unlock() + return + } + p.recoveryCancelMu.Unlock() + } + + // Create a new recovery context without a fixed timeout. + p.recoveryCancelMu.Lock() + recoveryCtx, cancel := context.WithCancel(context.Background()) + p.recoveryCancel = cancel + p.recoveryCancelMu.Unlock() + + // Immediately remove our DNS settings from the interface. + // set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface + p.recoveryRunning.Store(true) + p.resetDNS() + + // For an OS failure, reinitialize OS resolver nameservers immediately. + if reason == RecoveryReasonOSFailure { + mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values") + } else { + mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + } + } + + // Build upstream map based on the recovery reason. + upstreams := p.buildRecoveryUpstreams(reason) + + // Wait indefinitely until one of the upstreams recovers. + recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) + if err != nil { + mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() + return + } + mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) + + // For network changes we also reinitialize the OS resolver. + if reason == RecoveryReasonNetworkChange { + ns := ctrld.InitializeOsResolver(true) + if len(ns) == 0 { + mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") + } else { + mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + } + } + + // Apply our DNS settings back and log the interface state. + p.setDNS() + p.logInterfacesState() + + // allow watchdogs to put the listener back on the interface if its changed for any reason + p.recoveryRunning.Store(false) + + // Clear the recovery cancellation for a clean slate. + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() +} + +// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers. +// It returns the name of the recovered upstream or an error if the check times out. +func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) { recoveredCh := make(chan string, 1) var wg sync.WaitGroup - // Loop over your upstream configuration; skip the OS resolver. - for k, uc := range p.cfg.Upstream { - if uc.Type == ctrld.ResolverTypeOS { - continue - } + mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams)) - upstreamName := upstreamPrefix + k - mainLog.Load().Debug().Msgf("Launching recovery check for upstream: %s", upstreamName) + for name, uc := range upstreams { wg.Add(1) go func(name string, uc *ctrld.UpstreamConfig) { defer wg.Done() + mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name) for { select { case <-ctx.Done(): - mainLog.Load().Debug().Msgf("Context done for upstream %s; stopping recovery check", name) + mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name) return default: + // checkUpstreamOnce will reset any failure counters on success. if err := p.checkUpstreamOnce(name, uc); err == nil { - mainLog.Load().Debug().Msgf("Upstream %s is healthy; signaling recovery", name) + mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name) select { case recoveredCh <- name: + mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name) default: + mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered") } return - } else { - mainLog.Load().Debug().Msgf("Upstream %s not healthy, retrying...", name) } + mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name) time.Sleep(checkUpstreamBackoffSleep) } } - }(upstreamName, uc) + }(name, uc) } var recovered string select { case recovered = <-recoveredCh: - mainLog.Load().Debug().Msgf("Received recovered upstream: %s", recovered) case <-ctx.Done(): return "", ctx.Err() } - wg.Wait() return recovered, nil } + +// buildRecoveryUpstreams constructs the map of upstream configurations to test. +// For OS failures we supply the manual OS resolver upstream configuration. +// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS). +func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig { + upstreams := make(map[string]*ctrld.UpstreamConfig) + switch reason { + case RecoveryReasonOSFailure: + upstreams[upstreamOS] = osUpstreamConfig + case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure: + // Use all configured upstreams except any OS type. + for k, uc := range p.cfg.Upstream { + if uc.Type != ctrld.ResolverTypeOS { + upstreams[upstreamPrefix+k] = uc + } + } + } + return upstreams +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 3dc9e1b..8a86bcf 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -48,6 +48,17 @@ const ( ctrldServiceName = "ctrld" ) +// RecoveryReason provides context for why we are waiting for recovery. +// recovery involves removing the listener IP from the interface and +// waiting for the upstreams to work before returning +type RecoveryReason int + +const ( + RecoveryReasonNetworkChange RecoveryReason = iota + RecoveryReasonRegularFailure + RecoveryReasonOSFailure +) + // ControlSocketName returns name for control unix socket. func ControlSocketName() string { if isMobile() { @@ -118,14 +129,9 @@ type prog struct { loopMu sync.Mutex loop map[string]bool - leakingQueryMu sync.Mutex - leakingQueryRunning map[string]bool - leakingQueryReset atomic.Bool - - resetCtxMu sync.Mutex - recoveryCancelMu sync.Mutex recoveryCancel context.CancelFunc + recoveryRunning atomic.Bool started chan struct{} onStartedDone chan struct{} @@ -429,7 +435,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) - p.leakingQueryRunning = make(map[string]bool) p.lanLoopGuard = newLoopGuard() p.ptrLoopGuard = newLoopGuard() p.cacheFlushDomainsMap = nil @@ -779,7 +784,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces mainLog.Load().Debug().Msg("stop dns watchdog") return case <-ticker.C: - if p.leakingQueryReset.Load() { + if p.recoveryRunning.Load() { return } if dnsChanged(iface, ns) { @@ -980,16 +985,10 @@ func findWorkingInterface(currentIface string) string { return currentIface } -// leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams. -func (p *prog) leakOnUpstreamFailure() bool { - if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil { - return *ptr - } - // Default is false on routers, since this leaking is only useful for devices that move between networks. - if router.Name() != "" { - return false - } - return true +// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure. +func (p *prog) recoverOnUpstreamFailure() bool { + // Default is false on routers, since this recovery flow is only useful for devices that move between networks. + return router.Name() == "" } func randomLocalIP() string { diff --git a/cmd/cli/resolvconf.go b/cmd/cli/resolvconf.go index 9d37d68..0f3f731 100644 --- a/cmd/cli/resolvconf.go +++ b/cmd/cli/resolvconf.go @@ -67,7 +67,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath) return case event, ok := <-watcher.Events: - if p.leakingQueryReset.Load() { + if p.recoveryRunning.Load() { return } if !ok { diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index df52a14..e42b3c1 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -145,3 +145,16 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { time.Sleep(checkUpstreamBackoffSleep) } } + +// countHealthy returns the number of upstreams in the provided map that are considered healthy. +func (um *upstreamMonitor) countHealthy(upstreams []string) int { + var count int + um.mu.RLock() + defer um.mu.RUnlock() + for _, upstream := range upstreams { + if !um.isDown(upstream) { + count++ + } + } + return count +} From d37d0e942ca502e4e9860989445d8b9832f4b758 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 15:46:12 -0500 Subject: [PATCH 092/100] fix countHealthy locking --- cmd/cli/upstream_monitor.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index e42b3c1..7489091 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -150,11 +150,11 @@ func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { func (um *upstreamMonitor) countHealthy(upstreams []string) int { var count int um.mu.RLock() - defer um.mu.RUnlock() for _, upstream := range upstreams { - if !um.isDown(upstream) { + if !um.down[upstream] { count++ } } + um.mu.RUnlock() return count } From 60e65a37a692f94871427397480b72d862984d44 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 7 Feb 2025 16:03:36 -0500 Subject: [PATCH 093/100] do the reset after recovery finished --- cmd/cli/dns_proxy.go | 3 ++ cmd/cli/upstream_monitor.go | 58 +------------------------------------ 2 files changed, 4 insertions(+), 57 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 31e8aa8..d2065ef 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1429,6 +1429,9 @@ func (p *prog) handleRecovery(reason RecoveryReason) { } mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) + // reset the upstream failure count and down state + p.um.reset(recovered) + // For network changes we also reinitialize the OS resolver. if reason == RecoveryReasonNetworkChange { ns := ctrld.InitializeOsResolver(true) diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 7489091..507a06f 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -1,12 +1,9 @@ package cli import ( - "context" "sync" "time" - "github.com/miekg/dns" - "github.com/Control-D-Inc/ctrld" ) @@ -80,11 +77,10 @@ func (um *upstreamMonitor) isDown(upstream string) bool { // reset marks an upstream as up and set failed queries counter to zero. func (um *upstreamMonitor) reset(upstream string) { um.mu.Lock() - defer um.mu.Unlock() - um.failureReq[upstream] = 0 um.down[upstream] = false um.recovered[upstream] = true + um.mu.Unlock() go func() { // debounce the recovery to avoid incrementing failure counts already in flight time.Sleep(1 * time.Second) @@ -94,58 +90,6 @@ func (um *upstreamMonitor) reset(upstream string) { }() } -// checkUpstream checks the given upstream status, periodically sending query to upstream -// until successfully. An upstream status/counter will be reset once it becomes reachable. -func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { - p.um.mu.Lock() - isChecking := p.um.checking[upstream] - if isChecking { - p.um.mu.Unlock() - return - } - p.um.checking[upstream] = true - p.um.mu.Unlock() - defer func() { - p.um.mu.Lock() - p.um.checking[upstream] = false - p.um.mu.Unlock() - }() - - resolver, err := ctrld.NewResolver(uc) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not check upstream") - return - } - msg := new(dns.Msg) - msg.SetQuestion(".", dns.TypeNS) - timeout := 1000 * time.Millisecond - if uc.Timeout > 0 { - timeout = time.Duration(uc.Timeout) * time.Millisecond - } - check := func() error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - uc.ReBootstrap() - _, err := resolver.Resolve(ctx, msg) - return err - } - endpoint := uc.Endpoint - if endpoint == "" { - endpoint = uc.Name - } - mainLog.Load().Warn().Msgf("upstream %q is offline", endpoint) - for { - if err := check(); err == nil { - mainLog.Load().Warn().Msgf("upstream %q is online", endpoint) - p.um.reset(upstream) - return - } else { - mainLog.Load().Debug().Msgf("checked upstream %q failed: %v", endpoint, err) - } - time.Sleep(checkUpstreamBackoffSleep) - } -} - // countHealthy returns the number of upstreams in the provided map that are considered healthy. func (um *upstreamMonitor) countHealthy(upstreams []string) int { var count int From 5007a87d3a3de4d169b0824b3a983ce408256a90 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Sat, 8 Feb 2025 08:17:43 +0700 Subject: [PATCH 094/100] cmd/cli: better error message when doing restart In case of remote config validation error during start, it's likely that there's problem with connecting to ControlD API. The ctrld daemon was restarted in this case, but may not ready to receive requests yet. This commit changes the error message to explicitly state that instead of a mis-leading "could not complete service restart". --- cmd/cli/cli.go | 13 ++++++++++--- cmd/cli/commands.go | 24 +++++++++++++++++------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 07abf3c..9d01206 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1478,6 +1478,12 @@ func removeOrgFlagsFromArgs(sc *service.Config) { // newSocketControlClient returns new control client after control server was started. func newSocketControlClient(ctx context.Context, s service.Service, dir string) *controlClient { + return newSocketControlClientWithTimeout(ctx, s, dir, dialSocketControlServerTimeout) +} + +// newSocketControlClientWithTimeout returns new control client after control server was started. +// The timeoutDuration controls how long to wait for the server. +func newSocketControlClientWithTimeout(ctx context.Context, s service.Service, dir string, timeoutDuration time.Duration) *controlClient { // Return early if service is not running. if status, err := s.Status(); err != nil || status != service.StatusRunning { return nil @@ -1486,7 +1492,7 @@ func newSocketControlClient(ctx context.Context, s service.Service, dir string) bo.LogLongerThan = 10 * time.Second cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) - timeout := time.NewTimer(30 * time.Second) + timeout := time.NewTimer(timeoutDuration) defer timeout.Stop() // The socket control server may not start yet, so attempt to ping @@ -1807,7 +1813,7 @@ func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceRe } // doValidateCdRemoteConfig fetches and validates custom config for cdUID. -func doValidateCdRemoteConfig(cdUID string, fatal bool) { +func doValidateCdRemoteConfig(cdUID string, fatal bool) error { rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) if err != nil { logger := mainLog.Load().Fatal() @@ -1816,7 +1822,7 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) { } logger.Err(err).Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) if !fatal { - return + return err } } // validateCdRemoteConfig clobbers v, saving it here to restore later. @@ -1847,6 +1853,7 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) { mainLog.Load().Warn().Msg("disregarding invalid custom config") } v = oldV + return nil } // uninstallInvalidCdUID performs self-uninstallation because the ControlD device does not exist. diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index d340574..49dfb8f 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -30,6 +30,9 @@ import ( "github.com/Control-D-Inc/ctrld/internal/router" ) +// dialSocketControlServerTimeout is the default timeout to wait when ping control server. +const dialSocketControlServerTimeout = 30 * time.Second + func initLogCmd() *cobra.Command { warnRuntimeLoggingNotEnabled := func() { mainLog.Load().Warn().Msg("runtime debug logging is not enabled") @@ -373,7 +376,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c } if cdUID != "" { - doValidateCdRemoteConfig(cdUID, true) + _ = doValidateCdRemoteConfig(cdUID, true) } else if uid := cdUIDFromProvToken(); uid != "" { cdUID = uid mainLog.Load().Debug().Msg("using uid from provision token") @@ -697,8 +700,9 @@ func initRestartCmd() *cobra.Command { initInteractiveLogging() + var validateConfigErr error if cdMode { - doValidateCdRemoteConfig(cdUID, false) + validateConfigErr = doValidateCdRemoteConfig(cdUID, false) } if ir := runningIface(s); ir != nil { @@ -752,12 +756,18 @@ func initRestartCmd() *cobra.Command { if doRestart() { if dir, err := socketDir(); err == nil { - cc := newSocketControlClient(context.TODO(), s, dir) - if cc == nil { - mainLog.Load().Error().Msg("Could not complete service restart") - os.Exit(1) + timeout := dialSocketControlServerTimeout + // If we failed to validate remote config above, it's likely that + // we are having problem with network connection. So using a shorter + // timeout than default one for better UX. + if validateConfigErr != nil { + timeout = 5 * time.Second + } + if cc := newSocketControlClientWithTimeout(context.TODO(), s, dir, timeout); cc != nil { + _, _ = cc.post(ifacePath, nil) + } else { + mainLog.Load().Warn().Err(err).Msg("Service was restarted, but ctrld process may not be ready yet") } - _, _ = cc.post(ifacePath, nil) } else { mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server") } From e3b99bf339d010c7305bc215035c504e5a6ee922 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 10 Feb 2025 18:24:29 -0500 Subject: [PATCH 095/100] mark upstream as down after 10s of no successful queries --- cmd/cli/upstream_monitor.go | 42 +++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 507a06f..acc02bb 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -9,7 +9,7 @@ import ( const ( // maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down. - maxFailureRequest = 100 + maxFailureRequest = 50 // checkUpstreamBackoffSleep is the time interval between each upstream checks. checkUpstreamBackoffSleep = 2 * time.Second ) @@ -23,15 +23,19 @@ type upstreamMonitor struct { down map[string]bool failureReq map[string]uint64 recovered map[string]bool + + // failureTimerActive tracks if a timer is already running for a given upstream. + failureTimerActive map[string]bool } func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { um := &upstreamMonitor{ - cfg: cfg, - checking: make(map[string]bool), - down: make(map[string]bool), - failureReq: make(map[string]uint64), - recovered: make(map[string]bool), + cfg: cfg, + checking: make(map[string]bool), + down: make(map[string]bool), + failureReq: make(map[string]uint64), + recovered: make(map[string]bool), + failureTimerActive: make(map[string]bool), } for n := range cfg.Upstream { upstream := upstreamPrefix + n @@ -42,6 +46,8 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { } // increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information. +// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down +// within 10 seconds if failures persist, without spawning duplicate goroutines. func (um *upstreamMonitor) increaseFailureCount(upstream string) { um.mu.Lock() defer um.mu.Unlock() @@ -54,13 +60,31 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { um.failureReq[upstream] += 1 failedCount := um.failureReq[upstream] - // Log the updated failure count + // Log the updated failure count. mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount) - // Check if the failure count has reached the threshold to mark the upstream as down. + // If this is the first failure and no timer is running, start a 10-second timer. + if failedCount == 1 && !um.failureTimerActive[upstream] { + um.failureTimerActive[upstream] = true + go func(upstream string) { + time.Sleep(10 * time.Second) + um.mu.Lock() + defer um.mu.Unlock() + // If no success occurred during the 10-second window (i.e. counter remains > 0) + // and the upstream is not in a recovered state, mark it as down. + if um.failureReq[upstream] > 0 && !um.recovered[upstream] { + um.down[upstream] = true + mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream]) + } + // Reset the timer flag so that a new timer can be spawned if needed. + um.failureTimerActive[upstream] = false + }(upstream) + } + + // If the failure count quickly reaches the threshold, mark the upstream as down immediately. if failedCount >= maxFailureRequest { um.down[upstream] = true - mainLog.Load().Warn().Msgf("upstream %q marked as down (failure count: %d)", upstream, failedCount) + mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) } else { um.down[upstream] = false } From 41a00c68ac68871f809f0d58045c1c555686ca9f Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 10 Feb 2025 18:39:45 -0500 Subject: [PATCH 096/100] fix down state handling --- cmd/cli/dns_proxy.go | 2 ++ cmd/cli/upstream_monitor.go | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index d2065ef..44582c5 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -545,9 +545,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { return nil } // if we have an answer, we should reset the failure count + // we dont use reset here since we dont want to prevent failure counts from being incremented if answer != nil { p.um.mu.Lock() p.um.failureReq[upstreams[n]] = 0 + p.um.down[upstreams[n]] = false p.um.mu.Unlock() } return answer diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index acc02bb..6e19e38 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -85,8 +85,6 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) { if failedCount >= maxFailureRequest { um.down[upstream] = true mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount) - } else { - um.down[upstream] = false } } From 9e83085f2a8719b3c0c73a1d564a6984503b919d Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 10 Feb 2025 19:50:20 -0500 Subject: [PATCH 097/100] handle old state missing interface crash --- cmd/cli/dns_proxy.go | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 44582c5..b7b5c5f 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1183,14 +1183,32 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error { var changeIPs []netip.Prefix // Check each valid interface for changes for ifaceName := range validIfaces { - oldIface := delta.Old.Interface[ifaceName] - newIface, exists := delta.New.Interface[ifaceName] - if !exists { + oldIface, oldExists := delta.Old.Interface[ifaceName] + newIface, newExists := delta.New.Interface[ifaceName] + if !newExists { continue } + oldIPs := delta.Old.InterfaceIPs[ifaceName] newIPs := delta.New.InterfaceIPs[ifaceName] + // if a valid interface did not exist in old + // check that its up and has usable IPs + if !oldExists { + // The interface is new (was not present in the old state). + usableNewIPs := filterUsableIPs(newIPs) + if newIface.IsUp() && len(usableNewIPs) > 0 { + changed = true + changeIPs = usableNewIPs + mainLog.Load().Debug(). + Str("interface", ifaceName). + Interface("new_ips", usableNewIPs). + Msg("Interface newly appeared (was not present in old state)") + break + } + continue + } + // Filter new IPs to only those that are usable. usableNewIPs := filterUsableIPs(newIPs) From 0fae584e653d0122017b218eb3a84a7c74c98689 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 10 Feb 2025 19:58:15 -0500 Subject: [PATCH 098/100] OS resolver retry catch all --- cmd/cli/dns_proxy.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index b7b5c5f..a4390de 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -638,6 +638,19 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } } + // attempt query to OS resolver while as a retry catch all + if upstreams[0] != upstreamOS { + ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all") + answer := resolve(0, osUpstreamConfig, req.msg) + if answer != nil { + ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful") + res.answer = answer + res.upstream = osUpstreamConfig.Endpoint + return res + } + ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed") + } + answer := new(dns.Msg) answer.SetRcode(req.msg, dns.RcodeServerFailure) res.answer = answer From 7d07d738dcd8430867e683b8a696962ffce5eb03 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 10 Feb 2025 20:23:16 -0500 Subject: [PATCH 099/100] fix failure count on OS retry --- cmd/cli/dns_proxy.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index a4390de..24660f8 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -509,8 +509,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { staleAnswer = answer } } - resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { - ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) + resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { + ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") @@ -525,17 +525,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { } return dnsResolver.Resolve(resolveCtx, msg) } - resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { + resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request") ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) } - answer, err := resolve1(n, upstreamConfig, msg) + answer, err := resolve1(upstream, upstreamConfig, msg) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") isNetworkErr := errNetworkError(err) if isNetworkErr { - p.um.increaseFailureCount(upstreams[n]) + p.um.increaseFailureCount(upstream) } // For timeout error (i.e: context deadline exceed), force re-bootstrapping. var e net.Error @@ -548,8 +548,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // we dont use reset here since we dont want to prevent failure counts from being incremented if answer != nil { p.um.mu.Lock() - p.um.failureReq[upstreams[n]] = 0 - p.um.down[upstreams[n]] = false + p.um.failureReq[upstream] = 0 + p.um.down[upstream] = false p.um.mu.Unlock() } return answer @@ -568,7 +568,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctrld.Log(ctx, logger, "DNS loop detected") continue } - answer := resolve(n, upstreamConfig, req.msg) + answer := resolve(upstreams[n], upstreamConfig, req.msg) if answer == nil { if serveStaleCache && staleAnswer != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response") @@ -641,7 +641,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { // attempt query to OS resolver while as a retry catch all if upstreams[0] != upstreamOS { ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all") - answer := resolve(0, osUpstreamConfig, req.msg) + answer := resolve(upstreamOS, osUpstreamConfig, req.msg) if answer != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful") res.answer = answer From 81e0bad739f2ce21ad8326a36aa865d3379f2378 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 10 Feb 2025 20:34:37 -0500 Subject: [PATCH 100/100] increase failure count for all queries with no answer --- cmd/cli/dns_proxy.go | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 24660f8..0bc042e 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -531,19 +531,6 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) } answer, err := resolve1(upstream, upstreamConfig, msg) - if err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") - isNetworkErr := errNetworkError(err) - if isNetworkErr { - p.um.increaseFailureCount(upstream) - } - // For timeout error (i.e: context deadline exceed), force re-bootstrapping. - var e net.Error - if errors.As(err, &e) && e.Timeout() { - upstreamConfig.ReBootstrap() - } - return nil - } // if we have an answer, we should reset the failure count // we dont use reset here since we dont want to prevent failure counts from being incremented if answer != nil { @@ -551,8 +538,24 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { p.um.failureReq[upstream] = 0 p.um.down[upstream] = false p.um.mu.Unlock() + return answer } - return answer + + ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") + + // increase failure count when there is no answer + // rehardless of what kind of error we get + p.um.increaseFailureCount(upstream) + + if err != nil { + // For timeout error (i.e: context deadline exceed), force re-bootstrapping. + var e net.Error + if errors.As(err, &e) && e.Timeout() { + upstreamConfig.ReBootstrap() + } + } + + return nil } for n, upstreamConfig := range upstreamConfigs { if upstreamConfig == nil {