From 00e9d2bdd30602f23d2776fd6cdcc14af5e77fd5 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 6 May 2025 19:59:11 +0700 Subject: [PATCH 1/9] all: do not listen on 0.0.0.0 on desktop clients Since this may create security vulnerabilities such as DNS amplification or abusing because the listener was exposed to the entire local network. --- cmd/cli/cli.go | 11 ++++++++--- cmd/cli/dns_proxy.go | 4 +++- desktop_darwin.go | 7 +++++++ desktop_others.go | 9 +++++++++ desktop_windows.go | 7 +++++++ 5 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 desktop_darwin.go create mode 100644 desktop_others.go create mode 100644 desktop_windows.go diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 31e1fcb..b99c48f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1216,13 +1216,18 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti // For Windows server with local Dns server running, we can only try on random local IP. hasLocalDnsServer := hasLocalDnsServerRunning() notRouter := router.Name() == "" + isDesktop := ctrld.IsDesktopPlatform() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { listener.IP = "0.0.0.0" - if hasLocalDnsServer { - // Windows Server lies to us that we could listen on 0.0.0.0:53 - // even there's a process already done that, stick to local IP only. + // Windows Server lies to us that we could listen on 0.0.0.0:53 + // even there's a process already done that, stick to local IP only. + // + // For desktop clients, also stick the listener to the local IP only. + // Listening on 0.0.0.0 would expose it to the entire local network, potentially + // creating security vulnerabilities (such as DNS amplification or abusing). + if hasLocalDnsServer || isDesktop { listener.IP = "127.0.0.1" } lcc[n].IP = true diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 6a214e5..2311260 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1042,8 +1042,10 @@ func (p *prog) queryFromSelf(ip string) bool { return false } +// needRFC1918Listeners reports whether ctrld need to spawn listener for RFC 1918 addresses. +// This is helpful for non-desktop platforms to receive queries from LAN clients. func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool { - return lc.IP == "127.0.0.1" && lc.Port == 53 + return lc.IP == "127.0.0.1" && lc.Port == 53 && !ctrld.IsDesktopPlatform() } // ipFromARPA parses a FQDN arpa domain and return the IP address if valid. diff --git a/desktop_darwin.go b/desktop_darwin.go new file mode 100644 index 0000000..039c0fa --- /dev/null +++ b/desktop_darwin.go @@ -0,0 +1,7 @@ +package ctrld + +// IsDesktopPlatform indicates if ctrld is running on a desktop platform, +// currently defined as macOS or Windows workstation. +func IsDesktopPlatform() bool { + return true +} diff --git a/desktop_others.go b/desktop_others.go new file mode 100644 index 0000000..de486e7 --- /dev/null +++ b/desktop_others.go @@ -0,0 +1,9 @@ +//go:build !windows && !darwin + +package ctrld + +// IsDesktopPlatform indicates if ctrld is running on a desktop platform, +// currently defined as macOS or Windows workstation. +func IsDesktopPlatform() bool { + return false +} diff --git a/desktop_windows.go b/desktop_windows.go new file mode 100644 index 0000000..4e9526b --- /dev/null +++ b/desktop_windows.go @@ -0,0 +1,7 @@ +package ctrld + +// IsDesktopPlatform indicates if ctrld is running on a desktop platform, +// currently defined as macOS or Windows workstation. +func IsDesktopPlatform() bool { + return isWindowsWorkStation() +} From 62f73bcaa291ca26da402ccbeb62c76e6d1855af Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 8 May 2025 22:29:59 +0700 Subject: [PATCH 2/9] all: preserve search domains settings So bare hostname will be resolved as expected when ctrld is running. --- cmd/cli/os_darwin.go | 5 +++- cmd/cli/os_freebsd.go | 15 ++++++++-- cmd/cli/os_linux.go | 8 +++++- cmd/cli/os_windows.go | 4 +++ cmd/cli/search_domains_unix.go | 14 ++++++++++ cmd/cli/search_domains_windows.go | 43 +++++++++++++++++++++++++++++ internal/resolvconffile/dns.go | 12 +++++++- internal/resolvconffile/dns_test.go | 2 +- nameservers_unix.go | 4 +-- 9 files changed, 99 insertions(+), 8 deletions(-) create mode 100644 cmd/cli/search_domains_unix.go create mode 100644 cmd/cli/search_domains_windows.go diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index 841be76..4c358b0 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -47,6 +47,9 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e // networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1 // TODO(cuonglm): use system API func setDNS(iface *net.Interface, nameservers []string) error { + // Note that networksetup won't modify search domains settings, + // This assignment is just a placeholder to silent linter. + _ = searchDomains cmd := "networksetup" args := []string{"-setdnsservers", iface.Name} args = append(args, nameservers...) @@ -88,7 +91,7 @@ func restoreDNS(iface *net.Interface) (err error) { } func currentDNS(_ *net.Interface) []string { - return resolvconffile.NameServers("") + return resolvconffile.NameServers() } // currentStaticDNS returns the current static DNS settings of given interface. diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index 72da485..d66e4bf 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -7,6 +7,7 @@ import ( "tailscale.com/control/controlknobs" "tailscale.com/health" + "tailscale.com/util/dnsname" "github.com/Control-D-Inc/ctrld/internal/dns" "github.com/Control-D-Inc/ctrld/internal/resolvconffile" @@ -50,7 +51,17 @@ func setDNS(iface *net.Interface, nameservers []string) error { ns = append(ns, netip.MustParseAddr(nameserver)) } - if err := r.SetDNS(dns.OSConfig{Nameservers: ns}); err != nil { + osConfig := dns.OSConfig{ + Nameservers: ns, + SearchDomains: []dnsname.FQDN{}, + } + if sds, err := searchDomains(); err == nil { + osConfig.SearchDomains = sds + } else { + mainLog.Load().Debug().Err(err).Msg("failed to get search domains list") + } + + if err := r.SetDNS(osConfig); err != nil { mainLog.Load().Error().Err(err).Msg("failed to set DNS") return err } @@ -83,7 +94,7 @@ func restoreDNS(iface *net.Interface) (err error) { } func currentDNS(_ *net.Interface) []string { - return resolvconffile.NameServers("") + return resolvconffile.NameServers() } // currentStaticDNS returns the current static DNS settings of given interface. diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index e2302a3..8caad63 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -71,6 +71,11 @@ func setDNS(iface *net.Interface, nameservers []string) error { Nameservers: ns, SearchDomains: []dnsname.FQDN{}, } + if sds, err := searchDomains(); err == nil { + osConfig.SearchDomains = sds + } else { + mainLog.Load().Debug().Err(err).Msg("failed to get search domains list") + } trySystemdResolve := false if err := r.SetDNS(osConfig); err != nil { if strings.Contains(err.Error(), "Rejected send message") && @@ -196,7 +201,8 @@ func restoreDNS(iface *net.Interface) (err error) { } func currentDNS(iface *net.Interface) []string { - for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} { + resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() } + for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} { if ns := fn(iface.Name); len(ns) > 0 { return ns } diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index e1bcd9a..7ebc54a 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -100,6 +100,10 @@ func setDNS(iface *net.Interface, nameservers []string) error { } } + // Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function. + // searchDomains is still implemented for Windows just in case Windows API changes in future versions. + _ = searchDomains + if len(serversV4) == 0 && len(serversV6) == 0 { return errors.New("invalid DNS nameservers") } diff --git a/cmd/cli/search_domains_unix.go b/cmd/cli/search_domains_unix.go new file mode 100644 index 0000000..de3998e --- /dev/null +++ b/cmd/cli/search_domains_unix.go @@ -0,0 +1,14 @@ +//go:build unix + +package cli + +import ( + "tailscale.com/util/dnsname" + + "github.com/Control-D-Inc/ctrld/internal/resolvconffile" +) + +// searchDomains returns the current search domains config. +func searchDomains() ([]dnsname.FQDN, error) { + return resolvconffile.SearchDomains() +} diff --git a/cmd/cli/search_domains_windows.go b/cmd/cli/search_domains_windows.go new file mode 100644 index 0000000..320a322 --- /dev/null +++ b/cmd/cli/search_domains_windows.go @@ -0,0 +1,43 @@ +package cli + +import ( + "fmt" + "syscall" + + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/util/dnsname" +) + +// searchDomains returns the current search domains config. +func searchDomains() ([]dnsname.FQDN, error) { + flags := winipcfg.GAAFlagIncludeGateways | + winipcfg.GAAFlagIncludePrefix + + aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags) + if err != nil { + return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err) + } + + var sds []dnsname.FQDN + for _, aa := range aas { + if aa.OperStatus != winipcfg.IfOperStatusUp { + continue + } + + // Skip if software loopback or other non-physical types + // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows + if aa.IfType == winipcfg.IfTypeSoftwareLoopback { + continue + } + + for a := aa.FirstDNSSuffix; a != nil; a = a.Next { + d, err := dnsname.ToFQDN(a.String()) + if err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String()) + continue + } + sds = append(sds, d) + } + } + return sds, nil +} diff --git a/internal/resolvconffile/dns.go b/internal/resolvconffile/dns.go index 3ce0f91..0d532eb 100644 --- a/internal/resolvconffile/dns.go +++ b/internal/resolvconffile/dns.go @@ -6,6 +6,7 @@ import ( "net" "tailscale.com/net/dns/resolvconffile" + "tailscale.com/util/dnsname" ) const resolvconfPath = "/etc/resolv.conf" @@ -22,7 +23,7 @@ func NameServersWithPort() []string { return ns } -func NameServers(_ string) []string { +func NameServers() []string { c, err := resolvconffile.ParseFile(resolvconfPath) if err != nil { return nil @@ -33,3 +34,12 @@ func NameServers(_ string) []string { } return ns } + +// SearchDomains returns the current search domains config in /etc/resolv.conf file. +func SearchDomains() ([]dnsname.FQDN, error) { + c, err := resolvconffile.ParseFile(resolvconfPath) + if err != nil { + return nil, err + } + return c.SearchDomains, nil +} diff --git a/internal/resolvconffile/dns_test.go b/internal/resolvconffile/dns_test.go index ba571af..7f7a64c 100644 --- a/internal/resolvconffile/dns_test.go +++ b/internal/resolvconffile/dns_test.go @@ -9,7 +9,7 @@ import ( ) func TestNameServers(t *testing.T) { - ns := NameServers("") + ns := NameServers() require.NotNil(t, ns) t.Log(ns) } diff --git a/nameservers_unix.go b/nameservers_unix.go index d7af521..d8e6035 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -14,7 +14,7 @@ import ( // currentNameserversFromResolvconf returns the current nameservers set from /etc/resolv.conf file. func currentNameserversFromResolvconf() []string { - return resolvconffile.NameServers("") + return resolvconffile.NameServers() } // dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file. @@ -34,7 +34,7 @@ func dnsFromResolvConf() []string { time.Sleep(retryInterval) } - nss := resolvconffile.NameServers("") + nss := resolvconffile.NameServers() var localDNS []string seen := make(map[string]bool) From a983dfaee2af68dc7e0d94dcb00d5190df785f86 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 21 May 2025 19:33:54 +0700 Subject: [PATCH 3/9] all: optimizing multiple queries to upstreams To guard ctrld from possible DoS to remote upstreams, this commit implements following things: - Optimizing multiple queries with the same domain and qtype to use singleflight group, so there's only 1 query to remote upstreams at any time. - Adding a hot cache with 1 second TTL, so repeated queries will re-use the result from cache if existed, preventing unnecessary requests to remote upstreams. --- resolver.go | 83 ++++++++++++++++++++++++++++++-- resolver_test.go | 120 +++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 189 insertions(+), 14 deletions(-) diff --git a/resolver.go b/resolver.go index a44ddb2..52515f9 100644 --- a/resolver.go +++ b/resolver.go @@ -9,12 +9,14 @@ import ( "net/netip" "runtime" "slices" + "strings" "sync" "sync/atomic" "time" "github.com/miekg/dns" "github.com/rs/zerolog" + "golang.org/x/sync/singleflight" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" ) @@ -216,6 +218,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { type osResolver struct { lanServers atomic.Pointer[[]string] publicServers atomic.Pointer[[]string] + group *singleflight.Group + cache *sync.Map } type osResolverResult struct { @@ -273,10 +277,75 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired return dnsClient.ExchangeContext(ctx, msg, server) } +const hotCacheTTL = time.Second + // Resolve resolves DNS queries using pre-configured nameservers. -// Query is sent to all nameservers concurrently, and the first +// The Query is sent to all nameservers concurrently, and the first // success response will be returned. +// +// To guard against unexpected DoS to upstreams, multiple queries of +// the same Qtype to a domain will be shared, so there's only 1 qps +// for each upstream at any time. +// +// Further, a hot cache will be used, so repeated queries will be cached +// for a short period (currently 1 second), reducing unnecessary traffics +// sent to upstreams. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if len(msg.Question) == 0 { + return nil, errors.New("no question found") + } + domain := strings.TrimSuffix(msg.Question[0].Name, ".") + qtype := msg.Question[0].Qtype + + // Unique key for the singleflight group. + key := fmt.Sprintf("%s:%d:", domain, qtype) + + // Checking the cache first. + if val, ok := o.cache.Load(key); ok { + if val, ok := val.(*dns.Msg); ok { + Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) + res := val.Copy() + res.SetRcode(msg, val.Rcode) + return res, nil + } + } + + // Ensure only one DNS query is in flight for the key. + v, err, shared := o.group.Do(key, func() (interface{}, error) { + msg, err := o.resolve(ctx, msg) + if err != nil { + return nil, err + } + // If we got an answer, storing it to the hot cache for hotCacheTTL + // This prevents possible DoS to upstream, ensuring there's only 1 QPS. + o.cache.Store(key, msg) + // Depends on go runtime scheduling, the result may end up in hot cache longer + // than hotCacheTTL duration. However, this is fine since we only want to guard + // against DoS attack. The result will be cleaned from the cache eventually. + time.AfterFunc(hotCacheTTL, func() { + o.removeCache(key) + }) + return msg, nil + }) + if err != nil { + return nil, err + } + + sharedMsg, ok := v.(*dns.Msg) + if !ok { + return nil, fmt.Errorf("invalid answer for key: %s", key) + } + res := sharedMsg.Copy() + res.SetRcode(msg, sharedMsg.Rcode) + if shared { + Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) + } + + return res, nil +} + +// resolve sends the query to current nameservers. +func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { publicServers := *o.publicServers.Load() var nss []string if p := o.lanServers.Load(); p != nil { @@ -431,6 +500,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error return nil, errors.Join(errs...) } +func (o *osResolver) removeCache(key string) { + o.cache.Delete(key) +} + type legacyResolver struct { uc *UpstreamConfig } @@ -627,10 +700,10 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // newResolverWithNameserver returns an OS resolver from given nameservers list. // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers) - r := &osResolver{} + r := &osResolver{ + group: &singleflight.Group{}, + cache: &sync.Map{}, + } var publicNss []string var lanNss []string for _, ns := range slices.Sorted(slices.Values(nameservers)) { diff --git a/resolver_test.go b/resolver_test.go index fb6831b..a75e748 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "sync/atomic" "testing" "time" @@ -16,8 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() - resolver := &osResolver{} - resolver.publicServers.Store(&[]string{"127.0.0.127:5353"}) + resolver := newResolverWithNameserver([]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -50,8 +50,7 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) { t.Error("not a LAN query") return } - resolver := &osResolver{} - resolver.publicServers.Store(&[]string{"76.76.2.0:53"}) + resolver := newResolverWithNameserver([]string{"76.76.2.0:53"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -107,11 +106,9 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { }() // We now create an osResolver which has both a LAN and public nameserver. - resolver := &osResolver{} - // Explicitly store the LAN nameserver. - resolver.lanServers.Store(&[]string{lanAddr}) - // And store the public nameservers. - resolver.publicServers.Store(&publicNS) + nss := []string{lanAddr} + nss = append(nss, publicNS...) + resolver := newResolverWithNameserver(nss) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) @@ -139,6 +136,102 @@ func Test_osResolver_InitializationRace(t *testing.T) { wg.Wait() } +func Test_osResolver_Singleflight(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + n := 10 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + _, err := or.Resolve(context.Background(), m) + if err != nil { + t.Error(err) + } + }() + } + wg.Wait() + + // All above queries should only make 1 call to server. + if call.Load() != 1 { + t.Fatalf("expected 1 result from singleflight lookup, got %d", call) + } +} + +func Test_osResolver_HotCache(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + + // Make 2 repeated queries to server, should hit hot cache. + for i := 0; i < 2; i++ { + if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + t.Fatal(err) + } + } + if call.Load() != 1 { + t.Fatalf("cache not hit, server was called: %d", call.Load()) + } + + timeoutChan := make(chan struct{}) + time.AfterFunc(5*time.Second, func() { + close(timeoutChan) + }) + + for { + select { + case <-timeoutChan: + t.Fatal("timed out waiting for cache cleaned") + default: + count := 0 + or.cache.Range(func(key, value interface{}) bool { + count++ + return true + }) + if count != 0 { + t.Logf("hot cache is not empty: %d elements", count) + continue + } + } + break + } + + if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + t.Fatal(err) + } + if call.Load() != 2 { + t.Fatal("cache hit unexpectedly") + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -208,3 +301,12 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc { w.WriteMsg(m) } } + +func countHandler(call *atomic.Int64) dns.HandlerFunc { + return func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, dns.RcodeSuccess) + w.WriteMsg(m) + call.Add(1) + } +} From b4faf82f76a577c56cc83a954cf6fe9b01560587 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 26 May 2025 20:49:03 +0700 Subject: [PATCH 4/9] all: set edns0 cookie for shared message For cached or singleflight messages, the edns0 cookie is currently shared among all of them, causing mismatch cookie warning from clients. The ctrld proxy should re-set client cookies for each request separately, even though they use the same shared answer. --- cmd/cli/dns_proxy.go | 2 +- dns.go | 30 ++++++++++++++++++ resolver.go | 4 +-- resolver_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 3 deletions(-) create mode 100644 dns.go diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 2311260..33012fa 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -500,7 +500,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { continue } answer := cachedValue.Msg.Copy() - answer.SetRcode(req.msg, answer.Rcode) + ctrld.SetCacheReply(answer, req.msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..f2b71a5 --- /dev/null +++ b/dns.go @@ -0,0 +1,30 @@ +package ctrld + +import ( + "github.com/miekg/dns" +) + +// SetCacheReply extracts and stores the necessary data from the message for a cached answer. +func SetCacheReply(answer, msg *dns.Msg, code int) { + answer.SetRcode(msg, code) + cCookie := getEdns0Cookie(msg.IsEdns0()) + sCookie := getEdns0Cookie(answer.IsEdns0()) + if cCookie != nil && sCookie != nil { + // Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes. + // See https://datatracker.ietf.org/doc/html/rfc7873#section-4 + sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:] + } +} + +// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present. +func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE { + if opt == nil { + return nil + } + for _, o := range opt.Option { + if e, ok := o.(*dns.EDNS0_COOKIE); ok { + return e + } + } + return nil +} diff --git a/resolver.go b/resolver.go index 52515f9..c20f1f5 100644 --- a/resolver.go +++ b/resolver.go @@ -305,7 +305,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error if val, ok := val.(*dns.Msg); ok { Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) res := val.Copy() - res.SetRcode(msg, val.Rcode) + SetCacheReply(res, msg, val.Rcode) return res, nil } } @@ -336,7 +336,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error return nil, fmt.Errorf("invalid answer for key: %s", key) } res := sharedMsg.Copy() - res.SetRcode(msg, sharedMsg.Rcode) + SetCacheReply(res, msg, sharedMsg.Rcode) if shared { Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) } diff --git a/resolver_test.go b/resolver_test.go index a75e748..ebcad16 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -2,6 +2,8 @@ package ctrld import ( "context" + "crypto/rand" + "encoding/hex" "net" "sync" "sync/atomic" @@ -232,6 +234,54 @@ func Test_osResolver_HotCache(t *testing.T) { } } +func Test_Edns0_CacheReply(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + + do := func() *dns.Msg { + msg := m.Copy() + msg.SetEdns0(4096, true) + cookieOption := new(dns.EDNS0_COOKIE) + cookieOption.Code = dns.EDNS0COOKIE + cookieOption.Cookie = generateEdns0ClientCookie() + msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption) + + answer, err := or.Resolve(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return answer + } + answer1 := do() + answer2 := do() + // Ensure the cache was hit, so we can check that edns0 cookie must be modified. + if call.Load() != 1 { + t.Fatalf("cache not hit, server was called: %d", call.Load()) + } + cookie1 := getEdns0Cookie(answer1.IsEdns0()) + cookie2 := getEdns0Cookie(answer2.IsEdns0()) + if cookie1 == nil || cookie2 == nil { + t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2) + } + if cookie1.Cookie == cookie2.Cookie { + t.Fatalf("edns0 cookie is not modified: %v", cookie1) + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -306,7 +356,32 @@ func countHandler(call *atomic.Int64) dns.HandlerFunc { return func(w dns.ResponseWriter, msg *dns.Msg) { m := new(dns.Msg) m.SetRcode(msg, dns.RcodeSuccess) + if cookie := getEdns0Cookie(msg.IsEdns0()); cookie != nil { + if m.IsEdns0() == nil { + m.SetEdns0(4096, false) + } + cookieOption := new(dns.EDNS0_COOKIE) + cookieOption.Code = dns.EDNS0COOKIE + cookieOption.Cookie = generateEdns0ServerCookie(cookie.Cookie) + m.IsEdns0().Option = append(m.IsEdns0().Option, cookieOption) + } w.WriteMsg(m) call.Add(1) } } + +func generateEdns0ClientCookie() string { + cookie := make([]byte, 8) + if _, err := rand.Read(cookie); err != nil { + panic(err) + } + return hex.EncodeToString(cookie) +} + +func generateEdns0ServerCookie(clientCookie string) string { + cookie := make([]byte, 32) + if _, err := rand.Read(cookie); err != nil { + panic(err) + } + return clientCookie + hex.EncodeToString(cookie) +} From 8dc34f8bf5fa189196ee4efa0a7169b7fa92a827 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 3 Jun 2025 18:39:07 +0700 Subject: [PATCH 5/9] internal/net: improve IPv6 support detection with multiple common ports Changed the IPv6 support detection to try multiple common ports (HTTP/HTTPS) instead of just testing against a DNS port. The function now returns both the IPv6 support status and the successful port that confirmed the connectivity. This makes the IPv6 detection more reliable by not depending solely on DNS port availability. Previously, the function only tested connectivity to a DNS port (53) over IPv6. Now it tries to connect to commonly available ports like HTTP (80) and HTTPS (443) until it finds a working one, making the detection more robust in environments where certain ports might be blocked. --- internal/net/net.go | 38 ++++++++++++++++++++++++++++++-------- internal/net/net_test.go | 7 ++++++- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/internal/net/net.go b/internal/net/net.go index d5bd75e..f4b5586 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -17,10 +17,17 @@ import ( ) const ( - v4BootstrapDNS = "76.76.2.22:53" - v6BootstrapDNS = "[2606:1a40::22]:53" + v4BootstrapDNS = "76.76.2.22:53" + v6BootstrapDNS = "[2606:1a40::22]:53" + v6BootstrapIP = "2606:1a40::22" + defaultHTTPSPort = "443" + defaultHTTPPort = "80" + defaultDNSPort = "53" + probeStackTimeout = 2 * time.Second ) +var commonIPv6Ports = []string{defaultHTTPSPort, defaultHTTPPort, defaultDNSPort} + var Dialer = &net.Dialer{ Resolver: &net.Resolver{ PreferGo: true, @@ -33,8 +40,6 @@ var Dialer = &net.Dialer{ }, } -const probeStackTimeout = 2 * time.Second - var probeStackDialer = &net.Dialer{ Resolver: Dialer.Resolver, Timeout: probeStackTimeout, @@ -50,12 +55,28 @@ func init() { stackOnce.Store(new(sync.Once)) } -func supportIPv6(ctx context.Context) bool { - c, err := probeStackDialer.DialContext(ctx, "tcp6", v6BootstrapDNS) +// supportIPv6 checks for IPv6 connectivity by attempting to connect to predefined ports +// on a specific IPv6 address. +// Returns a boolean indicating if IPv6 is supported and the port on which the connection succeeded. +// If no connection is successful, returns false and an empty string. +func supportIPv6(ctx context.Context) (supported bool, successPort string) { + for _, port := range commonIPv6Ports { + if canConnectToIPv6Port(ctx, port) { + return true, string(port) + } + } + return false, "" +} + +// canConnectToIPv6Port attempts to establish a TCP connection to the specified port +// using IPv6. Returns true if the connection was successful. +func canConnectToIPv6Port(ctx context.Context, port string) bool { + address := net.JoinHostPort(v6BootstrapIP, port) + conn, err := probeStackDialer.DialContext(ctx, "tcp6", address) if err != nil { return false } - c.Close() + _ = conn.Close() return true } @@ -110,7 +131,8 @@ func SupportsIPv6ListenLocal() bool { // IPv6Available is like SupportsIPv6, but always do the check without caching. func IPv6Available(ctx context.Context) bool { - return supportIPv6(ctx) + hasV6, _ := supportIPv6(ctx) + return hasV6 } // IsIPv6 checks if the provided IP is v6. diff --git a/internal/net/net_test.go b/internal/net/net_test.go index d28dbed..7df3e09 100644 --- a/internal/net/net_test.go +++ b/internal/net/net_test.go @@ -12,7 +12,12 @@ func TestProbeStackTimeout(t *testing.T) { go func() { defer close(done) close(started) - supportIPv6(context.Background()) + hasV6, port := supportIPv6(context.Background()) + if hasV6 { + t.Logf("connect to port %s using ipv6: %v", port, hasV6) + } else { + t.Log("ipv6 is not available") + } }() <-started From 628c4302aa4a140737265098c1f76ea24793c176 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 4 Jun 2025 17:40:27 +0700 Subject: [PATCH 6/9] cmd/cli: preserve search domains when reverting resolv.conf Fixes search domains not being preserved when the resolv.conf file is reverted to its previous state. This ensures that important domain search configuration is maintained during DNS configuration changes. The search domains handling was missing in setResolvConf function, which is responsible for restoring DNS settings. --- cmd/cli/prog.go | 7 +++++++ cmd/cli/prog_linux.go | 5 +---- cmd/cli/resolvconf_not_darwin_unix.go | 16 +++++++++++++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 089bfd0..dd8de9f 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -70,10 +70,17 @@ func ControlSocketName() string { } } +// logf is a function variable used for logging formatted debug messages with optional arguments. +// This is used only when creating a new DNS OS configurator. var logf = func(format string, args ...any) { mainLog.Load().Debug().Msgf(format, args...) } +// noopLogf is like logf but discards formatted log messages and arguments without any processing. +// +//lint:ignore U1000 use in newLoopbackOSConfigurator +var noopLogf = func(format string, args ...any) {} + var svcConfig = &service.Config{ Name: ctrldServiceName, DisplayName: "Control-D Helper Service", diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index cc0046b..2e5c7c7 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -9,15 +9,12 @@ import ( "strings" "github.com/kardianos/service" - "tailscale.com/control/controlknobs" - "tailscale.com/health" - "github.com/Control-D-Inc/ctrld/internal/dns" "github.com/Control-D-Inc/ctrld/internal/router" ) func init() { - if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo"); err == nil { + if r, err := newLoopbackOSConfigurator(); err == nil { useSystemdResolved = r.Mode() == "systemd-resolved" } // Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911 diff --git a/cmd/cli/resolvconf_not_darwin_unix.go b/cmd/cli/resolvconf_not_darwin_unix.go index 7181e95..af33572 100644 --- a/cmd/cli/resolvconf_not_darwin_unix.go +++ b/cmd/cli/resolvconf_not_darwin_unix.go @@ -13,9 +13,9 @@ import ( "github.com/Control-D-Inc/ctrld/internal/dns" ) -// setResolvConf sets the content of resolv.conf file using the given nameservers list. +// setResolvConf sets the content of the resolv.conf file using the given nameservers list. func setResolvConf(iface *net.Interface, ns []netip.Addr) error { - r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter. + r, err := newLoopbackOSConfigurator() if err != nil { return err } @@ -24,12 +24,17 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error { Nameservers: ns, SearchDomains: []dnsname.FQDN{}, } + if sds, err := searchDomains(); err == nil { + oc.SearchDomains = sds + } else { + mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file") + } return r.SetDNS(oc) } // shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator. func shouldWatchResolvconf() bool { - r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter. + r, err := newLoopbackOSConfigurator() if err != nil { return false } @@ -40,3 +45,8 @@ func shouldWatchResolvconf() bool { return false } } + +// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface. +func newLoopbackOSConfigurator() (dns.OSConfigurator, error) { + return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo") +} From a20fbf95de48dcee637980f080d619b2f73c8fa0 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 6 Jun 2025 20:19:44 +0700 Subject: [PATCH 7/9] all: enhanced TLS certificate verification error messages Added more descriptive error messages for TLS certificate verification failures across DoH, DoT, DoQ, and DoH3 protocols. The error messages now include: - Certificate subject information - Issuer organization details - Common name of the certificate This helps users and developers better understand certificate validation failures by providing specific details about the untrusted certificate, rather than just a generic "unknown authority" message. Example error message change: Before: "certificate signed by unknown authority" After: "certificate signed by unknown authority: TestCA, TestOrg, TestIssuerOrg" --- doh.go | 51 +++++++++++ doh_test.go | 243 ++++++++++++++++++++++++++++++++++++++++++++++++++++ doq.go | 2 +- doq_test.go | 223 +++++++++++++++++++++++++++++++++++++++++++++++ dot.go | 3 +- 5 files changed, 519 insertions(+), 3 deletions(-) create mode 100644 doq_test.go diff --git a/doh.go b/doh.go index 73b2764..3459cb8 100644 --- a/doh.go +++ b/doh.go @@ -2,6 +2,7 @@ package ctrld import ( "context" + "crypto/tls" "encoding/base64" "errors" "fmt" @@ -120,6 +121,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro resp, err = c.Do(req.Clone(retryCtx)) } if err != nil { + err = wrapUrlError(err) if r.isDoH3 { if closer, ok := c.Transport.(io.Closer); ok { closer.Close() @@ -208,3 +210,52 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header { } return header } + +// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer. +// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output. +// If no certificate-related information is available, it simply returns the original error unmodified. +func wrapCertificateVerificationError(err error) error { + var tlsErr *tls.CertificateVerificationError + if errors.As(err, &tlsErr) { + if len(tlsErr.UnverifiedCertificates) > 0 { + cert := tlsErr.UnverifiedCertificates[0] + // Extract a more user-friendly issuer name + var issuer string + var organization string + if len(cert.Issuer.Organization) > 0 { + organization = cert.Issuer.Organization[0] + issuer = organization + } else if cert.Issuer.CommonName != "" { + issuer = cert.Issuer.CommonName + } else { + issuer = cert.Issuer.String() + } + + // Get the organization from the subject field as well + if len(cert.Subject.Organization) > 0 { + organization = cert.Subject.Organization[0] + } + + // Extract the subject information + subjectCN := cert.Subject.CommonName + if subjectCN == "" && len(cert.Subject.Organization) > 0 { + subjectCN = cert.Subject.Organization[0] + } + return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer) + } + } + return err +} + +// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context. +func wrapUrlError(err error) error { + var urlErr *url.Error + if errors.As(err, &urlErr) { + var tlsErr *tls.CertificateVerificationError + if errors.As(urlErr.Err, &tlsErr) { + urlErr.Err = wrapCertificateVerificationError(tlsErr) + return urlErr + } + } + return err +} diff --git a/doh_test.go b/doh_test.go index 8d3e011..92fa79f 100644 --- a/doh_test.go +++ b/doh_test.go @@ -1,8 +1,22 @@ package ctrld import ( + "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "net" + "net/http" + "net/http/httptest" + "net/url" "runtime" + "strings" "testing" + "time" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go/http3" ) func Test_dohOsHeaderValue(t *testing.T) { @@ -21,3 +35,232 @@ func Test_dohOsHeaderValue(t *testing.T) { t.Fatalf("missing decoding value for: %q", runtime.GOOS) } } + +func Test_wrapUrlError(t *testing.T) { + tests := []struct { + name string + err error + wantErr string + }{ + { + name: "No wrapping for non-URL errors", + err: errors.New("plain error"), + wantErr: "plain error", + }, + { + name: "URL error without TLS error", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: errors.New("underlying error"), + }, + wantErr: "Get \"https://example.com\": underlying error", + }, + { + name: "TLS error with missing unverified certificate data", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: &tls.CertificateVerificationError{ + UnverifiedCertificates: nil, + Err: &x509.UnknownAuthorityError{}, + }, + }, + wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority`, + }, + { + name: "TLS error with valid certificate data", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: &tls.CertificateVerificationError{ + UnverifiedCertificates: []*x509.Certificate{ + { + Subject: pkix.Name{ + CommonName: "BadSubjectCN", + Organization: []string{"BadSubjectOrg"}, + }, + Issuer: pkix.Name{ + CommonName: "BadIssuerCN", + Organization: []string{"BadIssuerOrg"}, + }, + }, + }, + Err: &x509.UnknownAuthorityError{}, + }, + }, + wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority: BadSubjectCN, BadSubjectOrg, BadIssuerOrg`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := wrapUrlError(tt.err) + if gotErr.Error() != tt.wantErr { + t.Errorf("wrapCertificateVerificationError() error = %v, want %v", gotErr, tt.wantErr) + } + }) + } +} + +func Test_ClientCertificateVerificationError(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/dns-message") + }) + tlsServer, cert := testTLSServer(t, handler) + tlsServerUrl, err := url.Parse(tlsServer.URL) + if err != nil { + t.Fatal(err) + } + quicServer := newTestQUICServer(t) + http3Server := newTestHTTP3Server(t, handler) + + tests := []struct { + name string + uc *UpstreamConfig + }{ + { + "doh", + &UpstreamConfig{ + Name: "doh", + Type: ResolverTypeDOH, + Endpoint: tlsServer.URL, + Timeout: 1000, + }, + }, + { + "doh3", + &UpstreamConfig{ + Name: "doh3", + Type: ResolverTypeDOH3, + Endpoint: http3Server.addr, + Timeout: 5000, + }, + }, + { + "doq", + &UpstreamConfig{ + Name: "doq", + Type: ResolverTypeDOQ, + Endpoint: quicServer.addr, + Timeout: 5000, + }, + }, + { + "dot", + &UpstreamConfig{ + Name: "dot", + Type: ResolverTypeDOT, + Endpoint: net.JoinHostPort(tlsServerUrl.Hostname(), tlsServerUrl.Port()), + Timeout: 1000, + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tc.uc.Init() + tc.uc.SetupBootstrapIP() + r, err := NewResolver(tc.uc) + if err != nil { + t.Fatal(err) + } + msg := new(dns.Msg) + msg.SetQuestion("verify.controld.com.", dns.TypeA) + msg.RecursionDesired = true + _, err = r.Resolve(context.Background(), msg) + // Verify the error contains the expected certificate information + if err == nil { + t.Fatal("expected certificate verification error, got nil") + } + + // You can check the error contains information about the test certificate + if !strings.Contains(err.Error(), cert.Issuer.CommonName) { + t.Fatalf("error should contain issuer information %q, got: %v", cert.Issuer.CommonName, err) + } + }) + } +} + +// testTLSServer creates an HTTPS test server with a self-signed certificate +// returns the server and its certificate for verification testing +// testTLSServer creates an HTTPS test server with a self-signed certificate +func testTLSServer(t *testing.T, handler http.Handler) (*httptest.Server, *x509.Certificate) { + t.Helper() + + testCert := generateTestCertificate(t) + + // Create a test server + server := httptest.NewUnstartedServer(handler) + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + } + server.StartTLS() + + // Add cleanup + t.Cleanup(server.Close) + + return server, testCert.cert +} + +// testHTTP3Server represents a structure for an HTTP/3 test server with its server instance, TLS certificate, and address. +type testHTTP3Server struct { + server *http3.Server + cert *x509.Certificate + addr string +} + +// newTestHTTP3Server creates and starts a test HTTP/3 server with a given handler and returns the server instance. +func newTestHTTP3Server(t *testing.T, handler http.Handler) *testHTTP3Server { + t.Helper() + + testCert := generateTestCertificate(t) + + // First create a listener to get the actual port + udpAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + t.Fatalf("failed to create UDP listener: %v", err) + } + + // Get the actual address + actualAddr := udpConn.LocalAddr().String() + + // Create TLS config + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + NextProtos: []string{"h3"}, // HTTP/3 protocol identifier + } + + // Create HTTP/3 server + server := &http3.Server{ + Handler: handler, + TLSConfig: tlsConfig, + } + + // Start the server with the existing UDP connection + go func() { + if err := server.Serve(udpConn); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Logf("HTTP/3 server error: %v", err) + } + }() + + h3Server := &testHTTP3Server{ + server: server, + cert: testCert.cert, + addr: actualAddr, + } + + // Add cleanup + t.Cleanup(func() { + server.Close() + udpConn.Close() + }) + + // Wait a bit for the server to be ready + time.Sleep(100 * time.Millisecond) + + return h3Server +} diff --git a/doq.go b/doq.go index 3c3f9e8..0903411 100644 --- a/doq.go +++ b/doq.go @@ -43,7 +43,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls. continue } if err != nil { - return nil, err + return nil, wrapCertificateVerificationError(err) } return answer, nil } diff --git a/doq_test.go b/doq_test.go new file mode 100644 index 0000000..430a22a --- /dev/null +++ b/doq_test.go @@ -0,0 +1,223 @@ +// test_helpers.go +package ctrld + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "strings" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go" +) + +// testCertificate represents a test certificate with its components +type testCertificate struct { + cert *x509.Certificate + tlsCert tls.Certificate + template *x509.Certificate +} + +// generateTestCertificate creates a self-signed certificate for testing +func generateTestCertificate(t *testing.T) *testCertificate { + t.Helper() + + // Generate private key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate private key: %v", err) + } + + // Create certificate template + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + CommonName: "Test CA", + }, + Issuer: pkix.Name{ + Organization: []string{"Test Issuer Org"}, + CommonName: "Test Issuer CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + + // Create certificate + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + t.Fatalf("failed to parse certificate: %v", err) + } + + // Create TLS certificate + tlsCert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: privateKey, + } + + return &testCertificate{ + cert: cert, + tlsCert: tlsCert, + template: template, + } +} + +// testQUICServer is a structure representing a test QUIC server for handling connections and streams. +// listener is the QUIC listener used to accept incoming connections. +// cert is the x509 certificate used by the server for authentication. +// addr is the address on which the test server is running. +type testQUICServer struct { + listener *quic.Listener + cert *x509.Certificate + addr string +} + +// newTestQUICServer creates and initializes a test QUIC server with TLS configuration and starts accepting connections. +func newTestQUICServer(t *testing.T) *testQUICServer { + t.Helper() + + testCert := generateTestCertificate(t) + + // Create TLS config + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + NextProtos: []string{"doq"}, + } + + // Create QUIC listener + listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil) + if err != nil { + t.Fatalf("failed to create QUIC listener: %v", err) + } + + server := &testQUICServer{ + listener: listener, + cert: testCert.cert, + addr: listener.Addr().String(), + } + + // Start handling connections + go server.serve(t) + + // Add cleanup + t.Cleanup(func() { + listener.Close() + }) + + return server +} + +// serve handles incoming connections on the QUIC listener and delegates them to connection handlers in separate goroutines. +func (s *testQUICServer) serve(t *testing.T) { + for { + conn, err := s.listener.Accept(context.Background()) + if err != nil { + // Check if the error is due to the listener being closed + if strings.Contains(err.Error(), "server closed") { + return + } + t.Logf("failed to accept connection: %v", err) + continue + } + + go s.handleConnection(t, conn) + } +} + +// handleConnection manages an individual QUIC connection by accepting and handling incoming streams in separate goroutines. +func (s *testQUICServer) handleConnection(t *testing.T, conn quic.Connection) { + for { + stream, err := conn.AcceptStream(context.Background()) + if err != nil { + return + } + + go s.handleStream(t, stream) + } +} + +// handleStream processes a single QUIC stream, reads DNS messages, generates a response, and sends it back to the client. +func (s *testQUICServer) handleStream(t *testing.T, stream quic.Stream) { + defer stream.Close() + + // Read length (2 bytes) + lenBuf := make([]byte, 2) + _, err := stream.Read(lenBuf) + if err != nil { + t.Logf("failed to read message length: %v", err) + return + } + msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1]) + + // Read message + msgBuf := make([]byte, msgLen) + _, err = stream.Read(msgBuf) + if err != nil { + t.Logf("failed to read message: %v", err) + return + } + + // Parse DNS message + msg := new(dns.Msg) + if err := msg.Unpack(msgBuf); err != nil { + t.Logf("failed to unpack DNS message: %v", err) + return + } + + // Create response + response := new(dns.Msg) + response.SetReply(msg) + response.Authoritative = true + + // Add a test answer + if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA { + response.Answer = append(response.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP("192.0.2.1"), // TEST-NET-1 address + }) + } + + // Pack response + respBytes, err := response.Pack() + if err != nil { + t.Logf("failed to pack response: %v", err) + return + } + + // Write length + respLen := uint16(len(respBytes)) + _, err = stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)}) + if err != nil { + t.Logf("failed to write response length: %v", err) + return + } + + // Write response + _, err = stream.Write(respBytes) + if err != nil { + t.Logf("failed to write response: %v", err) + return + } +} diff --git a/dot.go b/dot.go index 67d1ff8..295134c 100644 --- a/dot.go +++ b/dot.go @@ -23,7 +23,6 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - tcpNet, _ := r.uc.netForDNSType(dnsTyp) dnsClient := &dns.Client{ Net: tcpNet, @@ -39,5 +38,5 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro } answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint) - return answer, err + return answer, wrapCertificateVerificationError(err) } From 7cea5305e1377b8e26e14deccaf789c334dcd237 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 10 Jun 2025 19:13:46 +0700 Subject: [PATCH 8/9] all: fix a regression causing invalid reloading timeout In v1.4.3, ControlD bootstrap DNS is used again for bootstrapping process. When this happened, the default system nameservers will be retrieved first, then ControlD DNS will be used if none available. However, getting default system nameservers process may take longer than reloading command timeout, causing invalid error message printed. To fix this, ensuring default system nameservers is retrieved once. --- resolver.go | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/resolver.go b/resolver.go index c20f1f5..27c0108 100644 --- a/resolver.go +++ b/resolver.go @@ -542,11 +542,26 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err // LookupIP looks up domain using current system nameservers settings. // It returns a slice of that host's IPv4 and IPv6 addresses. func LookupIP(domain string) []string { - return lookupIP(domain, -1, defaultNameservers()) + nss := initDefaultOsResolver() + return lookupIP(domain, -1, nss) +} + +// initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet. +// It returns the combined list of LAN and public nameservers currently held by the resolver. +func initDefaultOsResolver() []string { + resolverMutex.Lock() + defer resolverMutex.Unlock() + if or == nil { + ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers") + or = newResolverWithNameserver(defaultNameservers()) + } + nss := *or.lanServers.Load() + nss = append(nss, *or.publicServers.Load()...) + return nss } // lookupIP looks up domain with given timeout and bootstrapDNS. -// If timeout is negative, default timeout 2000 ms will be used. +// If the timeout is negative, default timeout 2000 ms will be used. // It returns nil if bootstrapDNS is nil or empty. func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) { if net.ParseIP(domain) != nil { @@ -650,13 +665,7 @@ func NewBootstrapResolver(servers ...string) Resolver { // // This is useful for doing PTR lookup in LAN network. func NewPrivateResolver() Resolver { - resolverMutex.Lock() - if or == nil { - ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver in NewPrivateResolver") - or = newResolverWithNameserver(defaultNameservers()) - } - nss := *or.lanServers.Load() - resolverMutex.Unlock() + nss := initDefaultOsResolver() resolveConfNss := currentNameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() n := 0 From c4efa1ab97892359feba6eca388e14c909371170 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 11 Jun 2025 22:23:24 +0700 Subject: [PATCH 9/9] Initializing default os resolver during upstream bootstrap Since calling defaultNameservers may block the whole bootstrap process if there's no valid DNS servers available. --- config.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index fdea19f..96f6686 100644 --- a/config.go +++ b/config.go @@ -437,8 +437,9 @@ func (uc *UpstreamConfig) UID() string { func (uc *UpstreamConfig) SetupBootstrapIP() { b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second) isControlD := uc.IsControlD() + nss := initDefaultOsResolver() for { - uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, defaultNameservers()) + uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss) // For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses, // filtering them out here to prevent weird behavior. if isControlD {