diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 31e1fcb..b99c48f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1216,13 +1216,18 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti // For Windows server with local Dns server running, we can only try on random local IP. hasLocalDnsServer := hasLocalDnsServerRunning() notRouter := router.Name() == "" + isDesktop := ctrld.IsDesktopPlatform() for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { listener.IP = "0.0.0.0" - if hasLocalDnsServer { - // Windows Server lies to us that we could listen on 0.0.0.0:53 - // even there's a process already done that, stick to local IP only. + // Windows Server lies to us that we could listen on 0.0.0.0:53 + // even there's a process already done that, stick to local IP only. + // + // For desktop clients, also stick the listener to the local IP only. + // Listening on 0.0.0.0 would expose it to the entire local network, potentially + // creating security vulnerabilities (such as DNS amplification or abusing). + if hasLocalDnsServer || isDesktop { listener.IP = "127.0.0.1" } lcc[n].IP = true diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 6a214e5..33012fa 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -500,7 +500,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { continue } answer := cachedValue.Msg.Copy() - answer.SetRcode(req.msg, answer.Rcode) + ctrld.SetCacheReply(answer, req.msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") @@ -1042,8 +1042,10 @@ func (p *prog) queryFromSelf(ip string) bool { return false } +// needRFC1918Listeners reports whether ctrld need to spawn listener for RFC 1918 addresses. +// This is helpful for non-desktop platforms to receive queries from LAN clients. func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool { - return lc.IP == "127.0.0.1" && lc.Port == 53 + return lc.IP == "127.0.0.1" && lc.Port == 53 && !ctrld.IsDesktopPlatform() } // ipFromARPA parses a FQDN arpa domain and return the IP address if valid. diff --git a/cmd/cli/os_darwin.go b/cmd/cli/os_darwin.go index 841be76..4c358b0 100644 --- a/cmd/cli/os_darwin.go +++ b/cmd/cli/os_darwin.go @@ -47,6 +47,9 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e // networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1 // TODO(cuonglm): use system API func setDNS(iface *net.Interface, nameservers []string) error { + // Note that networksetup won't modify search domains settings, + // This assignment is just a placeholder to silent linter. + _ = searchDomains cmd := "networksetup" args := []string{"-setdnsservers", iface.Name} args = append(args, nameservers...) @@ -88,7 +91,7 @@ func restoreDNS(iface *net.Interface) (err error) { } func currentDNS(_ *net.Interface) []string { - return resolvconffile.NameServers("") + return resolvconffile.NameServers() } // currentStaticDNS returns the current static DNS settings of given interface. diff --git a/cmd/cli/os_freebsd.go b/cmd/cli/os_freebsd.go index 72da485..d66e4bf 100644 --- a/cmd/cli/os_freebsd.go +++ b/cmd/cli/os_freebsd.go @@ -7,6 +7,7 @@ import ( "tailscale.com/control/controlknobs" "tailscale.com/health" + "tailscale.com/util/dnsname" "github.com/Control-D-Inc/ctrld/internal/dns" "github.com/Control-D-Inc/ctrld/internal/resolvconffile" @@ -50,7 +51,17 @@ func setDNS(iface *net.Interface, nameservers []string) error { ns = append(ns, netip.MustParseAddr(nameserver)) } - if err := r.SetDNS(dns.OSConfig{Nameservers: ns}); err != nil { + osConfig := dns.OSConfig{ + Nameservers: ns, + SearchDomains: []dnsname.FQDN{}, + } + if sds, err := searchDomains(); err == nil { + osConfig.SearchDomains = sds + } else { + mainLog.Load().Debug().Err(err).Msg("failed to get search domains list") + } + + if err := r.SetDNS(osConfig); err != nil { mainLog.Load().Error().Err(err).Msg("failed to set DNS") return err } @@ -83,7 +94,7 @@ func restoreDNS(iface *net.Interface) (err error) { } func currentDNS(_ *net.Interface) []string { - return resolvconffile.NameServers("") + return resolvconffile.NameServers() } // currentStaticDNS returns the current static DNS settings of given interface. diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index e2302a3..8caad63 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -71,6 +71,11 @@ func setDNS(iface *net.Interface, nameservers []string) error { Nameservers: ns, SearchDomains: []dnsname.FQDN{}, } + if sds, err := searchDomains(); err == nil { + osConfig.SearchDomains = sds + } else { + mainLog.Load().Debug().Err(err).Msg("failed to get search domains list") + } trySystemdResolve := false if err := r.SetDNS(osConfig); err != nil { if strings.Contains(err.Error(), "Rejected send message") && @@ -196,7 +201,8 @@ func restoreDNS(iface *net.Interface) (err error) { } func currentDNS(iface *net.Interface) []string { - for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} { + resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() } + for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} { if ns := fn(iface.Name); len(ns) > 0 { return ns } diff --git a/cmd/cli/os_windows.go b/cmd/cli/os_windows.go index e1bcd9a..7ebc54a 100644 --- a/cmd/cli/os_windows.go +++ b/cmd/cli/os_windows.go @@ -100,6 +100,10 @@ func setDNS(iface *net.Interface, nameservers []string) error { } } + // Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function. + // searchDomains is still implemented for Windows just in case Windows API changes in future versions. + _ = searchDomains + if len(serversV4) == 0 && len(serversV6) == 0 { return errors.New("invalid DNS nameservers") } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 089bfd0..dd8de9f 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -70,10 +70,17 @@ func ControlSocketName() string { } } +// logf is a function variable used for logging formatted debug messages with optional arguments. +// This is used only when creating a new DNS OS configurator. var logf = func(format string, args ...any) { mainLog.Load().Debug().Msgf(format, args...) } +// noopLogf is like logf but discards formatted log messages and arguments without any processing. +// +//lint:ignore U1000 use in newLoopbackOSConfigurator +var noopLogf = func(format string, args ...any) {} + var svcConfig = &service.Config{ Name: ctrldServiceName, DisplayName: "Control-D Helper Service", diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index cc0046b..2e5c7c7 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -9,15 +9,12 @@ import ( "strings" "github.com/kardianos/service" - "tailscale.com/control/controlknobs" - "tailscale.com/health" - "github.com/Control-D-Inc/ctrld/internal/dns" "github.com/Control-D-Inc/ctrld/internal/router" ) func init() { - if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo"); err == nil { + if r, err := newLoopbackOSConfigurator(); err == nil { useSystemdResolved = r.Mode() == "systemd-resolved" } // Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911 diff --git a/cmd/cli/resolvconf_not_darwin_unix.go b/cmd/cli/resolvconf_not_darwin_unix.go index 7181e95..af33572 100644 --- a/cmd/cli/resolvconf_not_darwin_unix.go +++ b/cmd/cli/resolvconf_not_darwin_unix.go @@ -13,9 +13,9 @@ import ( "github.com/Control-D-Inc/ctrld/internal/dns" ) -// setResolvConf sets the content of resolv.conf file using the given nameservers list. +// setResolvConf sets the content of the resolv.conf file using the given nameservers list. func setResolvConf(iface *net.Interface, ns []netip.Addr) error { - r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter. + r, err := newLoopbackOSConfigurator() if err != nil { return err } @@ -24,12 +24,17 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error { Nameservers: ns, SearchDomains: []dnsname.FQDN{}, } + if sds, err := searchDomains(); err == nil { + oc.SearchDomains = sds + } else { + mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file") + } return r.SetDNS(oc) } // shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator. func shouldWatchResolvconf() bool { - r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter. + r, err := newLoopbackOSConfigurator() if err != nil { return false } @@ -40,3 +45,8 @@ func shouldWatchResolvconf() bool { return false } } + +// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface. +func newLoopbackOSConfigurator() (dns.OSConfigurator, error) { + return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo") +} diff --git a/cmd/cli/search_domains_unix.go b/cmd/cli/search_domains_unix.go new file mode 100644 index 0000000..de3998e --- /dev/null +++ b/cmd/cli/search_domains_unix.go @@ -0,0 +1,14 @@ +//go:build unix + +package cli + +import ( + "tailscale.com/util/dnsname" + + "github.com/Control-D-Inc/ctrld/internal/resolvconffile" +) + +// searchDomains returns the current search domains config. +func searchDomains() ([]dnsname.FQDN, error) { + return resolvconffile.SearchDomains() +} diff --git a/cmd/cli/search_domains_windows.go b/cmd/cli/search_domains_windows.go new file mode 100644 index 0000000..320a322 --- /dev/null +++ b/cmd/cli/search_domains_windows.go @@ -0,0 +1,43 @@ +package cli + +import ( + "fmt" + "syscall" + + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/util/dnsname" +) + +// searchDomains returns the current search domains config. +func searchDomains() ([]dnsname.FQDN, error) { + flags := winipcfg.GAAFlagIncludeGateways | + winipcfg.GAAFlagIncludePrefix + + aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags) + if err != nil { + return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err) + } + + var sds []dnsname.FQDN + for _, aa := range aas { + if aa.OperStatus != winipcfg.IfOperStatusUp { + continue + } + + // Skip if software loopback or other non-physical types + // This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows + if aa.IfType == winipcfg.IfTypeSoftwareLoopback { + continue + } + + for a := aa.FirstDNSSuffix; a != nil; a = a.Next { + d, err := dnsname.ToFQDN(a.String()) + if err != nil { + mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String()) + continue + } + sds = append(sds, d) + } + } + return sds, nil +} diff --git a/config.go b/config.go index fdea19f..96f6686 100644 --- a/config.go +++ b/config.go @@ -437,8 +437,9 @@ func (uc *UpstreamConfig) UID() string { func (uc *UpstreamConfig) SetupBootstrapIP() { b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second) isControlD := uc.IsControlD() + nss := initDefaultOsResolver() for { - uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, defaultNameservers()) + uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss) // For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses, // filtering them out here to prevent weird behavior. if isControlD { diff --git a/desktop_darwin.go b/desktop_darwin.go new file mode 100644 index 0000000..039c0fa --- /dev/null +++ b/desktop_darwin.go @@ -0,0 +1,7 @@ +package ctrld + +// IsDesktopPlatform indicates if ctrld is running on a desktop platform, +// currently defined as macOS or Windows workstation. +func IsDesktopPlatform() bool { + return true +} diff --git a/desktop_others.go b/desktop_others.go new file mode 100644 index 0000000..de486e7 --- /dev/null +++ b/desktop_others.go @@ -0,0 +1,9 @@ +//go:build !windows && !darwin + +package ctrld + +// IsDesktopPlatform indicates if ctrld is running on a desktop platform, +// currently defined as macOS or Windows workstation. +func IsDesktopPlatform() bool { + return false +} diff --git a/desktop_windows.go b/desktop_windows.go new file mode 100644 index 0000000..4e9526b --- /dev/null +++ b/desktop_windows.go @@ -0,0 +1,7 @@ +package ctrld + +// IsDesktopPlatform indicates if ctrld is running on a desktop platform, +// currently defined as macOS or Windows workstation. +func IsDesktopPlatform() bool { + return isWindowsWorkStation() +} diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..f2b71a5 --- /dev/null +++ b/dns.go @@ -0,0 +1,30 @@ +package ctrld + +import ( + "github.com/miekg/dns" +) + +// SetCacheReply extracts and stores the necessary data from the message for a cached answer. +func SetCacheReply(answer, msg *dns.Msg, code int) { + answer.SetRcode(msg, code) + cCookie := getEdns0Cookie(msg.IsEdns0()) + sCookie := getEdns0Cookie(answer.IsEdns0()) + if cCookie != nil && sCookie != nil { + // Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes. + // See https://datatracker.ietf.org/doc/html/rfc7873#section-4 + sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:] + } +} + +// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present. +func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE { + if opt == nil { + return nil + } + for _, o := range opt.Option { + if e, ok := o.(*dns.EDNS0_COOKIE); ok { + return e + } + } + return nil +} diff --git a/doh.go b/doh.go index 73b2764..3459cb8 100644 --- a/doh.go +++ b/doh.go @@ -2,6 +2,7 @@ package ctrld import ( "context" + "crypto/tls" "encoding/base64" "errors" "fmt" @@ -120,6 +121,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro resp, err = c.Do(req.Clone(retryCtx)) } if err != nil { + err = wrapUrlError(err) if r.isDoH3 { if closer, ok := c.Transport.(io.Closer); ok { closer.Close() @@ -208,3 +210,52 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header { } return header } + +// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer. +// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output. +// If no certificate-related information is available, it simply returns the original error unmodified. +func wrapCertificateVerificationError(err error) error { + var tlsErr *tls.CertificateVerificationError + if errors.As(err, &tlsErr) { + if len(tlsErr.UnverifiedCertificates) > 0 { + cert := tlsErr.UnverifiedCertificates[0] + // Extract a more user-friendly issuer name + var issuer string + var organization string + if len(cert.Issuer.Organization) > 0 { + organization = cert.Issuer.Organization[0] + issuer = organization + } else if cert.Issuer.CommonName != "" { + issuer = cert.Issuer.CommonName + } else { + issuer = cert.Issuer.String() + } + + // Get the organization from the subject field as well + if len(cert.Subject.Organization) > 0 { + organization = cert.Subject.Organization[0] + } + + // Extract the subject information + subjectCN := cert.Subject.CommonName + if subjectCN == "" && len(cert.Subject.Organization) > 0 { + subjectCN = cert.Subject.Organization[0] + } + return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer) + } + } + return err +} + +// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context. +func wrapUrlError(err error) error { + var urlErr *url.Error + if errors.As(err, &urlErr) { + var tlsErr *tls.CertificateVerificationError + if errors.As(urlErr.Err, &tlsErr) { + urlErr.Err = wrapCertificateVerificationError(tlsErr) + return urlErr + } + } + return err +} diff --git a/doh_test.go b/doh_test.go index 8d3e011..92fa79f 100644 --- a/doh_test.go +++ b/doh_test.go @@ -1,8 +1,22 @@ package ctrld import ( + "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "net" + "net/http" + "net/http/httptest" + "net/url" "runtime" + "strings" "testing" + "time" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go/http3" ) func Test_dohOsHeaderValue(t *testing.T) { @@ -21,3 +35,232 @@ func Test_dohOsHeaderValue(t *testing.T) { t.Fatalf("missing decoding value for: %q", runtime.GOOS) } } + +func Test_wrapUrlError(t *testing.T) { + tests := []struct { + name string + err error + wantErr string + }{ + { + name: "No wrapping for non-URL errors", + err: errors.New("plain error"), + wantErr: "plain error", + }, + { + name: "URL error without TLS error", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: errors.New("underlying error"), + }, + wantErr: "Get \"https://example.com\": underlying error", + }, + { + name: "TLS error with missing unverified certificate data", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: &tls.CertificateVerificationError{ + UnverifiedCertificates: nil, + Err: &x509.UnknownAuthorityError{}, + }, + }, + wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority`, + }, + { + name: "TLS error with valid certificate data", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: &tls.CertificateVerificationError{ + UnverifiedCertificates: []*x509.Certificate{ + { + Subject: pkix.Name{ + CommonName: "BadSubjectCN", + Organization: []string{"BadSubjectOrg"}, + }, + Issuer: pkix.Name{ + CommonName: "BadIssuerCN", + Organization: []string{"BadIssuerOrg"}, + }, + }, + }, + Err: &x509.UnknownAuthorityError{}, + }, + }, + wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority: BadSubjectCN, BadSubjectOrg, BadIssuerOrg`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := wrapUrlError(tt.err) + if gotErr.Error() != tt.wantErr { + t.Errorf("wrapCertificateVerificationError() error = %v, want %v", gotErr, tt.wantErr) + } + }) + } +} + +func Test_ClientCertificateVerificationError(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/dns-message") + }) + tlsServer, cert := testTLSServer(t, handler) + tlsServerUrl, err := url.Parse(tlsServer.URL) + if err != nil { + t.Fatal(err) + } + quicServer := newTestQUICServer(t) + http3Server := newTestHTTP3Server(t, handler) + + tests := []struct { + name string + uc *UpstreamConfig + }{ + { + "doh", + &UpstreamConfig{ + Name: "doh", + Type: ResolverTypeDOH, + Endpoint: tlsServer.URL, + Timeout: 1000, + }, + }, + { + "doh3", + &UpstreamConfig{ + Name: "doh3", + Type: ResolverTypeDOH3, + Endpoint: http3Server.addr, + Timeout: 5000, + }, + }, + { + "doq", + &UpstreamConfig{ + Name: "doq", + Type: ResolverTypeDOQ, + Endpoint: quicServer.addr, + Timeout: 5000, + }, + }, + { + "dot", + &UpstreamConfig{ + Name: "dot", + Type: ResolverTypeDOT, + Endpoint: net.JoinHostPort(tlsServerUrl.Hostname(), tlsServerUrl.Port()), + Timeout: 1000, + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tc.uc.Init() + tc.uc.SetupBootstrapIP() + r, err := NewResolver(tc.uc) + if err != nil { + t.Fatal(err) + } + msg := new(dns.Msg) + msg.SetQuestion("verify.controld.com.", dns.TypeA) + msg.RecursionDesired = true + _, err = r.Resolve(context.Background(), msg) + // Verify the error contains the expected certificate information + if err == nil { + t.Fatal("expected certificate verification error, got nil") + } + + // You can check the error contains information about the test certificate + if !strings.Contains(err.Error(), cert.Issuer.CommonName) { + t.Fatalf("error should contain issuer information %q, got: %v", cert.Issuer.CommonName, err) + } + }) + } +} + +// testTLSServer creates an HTTPS test server with a self-signed certificate +// returns the server and its certificate for verification testing +// testTLSServer creates an HTTPS test server with a self-signed certificate +func testTLSServer(t *testing.T, handler http.Handler) (*httptest.Server, *x509.Certificate) { + t.Helper() + + testCert := generateTestCertificate(t) + + // Create a test server + server := httptest.NewUnstartedServer(handler) + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + } + server.StartTLS() + + // Add cleanup + t.Cleanup(server.Close) + + return server, testCert.cert +} + +// testHTTP3Server represents a structure for an HTTP/3 test server with its server instance, TLS certificate, and address. +type testHTTP3Server struct { + server *http3.Server + cert *x509.Certificate + addr string +} + +// newTestHTTP3Server creates and starts a test HTTP/3 server with a given handler and returns the server instance. +func newTestHTTP3Server(t *testing.T, handler http.Handler) *testHTTP3Server { + t.Helper() + + testCert := generateTestCertificate(t) + + // First create a listener to get the actual port + udpAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + t.Fatalf("failed to create UDP listener: %v", err) + } + + // Get the actual address + actualAddr := udpConn.LocalAddr().String() + + // Create TLS config + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + NextProtos: []string{"h3"}, // HTTP/3 protocol identifier + } + + // Create HTTP/3 server + server := &http3.Server{ + Handler: handler, + TLSConfig: tlsConfig, + } + + // Start the server with the existing UDP connection + go func() { + if err := server.Serve(udpConn); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Logf("HTTP/3 server error: %v", err) + } + }() + + h3Server := &testHTTP3Server{ + server: server, + cert: testCert.cert, + addr: actualAddr, + } + + // Add cleanup + t.Cleanup(func() { + server.Close() + udpConn.Close() + }) + + // Wait a bit for the server to be ready + time.Sleep(100 * time.Millisecond) + + return h3Server +} diff --git a/doq.go b/doq.go index 3c3f9e8..0903411 100644 --- a/doq.go +++ b/doq.go @@ -43,7 +43,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls. continue } if err != nil { - return nil, err + return nil, wrapCertificateVerificationError(err) } return answer, nil } diff --git a/doq_test.go b/doq_test.go new file mode 100644 index 0000000..430a22a --- /dev/null +++ b/doq_test.go @@ -0,0 +1,223 @@ +// test_helpers.go +package ctrld + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "strings" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go" +) + +// testCertificate represents a test certificate with its components +type testCertificate struct { + cert *x509.Certificate + tlsCert tls.Certificate + template *x509.Certificate +} + +// generateTestCertificate creates a self-signed certificate for testing +func generateTestCertificate(t *testing.T) *testCertificate { + t.Helper() + + // Generate private key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate private key: %v", err) + } + + // Create certificate template + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + CommonName: "Test CA", + }, + Issuer: pkix.Name{ + Organization: []string{"Test Issuer Org"}, + CommonName: "Test Issuer CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + + // Create certificate + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + t.Fatalf("failed to parse certificate: %v", err) + } + + // Create TLS certificate + tlsCert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: privateKey, + } + + return &testCertificate{ + cert: cert, + tlsCert: tlsCert, + template: template, + } +} + +// testQUICServer is a structure representing a test QUIC server for handling connections and streams. +// listener is the QUIC listener used to accept incoming connections. +// cert is the x509 certificate used by the server for authentication. +// addr is the address on which the test server is running. +type testQUICServer struct { + listener *quic.Listener + cert *x509.Certificate + addr string +} + +// newTestQUICServer creates and initializes a test QUIC server with TLS configuration and starts accepting connections. +func newTestQUICServer(t *testing.T) *testQUICServer { + t.Helper() + + testCert := generateTestCertificate(t) + + // Create TLS config + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + NextProtos: []string{"doq"}, + } + + // Create QUIC listener + listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil) + if err != nil { + t.Fatalf("failed to create QUIC listener: %v", err) + } + + server := &testQUICServer{ + listener: listener, + cert: testCert.cert, + addr: listener.Addr().String(), + } + + // Start handling connections + go server.serve(t) + + // Add cleanup + t.Cleanup(func() { + listener.Close() + }) + + return server +} + +// serve handles incoming connections on the QUIC listener and delegates them to connection handlers in separate goroutines. +func (s *testQUICServer) serve(t *testing.T) { + for { + conn, err := s.listener.Accept(context.Background()) + if err != nil { + // Check if the error is due to the listener being closed + if strings.Contains(err.Error(), "server closed") { + return + } + t.Logf("failed to accept connection: %v", err) + continue + } + + go s.handleConnection(t, conn) + } +} + +// handleConnection manages an individual QUIC connection by accepting and handling incoming streams in separate goroutines. +func (s *testQUICServer) handleConnection(t *testing.T, conn quic.Connection) { + for { + stream, err := conn.AcceptStream(context.Background()) + if err != nil { + return + } + + go s.handleStream(t, stream) + } +} + +// handleStream processes a single QUIC stream, reads DNS messages, generates a response, and sends it back to the client. +func (s *testQUICServer) handleStream(t *testing.T, stream quic.Stream) { + defer stream.Close() + + // Read length (2 bytes) + lenBuf := make([]byte, 2) + _, err := stream.Read(lenBuf) + if err != nil { + t.Logf("failed to read message length: %v", err) + return + } + msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1]) + + // Read message + msgBuf := make([]byte, msgLen) + _, err = stream.Read(msgBuf) + if err != nil { + t.Logf("failed to read message: %v", err) + return + } + + // Parse DNS message + msg := new(dns.Msg) + if err := msg.Unpack(msgBuf); err != nil { + t.Logf("failed to unpack DNS message: %v", err) + return + } + + // Create response + response := new(dns.Msg) + response.SetReply(msg) + response.Authoritative = true + + // Add a test answer + if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA { + response.Answer = append(response.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP("192.0.2.1"), // TEST-NET-1 address + }) + } + + // Pack response + respBytes, err := response.Pack() + if err != nil { + t.Logf("failed to pack response: %v", err) + return + } + + // Write length + respLen := uint16(len(respBytes)) + _, err = stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)}) + if err != nil { + t.Logf("failed to write response length: %v", err) + return + } + + // Write response + _, err = stream.Write(respBytes) + if err != nil { + t.Logf("failed to write response: %v", err) + return + } +} diff --git a/dot.go b/dot.go index 67d1ff8..295134c 100644 --- a/dot.go +++ b/dot.go @@ -23,7 +23,6 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - tcpNet, _ := r.uc.netForDNSType(dnsTyp) dnsClient := &dns.Client{ Net: tcpNet, @@ -39,5 +38,5 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro } answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint) - return answer, err + return answer, wrapCertificateVerificationError(err) } diff --git a/internal/net/net.go b/internal/net/net.go index d5bd75e..f4b5586 100644 --- a/internal/net/net.go +++ b/internal/net/net.go @@ -17,10 +17,17 @@ import ( ) const ( - v4BootstrapDNS = "76.76.2.22:53" - v6BootstrapDNS = "[2606:1a40::22]:53" + v4BootstrapDNS = "76.76.2.22:53" + v6BootstrapDNS = "[2606:1a40::22]:53" + v6BootstrapIP = "2606:1a40::22" + defaultHTTPSPort = "443" + defaultHTTPPort = "80" + defaultDNSPort = "53" + probeStackTimeout = 2 * time.Second ) +var commonIPv6Ports = []string{defaultHTTPSPort, defaultHTTPPort, defaultDNSPort} + var Dialer = &net.Dialer{ Resolver: &net.Resolver{ PreferGo: true, @@ -33,8 +40,6 @@ var Dialer = &net.Dialer{ }, } -const probeStackTimeout = 2 * time.Second - var probeStackDialer = &net.Dialer{ Resolver: Dialer.Resolver, Timeout: probeStackTimeout, @@ -50,12 +55,28 @@ func init() { stackOnce.Store(new(sync.Once)) } -func supportIPv6(ctx context.Context) bool { - c, err := probeStackDialer.DialContext(ctx, "tcp6", v6BootstrapDNS) +// supportIPv6 checks for IPv6 connectivity by attempting to connect to predefined ports +// on a specific IPv6 address. +// Returns a boolean indicating if IPv6 is supported and the port on which the connection succeeded. +// If no connection is successful, returns false and an empty string. +func supportIPv6(ctx context.Context) (supported bool, successPort string) { + for _, port := range commonIPv6Ports { + if canConnectToIPv6Port(ctx, port) { + return true, string(port) + } + } + return false, "" +} + +// canConnectToIPv6Port attempts to establish a TCP connection to the specified port +// using IPv6. Returns true if the connection was successful. +func canConnectToIPv6Port(ctx context.Context, port string) bool { + address := net.JoinHostPort(v6BootstrapIP, port) + conn, err := probeStackDialer.DialContext(ctx, "tcp6", address) if err != nil { return false } - c.Close() + _ = conn.Close() return true } @@ -110,7 +131,8 @@ func SupportsIPv6ListenLocal() bool { // IPv6Available is like SupportsIPv6, but always do the check without caching. func IPv6Available(ctx context.Context) bool { - return supportIPv6(ctx) + hasV6, _ := supportIPv6(ctx) + return hasV6 } // IsIPv6 checks if the provided IP is v6. diff --git a/internal/net/net_test.go b/internal/net/net_test.go index d28dbed..7df3e09 100644 --- a/internal/net/net_test.go +++ b/internal/net/net_test.go @@ -12,7 +12,12 @@ func TestProbeStackTimeout(t *testing.T) { go func() { defer close(done) close(started) - supportIPv6(context.Background()) + hasV6, port := supportIPv6(context.Background()) + if hasV6 { + t.Logf("connect to port %s using ipv6: %v", port, hasV6) + } else { + t.Log("ipv6 is not available") + } }() <-started diff --git a/internal/resolvconffile/dns.go b/internal/resolvconffile/dns.go index 3ce0f91..0d532eb 100644 --- a/internal/resolvconffile/dns.go +++ b/internal/resolvconffile/dns.go @@ -6,6 +6,7 @@ import ( "net" "tailscale.com/net/dns/resolvconffile" + "tailscale.com/util/dnsname" ) const resolvconfPath = "/etc/resolv.conf" @@ -22,7 +23,7 @@ func NameServersWithPort() []string { return ns } -func NameServers(_ string) []string { +func NameServers() []string { c, err := resolvconffile.ParseFile(resolvconfPath) if err != nil { return nil @@ -33,3 +34,12 @@ func NameServers(_ string) []string { } return ns } + +// SearchDomains returns the current search domains config in /etc/resolv.conf file. +func SearchDomains() ([]dnsname.FQDN, error) { + c, err := resolvconffile.ParseFile(resolvconfPath) + if err != nil { + return nil, err + } + return c.SearchDomains, nil +} diff --git a/internal/resolvconffile/dns_test.go b/internal/resolvconffile/dns_test.go index ba571af..7f7a64c 100644 --- a/internal/resolvconffile/dns_test.go +++ b/internal/resolvconffile/dns_test.go @@ -9,7 +9,7 @@ import ( ) func TestNameServers(t *testing.T) { - ns := NameServers("") + ns := NameServers() require.NotNil(t, ns) t.Log(ns) } diff --git a/nameservers_unix.go b/nameservers_unix.go index d7af521..d8e6035 100644 --- a/nameservers_unix.go +++ b/nameservers_unix.go @@ -14,7 +14,7 @@ import ( // currentNameserversFromResolvconf returns the current nameservers set from /etc/resolv.conf file. func currentNameserversFromResolvconf() []string { - return resolvconffile.NameServers("") + return resolvconffile.NameServers() } // dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file. @@ -34,7 +34,7 @@ func dnsFromResolvConf() []string { time.Sleep(retryInterval) } - nss := resolvconffile.NameServers("") + nss := resolvconffile.NameServers() var localDNS []string seen := make(map[string]bool) diff --git a/resolver.go b/resolver.go index a44ddb2..27c0108 100644 --- a/resolver.go +++ b/resolver.go @@ -9,12 +9,14 @@ import ( "net/netip" "runtime" "slices" + "strings" "sync" "sync/atomic" "time" "github.com/miekg/dns" "github.com/rs/zerolog" + "golang.org/x/sync/singleflight" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" ) @@ -216,6 +218,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { type osResolver struct { lanServers atomic.Pointer[[]string] publicServers atomic.Pointer[[]string] + group *singleflight.Group + cache *sync.Map } type osResolverResult struct { @@ -273,10 +277,75 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired return dnsClient.ExchangeContext(ctx, msg, server) } +const hotCacheTTL = time.Second + // Resolve resolves DNS queries using pre-configured nameservers. -// Query is sent to all nameservers concurrently, and the first +// The Query is sent to all nameservers concurrently, and the first // success response will be returned. +// +// To guard against unexpected DoS to upstreams, multiple queries of +// the same Qtype to a domain will be shared, so there's only 1 qps +// for each upstream at any time. +// +// Further, a hot cache will be used, so repeated queries will be cached +// for a short period (currently 1 second), reducing unnecessary traffics +// sent to upstreams. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { + if len(msg.Question) == 0 { + return nil, errors.New("no question found") + } + domain := strings.TrimSuffix(msg.Question[0].Name, ".") + qtype := msg.Question[0].Qtype + + // Unique key for the singleflight group. + key := fmt.Sprintf("%s:%d:", domain, qtype) + + // Checking the cache first. + if val, ok := o.cache.Load(key); ok { + if val, ok := val.(*dns.Msg); ok { + Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype]) + res := val.Copy() + SetCacheReply(res, msg, val.Rcode) + return res, nil + } + } + + // Ensure only one DNS query is in flight for the key. + v, err, shared := o.group.Do(key, func() (interface{}, error) { + msg, err := o.resolve(ctx, msg) + if err != nil { + return nil, err + } + // If we got an answer, storing it to the hot cache for hotCacheTTL + // This prevents possible DoS to upstream, ensuring there's only 1 QPS. + o.cache.Store(key, msg) + // Depends on go runtime scheduling, the result may end up in hot cache longer + // than hotCacheTTL duration. However, this is fine since we only want to guard + // against DoS attack. The result will be cleaned from the cache eventually. + time.AfterFunc(hotCacheTTL, func() { + o.removeCache(key) + }) + return msg, nil + }) + if err != nil { + return nil, err + } + + sharedMsg, ok := v.(*dns.Msg) + if !ok { + return nil, fmt.Errorf("invalid answer for key: %s", key) + } + res := sharedMsg.Copy() + SetCacheReply(res, msg, sharedMsg.Rcode) + if shared { + Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype]) + } + + return res, nil +} + +// resolve sends the query to current nameservers. +func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { publicServers := *o.publicServers.Load() var nss []string if p := o.lanServers.Load(); p != nil { @@ -431,6 +500,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error return nil, errors.Join(errs...) } +func (o *osResolver) removeCache(key string) { + o.cache.Delete(key) +} + type legacyResolver struct { uc *UpstreamConfig } @@ -469,11 +542,26 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err // LookupIP looks up domain using current system nameservers settings. // It returns a slice of that host's IPv4 and IPv6 addresses. func LookupIP(domain string) []string { - return lookupIP(domain, -1, defaultNameservers()) + nss := initDefaultOsResolver() + return lookupIP(domain, -1, nss) +} + +// initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet. +// It returns the combined list of LAN and public nameservers currently held by the resolver. +func initDefaultOsResolver() []string { + resolverMutex.Lock() + defer resolverMutex.Unlock() + if or == nil { + ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers") + or = newResolverWithNameserver(defaultNameservers()) + } + nss := *or.lanServers.Load() + nss = append(nss, *or.publicServers.Load()...) + return nss } // lookupIP looks up domain with given timeout and bootstrapDNS. -// If timeout is negative, default timeout 2000 ms will be used. +// If the timeout is negative, default timeout 2000 ms will be used. // It returns nil if bootstrapDNS is nil or empty. func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) { if net.ParseIP(domain) != nil { @@ -577,13 +665,7 @@ func NewBootstrapResolver(servers ...string) Resolver { // // This is useful for doing PTR lookup in LAN network. func NewPrivateResolver() Resolver { - resolverMutex.Lock() - if or == nil { - ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver in NewPrivateResolver") - or = newResolverWithNameserver(defaultNameservers()) - } - nss := *or.lanServers.Load() - resolverMutex.Unlock() + nss := initDefaultOsResolver() resolveConfNss := currentNameserversFromResolvconf() localRfc1918Addrs := Rfc1918Addresses() n := 0 @@ -627,10 +709,10 @@ func NewResolverWithNameserver(nameservers []string) Resolver { // newResolverWithNameserver returns an OS resolver from given nameservers list. // The caller must ensure each server in list is formed "ip:53". func newResolverWithNameserver(nameservers []string) *osResolver { - logger := *ProxyLogger.Load() - - Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers) - r := &osResolver{} + r := &osResolver{ + group: &singleflight.Group{}, + cache: &sync.Map{}, + } var publicNss []string var lanNss []string for _, ns := range slices.Sorted(slices.Values(nameservers)) { diff --git a/resolver_test.go b/resolver_test.go index fb6831b..ebcad16 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -2,8 +2,11 @@ package ctrld import ( "context" + "crypto/rand" + "encoding/hex" "net" "sync" + "sync/atomic" "testing" "time" @@ -16,8 +19,7 @@ func Test_osResolver_Resolve(t *testing.T) { go func() { defer cancel() - resolver := &osResolver{} - resolver.publicServers.Store(&[]string{"127.0.0.127:5353"}) + resolver := newResolverWithNameserver([]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -50,8 +52,7 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) { t.Error("not a LAN query") return } - resolver := &osResolver{} - resolver.publicServers.Store(&[]string{"76.76.2.0:53"}) + resolver := newResolverWithNameserver([]string{"76.76.2.0:53"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true @@ -107,11 +108,9 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { }() // We now create an osResolver which has both a LAN and public nameserver. - resolver := &osResolver{} - // Explicitly store the LAN nameserver. - resolver.lanServers.Store(&[]string{lanAddr}) - // And store the public nameservers. - resolver.publicServers.Store(&publicNS) + nss := []string{lanAddr} + nss = append(nss, publicNS...) + resolver := newResolverWithNameserver(nss) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) @@ -139,6 +138,150 @@ func Test_osResolver_InitializationRace(t *testing.T) { wg.Wait() } +func Test_osResolver_Singleflight(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + n := 10 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + _, err := or.Resolve(context.Background(), m) + if err != nil { + t.Error(err) + } + }() + } + wg.Wait() + + // All above queries should only make 1 call to server. + if call.Load() != 1 { + t.Fatalf("expected 1 result from singleflight lookup, got %d", call) + } +} + +func Test_osResolver_HotCache(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + + // Make 2 repeated queries to server, should hit hot cache. + for i := 0; i < 2; i++ { + if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + t.Fatal(err) + } + } + if call.Load() != 1 { + t.Fatalf("cache not hit, server was called: %d", call.Load()) + } + + timeoutChan := make(chan struct{}) + time.AfterFunc(5*time.Second, func() { + close(timeoutChan) + }) + + for { + select { + case <-timeoutChan: + t.Fatal("timed out waiting for cache cleaned") + default: + count := 0 + or.cache.Range(func(key, value interface{}) bool { + count++ + return true + }) + if count != 0 { + t.Logf("hot cache is not empty: %d elements", count) + continue + } + } + break + } + + if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { + t.Fatal(err) + } + if call.Load() != 2 { + t.Fatal("cache hit unexpectedly") + } +} + +func Test_Edns0_CacheReply(t *testing.T) { + lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on LAN address: %v", err) + } + call := &atomic.Int64{} + lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) + if err != nil { + t.Fatalf("failed to run LAN test server: %v", err) + } + defer lanServer.Shutdown() + + or := newResolverWithNameserver([]string{lanAddr}) + domain := "controld.com" + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), dns.TypeA) + m.RecursionDesired = true + + do := func() *dns.Msg { + msg := m.Copy() + msg.SetEdns0(4096, true) + cookieOption := new(dns.EDNS0_COOKIE) + cookieOption.Code = dns.EDNS0COOKIE + cookieOption.Cookie = generateEdns0ClientCookie() + msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption) + + answer, err := or.Resolve(context.Background(), msg) + if err != nil { + t.Fatal(err) + } + return answer + } + answer1 := do() + answer2 := do() + // Ensure the cache was hit, so we can check that edns0 cookie must be modified. + if call.Load() != 1 { + t.Fatalf("cache not hit, server was called: %d", call.Load()) + } + cookie1 := getEdns0Cookie(answer1.IsEdns0()) + cookie2 := getEdns0Cookie(answer2.IsEdns0()) + if cookie1 == nil || cookie2 == nil { + t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2) + } + if cookie1.Cookie == cookie2.Cookie { + t.Fatalf("edns0 cookie is not modified: %v", cookie1) + } +} + func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string @@ -208,3 +351,37 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc { w.WriteMsg(m) } } + +func countHandler(call *atomic.Int64) dns.HandlerFunc { + return func(w dns.ResponseWriter, msg *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(msg, dns.RcodeSuccess) + if cookie := getEdns0Cookie(msg.IsEdns0()); cookie != nil { + if m.IsEdns0() == nil { + m.SetEdns0(4096, false) + } + cookieOption := new(dns.EDNS0_COOKIE) + cookieOption.Code = dns.EDNS0COOKIE + cookieOption.Cookie = generateEdns0ServerCookie(cookie.Cookie) + m.IsEdns0().Option = append(m.IsEdns0().Option, cookieOption) + } + w.WriteMsg(m) + call.Add(1) + } +} + +func generateEdns0ClientCookie() string { + cookie := make([]byte, 8) + if _, err := rand.Read(cookie); err != nil { + panic(err) + } + return hex.EncodeToString(cookie) +} + +func generateEdns0ServerCookie(clientCookie string) string { + cookie := make([]byte, 32) + if _, err := rand.Read(cookie); err != nil { + panic(err) + } + return clientCookie + hex.EncodeToString(cookie) +}