mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Merge pull request #239 from Control-D-Inc/release-branch-v1.4.4
Release branch v1.4.4
This commit is contained in:
@@ -1216,13 +1216,18 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti
|
||||
// For Windows server with local Dns server running, we can only try on random local IP.
|
||||
hasLocalDnsServer := hasLocalDnsServerRunning()
|
||||
notRouter := router.Name() == ""
|
||||
isDesktop := ctrld.IsDesktopPlatform()
|
||||
for n, listener := range cfg.Listener {
|
||||
lcc[n] = &listenerConfigCheck{}
|
||||
if listener.IP == "" {
|
||||
listener.IP = "0.0.0.0"
|
||||
if hasLocalDnsServer {
|
||||
// Windows Server lies to us that we could listen on 0.0.0.0:53
|
||||
// even there's a process already done that, stick to local IP only.
|
||||
// Windows Server lies to us that we could listen on 0.0.0.0:53
|
||||
// even there's a process already done that, stick to local IP only.
|
||||
//
|
||||
// For desktop clients, also stick the listener to the local IP only.
|
||||
// Listening on 0.0.0.0 would expose it to the entire local network, potentially
|
||||
// creating security vulnerabilities (such as DNS amplification or abusing).
|
||||
if hasLocalDnsServer || isDesktop {
|
||||
listener.IP = "127.0.0.1"
|
||||
}
|
||||
lcc[n].IP = true
|
||||
|
||||
@@ -500,7 +500,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
continue
|
||||
}
|
||||
answer := cachedValue.Msg.Copy()
|
||||
answer.SetRcode(req.msg, answer.Rcode)
|
||||
ctrld.SetCacheReply(answer, req.msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
|
||||
@@ -1042,8 +1042,10 @@ func (p *prog) queryFromSelf(ip string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// needRFC1918Listeners reports whether ctrld need to spawn listener for RFC 1918 addresses.
|
||||
// This is helpful for non-desktop platforms to receive queries from LAN clients.
|
||||
func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
|
||||
return lc.IP == "127.0.0.1" && lc.Port == 53
|
||||
return lc.IP == "127.0.0.1" && lc.Port == 53 && !ctrld.IsDesktopPlatform()
|
||||
}
|
||||
|
||||
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.
|
||||
|
||||
@@ -47,6 +47,9 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
|
||||
// TODO(cuonglm): use system API
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
// Note that networksetup won't modify search domains settings,
|
||||
// This assignment is just a placeholder to silent linter.
|
||||
_ = searchDomains
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name}
|
||||
args = append(args, nameservers...)
|
||||
@@ -88,7 +91,7 @@ func restoreDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
@@ -50,7 +51,17 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
if err := r.SetDNS(dns.OSConfig{Nameservers: ns}); err != nil {
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
|
||||
}
|
||||
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to set DNS")
|
||||
return err
|
||||
}
|
||||
@@ -83,7 +94,7 @@ func restoreDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
|
||||
@@ -71,6 +71,11 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
|
||||
}
|
||||
trySystemdResolve := false
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
@@ -196,7 +201,8 @@ func restoreDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
|
||||
resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() }
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
return ns
|
||||
}
|
||||
|
||||
@@ -100,6 +100,10 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function.
|
||||
// searchDomains is still implemented for Windows just in case Windows API changes in future versions.
|
||||
_ = searchDomains
|
||||
|
||||
if len(serversV4) == 0 && len(serversV6) == 0 {
|
||||
return errors.New("invalid DNS nameservers")
|
||||
}
|
||||
|
||||
@@ -70,10 +70,17 @@ func ControlSocketName() string {
|
||||
}
|
||||
}
|
||||
|
||||
// logf is a function variable used for logging formatted debug messages with optional arguments.
|
||||
// This is used only when creating a new DNS OS configurator.
|
||||
var logf = func(format string, args ...any) {
|
||||
mainLog.Load().Debug().Msgf(format, args...)
|
||||
}
|
||||
|
||||
// noopLogf is like logf but discards formatted log messages and arguments without any processing.
|
||||
//
|
||||
//lint:ignore U1000 use in newLoopbackOSConfigurator
|
||||
var noopLogf = func(format string, args ...any) {}
|
||||
|
||||
var svcConfig = &service.Config{
|
||||
Name: ctrldServiceName,
|
||||
DisplayName: "Control-D Helper Service",
|
||||
|
||||
@@ -9,15 +9,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo"); err == nil {
|
||||
if r, err := newLoopbackOSConfigurator(); err == nil {
|
||||
useSystemdResolved = r.Mode() == "systemd-resolved"
|
||||
}
|
||||
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter.
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -24,12 +24,17 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
oc.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file")
|
||||
}
|
||||
return r.SetDNS(oc)
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter.
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -40,3 +45,8 @@ func shouldWatchResolvconf() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface.
|
||||
func newLoopbackOSConfigurator() (dns.OSConfigurator, error) {
|
||||
return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo")
|
||||
}
|
||||
|
||||
14
cmd/cli/search_domains_unix.go
Normal file
14
cmd/cli/search_domains_unix.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
return resolvconffile.SearchDomains()
|
||||
}
|
||||
43
cmd/cli/search_domains_windows.go
Normal file
43
cmd/cli/search_domains_windows.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err)
|
||||
}
|
||||
|
||||
var sds []dnsname.FQDN
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
|
||||
d, err := dnsname.ToFQDN(a.String())
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String())
|
||||
continue
|
||||
}
|
||||
sds = append(sds, d)
|
||||
}
|
||||
}
|
||||
return sds, nil
|
||||
}
|
||||
@@ -437,8 +437,9 @@ func (uc *UpstreamConfig) UID() string {
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
||||
isControlD := uc.IsControlD()
|
||||
nss := initDefaultOsResolver()
|
||||
for {
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, defaultNameservers())
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss)
|
||||
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
||||
// filtering them out here to prevent weird behavior.
|
||||
if isControlD {
|
||||
|
||||
7
desktop_darwin.go
Normal file
7
desktop_darwin.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return true
|
||||
}
|
||||
9
desktop_others.go
Normal file
9
desktop_others.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return false
|
||||
}
|
||||
7
desktop_windows.go
Normal file
7
desktop_windows.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
30
dns.go
Normal file
30
dns.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// SetCacheReply extracts and stores the necessary data from the message for a cached answer.
|
||||
func SetCacheReply(answer, msg *dns.Msg, code int) {
|
||||
answer.SetRcode(msg, code)
|
||||
cCookie := getEdns0Cookie(msg.IsEdns0())
|
||||
sCookie := getEdns0Cookie(answer.IsEdns0())
|
||||
if cCookie != nil && sCookie != nil {
|
||||
// Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc7873#section-4
|
||||
sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:]
|
||||
}
|
||||
}
|
||||
|
||||
// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present.
|
||||
func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE {
|
||||
if opt == nil {
|
||||
return nil
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if e, ok := o.(*dns.EDNS0_COOKIE); ok {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
51
doh.go
51
doh.go
@@ -2,6 +2,7 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -120,6 +121,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
resp, err = c.Do(req.Clone(retryCtx))
|
||||
}
|
||||
if err != nil {
|
||||
err = wrapUrlError(err)
|
||||
if r.isDoH3 {
|
||||
if closer, ok := c.Transport.(io.Closer); ok {
|
||||
closer.Close()
|
||||
@@ -208,3 +210,52 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header {
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
||||
// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer.
|
||||
// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output.
|
||||
// If no certificate-related information is available, it simply returns the original error unmodified.
|
||||
func wrapCertificateVerificationError(err error) error {
|
||||
var tlsErr *tls.CertificateVerificationError
|
||||
if errors.As(err, &tlsErr) {
|
||||
if len(tlsErr.UnverifiedCertificates) > 0 {
|
||||
cert := tlsErr.UnverifiedCertificates[0]
|
||||
// Extract a more user-friendly issuer name
|
||||
var issuer string
|
||||
var organization string
|
||||
if len(cert.Issuer.Organization) > 0 {
|
||||
organization = cert.Issuer.Organization[0]
|
||||
issuer = organization
|
||||
} else if cert.Issuer.CommonName != "" {
|
||||
issuer = cert.Issuer.CommonName
|
||||
} else {
|
||||
issuer = cert.Issuer.String()
|
||||
}
|
||||
|
||||
// Get the organization from the subject field as well
|
||||
if len(cert.Subject.Organization) > 0 {
|
||||
organization = cert.Subject.Organization[0]
|
||||
}
|
||||
|
||||
// Extract the subject information
|
||||
subjectCN := cert.Subject.CommonName
|
||||
if subjectCN == "" && len(cert.Subject.Organization) > 0 {
|
||||
subjectCN = cert.Subject.Organization[0]
|
||||
}
|
||||
return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context.
|
||||
func wrapUrlError(err error) error {
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
var tlsErr *tls.CertificateVerificationError
|
||||
if errors.As(urlErr.Err, &tlsErr) {
|
||||
urlErr.Err = wrapCertificateVerificationError(tlsErr)
|
||||
return urlErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
243
doh_test.go
243
doh_test.go
@@ -1,8 +1,22 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
func Test_dohOsHeaderValue(t *testing.T) {
|
||||
@@ -21,3 +35,232 @@ func Test_dohOsHeaderValue(t *testing.T) {
|
||||
t.Fatalf("missing decoding value for: %q", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapUrlError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "No wrapping for non-URL errors",
|
||||
err: errors.New("plain error"),
|
||||
wantErr: "plain error",
|
||||
},
|
||||
{
|
||||
name: "URL error without TLS error",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: errors.New("underlying error"),
|
||||
},
|
||||
wantErr: "Get \"https://example.com\": underlying error",
|
||||
},
|
||||
{
|
||||
name: "TLS error with missing unverified certificate data",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: &tls.CertificateVerificationError{
|
||||
UnverifiedCertificates: nil,
|
||||
Err: &x509.UnknownAuthorityError{},
|
||||
},
|
||||
},
|
||||
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority`,
|
||||
},
|
||||
{
|
||||
name: "TLS error with valid certificate data",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: &tls.CertificateVerificationError{
|
||||
UnverifiedCertificates: []*x509.Certificate{
|
||||
{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "BadSubjectCN",
|
||||
Organization: []string{"BadSubjectOrg"},
|
||||
},
|
||||
Issuer: pkix.Name{
|
||||
CommonName: "BadIssuerCN",
|
||||
Organization: []string{"BadIssuerOrg"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Err: &x509.UnknownAuthorityError{},
|
||||
},
|
||||
},
|
||||
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority: BadSubjectCN, BadSubjectOrg, BadIssuerOrg`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotErr := wrapUrlError(tt.err)
|
||||
if gotErr.Error() != tt.wantErr {
|
||||
t.Errorf("wrapCertificateVerificationError() error = %v, want %v", gotErr, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ClientCertificateVerificationError(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/dns-message")
|
||||
})
|
||||
tlsServer, cert := testTLSServer(t, handler)
|
||||
tlsServerUrl, err := url.Parse(tlsServer.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
quicServer := newTestQUICServer(t)
|
||||
http3Server := newTestHTTP3Server(t, handler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
}{
|
||||
{
|
||||
"doh",
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: tlsServer.URL,
|
||||
Timeout: 1000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"doh3",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: ResolverTypeDOH3,
|
||||
Endpoint: http3Server.addr,
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"doq",
|
||||
&UpstreamConfig{
|
||||
Name: "doq",
|
||||
Type: ResolverTypeDOQ,
|
||||
Endpoint: quicServer.addr,
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"dot",
|
||||
&UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: ResolverTypeDOT,
|
||||
Endpoint: net.JoinHostPort(tlsServerUrl.Hostname(), tlsServerUrl.Port()),
|
||||
Timeout: 1000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.SetupBootstrapIP()
|
||||
r, err := NewResolver(tc.uc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("verify.controld.com.", dns.TypeA)
|
||||
msg.RecursionDesired = true
|
||||
_, err = r.Resolve(context.Background(), msg)
|
||||
// Verify the error contains the expected certificate information
|
||||
if err == nil {
|
||||
t.Fatal("expected certificate verification error, got nil")
|
||||
}
|
||||
|
||||
// You can check the error contains information about the test certificate
|
||||
if !strings.Contains(err.Error(), cert.Issuer.CommonName) {
|
||||
t.Fatalf("error should contain issuer information %q, got: %v", cert.Issuer.CommonName, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTLSServer creates an HTTPS test server with a self-signed certificate
|
||||
// returns the server and its certificate for verification testing
|
||||
// testTLSServer creates an HTTPS test server with a self-signed certificate
|
||||
func testTLSServer(t *testing.T, handler http.Handler) (*httptest.Server, *x509.Certificate) {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
|
||||
// Create a test server
|
||||
server := httptest.NewUnstartedServer(handler)
|
||||
server.TLS = &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
}
|
||||
server.StartTLS()
|
||||
|
||||
// Add cleanup
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
return server, testCert.cert
|
||||
}
|
||||
|
||||
// testHTTP3Server represents a structure for an HTTP/3 test server with its server instance, TLS certificate, and address.
|
||||
type testHTTP3Server struct {
|
||||
server *http3.Server
|
||||
cert *x509.Certificate
|
||||
addr string
|
||||
}
|
||||
|
||||
// newTestHTTP3Server creates and starts a test HTTP/3 server with a given handler and returns the server instance.
|
||||
func newTestHTTP3Server(t *testing.T, handler http.Handler) *testHTTP3Server {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
|
||||
// First create a listener to get the actual port
|
||||
udpAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create UDP listener: %v", err)
|
||||
}
|
||||
|
||||
// Get the actual address
|
||||
actualAddr := udpConn.LocalAddr().String()
|
||||
|
||||
// Create TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
NextProtos: []string{"h3"}, // HTTP/3 protocol identifier
|
||||
}
|
||||
|
||||
// Create HTTP/3 server
|
||||
server := &http3.Server{
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
// Start the server with the existing UDP connection
|
||||
go func() {
|
||||
if err := server.Serve(udpConn); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Logf("HTTP/3 server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
h3Server := &testHTTP3Server{
|
||||
server: server,
|
||||
cert: testCert.cert,
|
||||
addr: actualAddr,
|
||||
}
|
||||
|
||||
// Add cleanup
|
||||
t.Cleanup(func() {
|
||||
server.Close()
|
||||
udpConn.Close()
|
||||
})
|
||||
|
||||
// Wait a bit for the server to be ready
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return h3Server
|
||||
}
|
||||
|
||||
2
doq.go
2
doq.go
@@ -43,7 +43,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, wrapCertificateVerificationError(err)
|
||||
}
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
223
doq_test.go
Normal file
223
doq_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
// test_helpers.go
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// testCertificate represents a test certificate with its components
|
||||
type testCertificate struct {
|
||||
cert *x509.Certificate
|
||||
tlsCert tls.Certificate
|
||||
template *x509.Certificate
|
||||
}
|
||||
|
||||
// generateTestCertificate creates a self-signed certificate for testing
|
||||
func generateTestCertificate(t *testing.T) *testCertificate {
|
||||
t.Helper()
|
||||
|
||||
// Generate private key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
// Create certificate template
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
CommonName: "Test CA",
|
||||
},
|
||||
Issuer: pkix.Name{
|
||||
Organization: []string{"Test Issuer Org"},
|
||||
CommonName: "Test Issuer CA",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse certificate: %v", err)
|
||||
}
|
||||
|
||||
// Create TLS certificate
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{derBytes},
|
||||
PrivateKey: privateKey,
|
||||
}
|
||||
|
||||
return &testCertificate{
|
||||
cert: cert,
|
||||
tlsCert: tlsCert,
|
||||
template: template,
|
||||
}
|
||||
}
|
||||
|
||||
// testQUICServer is a structure representing a test QUIC server for handling connections and streams.
|
||||
// listener is the QUIC listener used to accept incoming connections.
|
||||
// cert is the x509 certificate used by the server for authentication.
|
||||
// addr is the address on which the test server is running.
|
||||
type testQUICServer struct {
|
||||
listener *quic.Listener
|
||||
cert *x509.Certificate
|
||||
addr string
|
||||
}
|
||||
|
||||
// newTestQUICServer creates and initializes a test QUIC server with TLS configuration and starts accepting connections.
|
||||
func newTestQUICServer(t *testing.T) *testQUICServer {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
|
||||
// Create TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
NextProtos: []string{"doq"},
|
||||
}
|
||||
|
||||
// Create QUIC listener
|
||||
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create QUIC listener: %v", err)
|
||||
}
|
||||
|
||||
server := &testQUICServer{
|
||||
listener: listener,
|
||||
cert: testCert.cert,
|
||||
addr: listener.Addr().String(),
|
||||
}
|
||||
|
||||
// Start handling connections
|
||||
go server.serve(t)
|
||||
|
||||
// Add cleanup
|
||||
t.Cleanup(func() {
|
||||
listener.Close()
|
||||
})
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// serve handles incoming connections on the QUIC listener and delegates them to connection handlers in separate goroutines.
|
||||
func (s *testQUICServer) serve(t *testing.T) {
|
||||
for {
|
||||
conn, err := s.listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
// Check if the error is due to the listener being closed
|
||||
if strings.Contains(err.Error(), "server closed") {
|
||||
return
|
||||
}
|
||||
t.Logf("failed to accept connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleConnection(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection manages an individual QUIC connection by accepting and handling incoming streams in separate goroutines.
|
||||
func (s *testQUICServer) handleConnection(t *testing.T, conn quic.Connection) {
|
||||
for {
|
||||
stream, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go s.handleStream(t, stream)
|
||||
}
|
||||
}
|
||||
|
||||
// handleStream processes a single QUIC stream, reads DNS messages, generates a response, and sends it back to the client.
|
||||
func (s *testQUICServer) handleStream(t *testing.T, stream quic.Stream) {
|
||||
defer stream.Close()
|
||||
|
||||
// Read length (2 bytes)
|
||||
lenBuf := make([]byte, 2)
|
||||
_, err := stream.Read(lenBuf)
|
||||
if err != nil {
|
||||
t.Logf("failed to read message length: %v", err)
|
||||
return
|
||||
}
|
||||
msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1])
|
||||
|
||||
// Read message
|
||||
msgBuf := make([]byte, msgLen)
|
||||
_, err = stream.Read(msgBuf)
|
||||
if err != nil {
|
||||
t.Logf("failed to read message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse DNS message
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(msgBuf); err != nil {
|
||||
t.Logf("failed to unpack DNS message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create response
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(msg)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add a test answer
|
||||
if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA {
|
||||
response.Answer = append(response.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: net.ParseIP("192.0.2.1"), // TEST-NET-1 address
|
||||
})
|
||||
}
|
||||
|
||||
// Pack response
|
||||
respBytes, err := response.Pack()
|
||||
if err != nil {
|
||||
t.Logf("failed to pack response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write length
|
||||
respLen := uint16(len(respBytes))
|
||||
_, err = stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)})
|
||||
if err != nil {
|
||||
t.Logf("failed to write response length: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write response
|
||||
_, err = stream.Write(respBytes)
|
||||
if err != nil {
|
||||
t.Logf("failed to write response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
3
dot.go
3
dot.go
@@ -23,7 +23,6 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
|
||||
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
|
||||
dnsClient := &dns.Client{
|
||||
Net: tcpNet,
|
||||
@@ -39,5 +38,5 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
}
|
||||
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint)
|
||||
return answer, err
|
||||
return answer, wrapCertificateVerificationError(err)
|
||||
}
|
||||
|
||||
@@ -17,10 +17,17 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
v4BootstrapDNS = "76.76.2.22:53"
|
||||
v6BootstrapDNS = "[2606:1a40::22]:53"
|
||||
v4BootstrapDNS = "76.76.2.22:53"
|
||||
v6BootstrapDNS = "[2606:1a40::22]:53"
|
||||
v6BootstrapIP = "2606:1a40::22"
|
||||
defaultHTTPSPort = "443"
|
||||
defaultHTTPPort = "80"
|
||||
defaultDNSPort = "53"
|
||||
probeStackTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
var commonIPv6Ports = []string{defaultHTTPSPort, defaultHTTPPort, defaultDNSPort}
|
||||
|
||||
var Dialer = &net.Dialer{
|
||||
Resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
@@ -33,8 +40,6 @@ var Dialer = &net.Dialer{
|
||||
},
|
||||
}
|
||||
|
||||
const probeStackTimeout = 2 * time.Second
|
||||
|
||||
var probeStackDialer = &net.Dialer{
|
||||
Resolver: Dialer.Resolver,
|
||||
Timeout: probeStackTimeout,
|
||||
@@ -50,12 +55,28 @@ func init() {
|
||||
stackOnce.Store(new(sync.Once))
|
||||
}
|
||||
|
||||
func supportIPv6(ctx context.Context) bool {
|
||||
c, err := probeStackDialer.DialContext(ctx, "tcp6", v6BootstrapDNS)
|
||||
// supportIPv6 checks for IPv6 connectivity by attempting to connect to predefined ports
|
||||
// on a specific IPv6 address.
|
||||
// Returns a boolean indicating if IPv6 is supported and the port on which the connection succeeded.
|
||||
// If no connection is successful, returns false and an empty string.
|
||||
func supportIPv6(ctx context.Context) (supported bool, successPort string) {
|
||||
for _, port := range commonIPv6Ports {
|
||||
if canConnectToIPv6Port(ctx, port) {
|
||||
return true, string(port)
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// canConnectToIPv6Port attempts to establish a TCP connection to the specified port
|
||||
// using IPv6. Returns true if the connection was successful.
|
||||
func canConnectToIPv6Port(ctx context.Context, port string) bool {
|
||||
address := net.JoinHostPort(v6BootstrapIP, port)
|
||||
conn, err := probeStackDialer.DialContext(ctx, "tcp6", address)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
c.Close()
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -110,7 +131,8 @@ func SupportsIPv6ListenLocal() bool {
|
||||
|
||||
// IPv6Available is like SupportsIPv6, but always do the check without caching.
|
||||
func IPv6Available(ctx context.Context) bool {
|
||||
return supportIPv6(ctx)
|
||||
hasV6, _ := supportIPv6(ctx)
|
||||
return hasV6
|
||||
}
|
||||
|
||||
// IsIPv6 checks if the provided IP is v6.
|
||||
|
||||
@@ -12,7 +12,12 @@ func TestProbeStackTimeout(t *testing.T) {
|
||||
go func() {
|
||||
defer close(done)
|
||||
close(started)
|
||||
supportIPv6(context.Background())
|
||||
hasV6, port := supportIPv6(context.Background())
|
||||
if hasV6 {
|
||||
t.Logf("connect to port %s using ipv6: %v", port, hasV6)
|
||||
} else {
|
||||
t.Log("ipv6 is not available")
|
||||
}
|
||||
}()
|
||||
|
||||
<-started
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
|
||||
"tailscale.com/net/dns/resolvconffile"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
const resolvconfPath = "/etc/resolv.conf"
|
||||
@@ -22,7 +23,7 @@ func NameServersWithPort() []string {
|
||||
return ns
|
||||
}
|
||||
|
||||
func NameServers(_ string) []string {
|
||||
func NameServers() []string {
|
||||
c, err := resolvconffile.ParseFile(resolvconfPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -33,3 +34,12 @@ func NameServers(_ string) []string {
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// SearchDomains returns the current search domains config in /etc/resolv.conf file.
|
||||
func SearchDomains() ([]dnsname.FQDN, error) {
|
||||
c, err := resolvconffile.ParseFile(resolvconfPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.SearchDomains, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestNameServers(t *testing.T) {
|
||||
ns := NameServers("")
|
||||
ns := NameServers()
|
||||
require.NotNil(t, ns)
|
||||
t.Log(ns)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// currentNameserversFromResolvconf returns the current nameservers set from /etc/resolv.conf file.
|
||||
func currentNameserversFromResolvconf() []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file.
|
||||
@@ -34,7 +34,7 @@ func dnsFromResolvConf() []string {
|
||||
time.Sleep(retryInterval)
|
||||
}
|
||||
|
||||
nss := resolvconffile.NameServers("")
|
||||
nss := resolvconffile.NameServers()
|
||||
var localDNS []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
|
||||
110
resolver.go
110
resolver.go
@@ -9,12 +9,14 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
@@ -216,6 +218,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
type osResolver struct {
|
||||
lanServers atomic.Pointer[[]string]
|
||||
publicServers atomic.Pointer[[]string]
|
||||
group *singleflight.Group
|
||||
cache *sync.Map
|
||||
}
|
||||
|
||||
type osResolverResult struct {
|
||||
@@ -273,10 +277,75 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired
|
||||
return dnsClient.ExchangeContext(ctx, msg, server)
|
||||
}
|
||||
|
||||
const hotCacheTTL = time.Second
|
||||
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// The Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
//
|
||||
// To guard against unexpected DoS to upstreams, multiple queries of
|
||||
// the same Qtype to a domain will be shared, so there's only 1 qps
|
||||
// for each upstream at any time.
|
||||
//
|
||||
// Further, a hot cache will be used, so repeated queries will be cached
|
||||
// for a short period (currently 1 second), reducing unnecessary traffics
|
||||
// sent to upstreams.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if len(msg.Question) == 0 {
|
||||
return nil, errors.New("no question found")
|
||||
}
|
||||
domain := strings.TrimSuffix(msg.Question[0].Name, ".")
|
||||
qtype := msg.Question[0].Qtype
|
||||
|
||||
// Unique key for the singleflight group.
|
||||
key := fmt.Sprintf("%s:%d:", domain, qtype)
|
||||
|
||||
// Checking the cache first.
|
||||
if val, ok := o.cache.Load(key); ok {
|
||||
if val, ok := val.(*dns.Msg); ok {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
res := val.Copy()
|
||||
SetCacheReply(res, msg, val.Rcode)
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure only one DNS query is in flight for the key.
|
||||
v, err, shared := o.group.Do(key, func() (interface{}, error) {
|
||||
msg, err := o.resolve(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If we got an answer, storing it to the hot cache for hotCacheTTL
|
||||
// This prevents possible DoS to upstream, ensuring there's only 1 QPS.
|
||||
o.cache.Store(key, msg)
|
||||
// Depends on go runtime scheduling, the result may end up in hot cache longer
|
||||
// than hotCacheTTL duration. However, this is fine since we only want to guard
|
||||
// against DoS attack. The result will be cleaned from the cache eventually.
|
||||
time.AfterFunc(hotCacheTTL, func() {
|
||||
o.removeCache(key)
|
||||
})
|
||||
return msg, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sharedMsg, ok := v.(*dns.Msg)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid answer for key: %s", key)
|
||||
}
|
||||
res := sharedMsg.Copy()
|
||||
SetCacheReply(res, msg, sharedMsg.Rcode)
|
||||
if shared {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// resolve sends the query to current nameservers.
|
||||
func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
publicServers := *o.publicServers.Load()
|
||||
var nss []string
|
||||
if p := o.lanServers.Load(); p != nil {
|
||||
@@ -431,6 +500,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (o *osResolver) removeCache(key string) {
|
||||
o.cache.Delete(key)
|
||||
}
|
||||
|
||||
type legacyResolver struct {
|
||||
uc *UpstreamConfig
|
||||
}
|
||||
@@ -469,11 +542,26 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err
|
||||
// LookupIP looks up domain using current system nameservers settings.
|
||||
// It returns a slice of that host's IPv4 and IPv6 addresses.
|
||||
func LookupIP(domain string) []string {
|
||||
return lookupIP(domain, -1, defaultNameservers())
|
||||
nss := initDefaultOsResolver()
|
||||
return lookupIP(domain, -1, nss)
|
||||
}
|
||||
|
||||
// initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet.
|
||||
// It returns the combined list of LAN and public nameservers currently held by the resolver.
|
||||
func initDefaultOsResolver() []string {
|
||||
resolverMutex.Lock()
|
||||
defer resolverMutex.Unlock()
|
||||
if or == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers")
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
nss := *or.lanServers.Load()
|
||||
nss = append(nss, *or.publicServers.Load()...)
|
||||
return nss
|
||||
}
|
||||
|
||||
// lookupIP looks up domain with given timeout and bootstrapDNS.
|
||||
// If timeout is negative, default timeout 2000 ms will be used.
|
||||
// If the timeout is negative, default timeout 2000 ms will be used.
|
||||
// It returns nil if bootstrapDNS is nil or empty.
|
||||
func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) {
|
||||
if net.ParseIP(domain) != nil {
|
||||
@@ -577,13 +665,7 @@ func NewBootstrapResolver(servers ...string) Resolver {
|
||||
//
|
||||
// This is useful for doing PTR lookup in LAN network.
|
||||
func NewPrivateResolver() Resolver {
|
||||
resolverMutex.Lock()
|
||||
if or == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver in NewPrivateResolver")
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
nss := *or.lanServers.Load()
|
||||
resolverMutex.Unlock()
|
||||
nss := initDefaultOsResolver()
|
||||
resolveConfNss := currentNameserversFromResolvconf()
|
||||
localRfc1918Addrs := Rfc1918Addresses()
|
||||
n := 0
|
||||
@@ -627,10 +709,10 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
// newResolverWithNameserver returns an OS resolver from given nameservers list.
|
||||
// The caller must ensure each server in list is formed "ip:53".
|
||||
func newResolverWithNameserver(nameservers []string) *osResolver {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers)
|
||||
r := &osResolver{}
|
||||
r := &osResolver{
|
||||
group: &singleflight.Group{},
|
||||
cache: &sync.Map{},
|
||||
}
|
||||
var publicNss []string
|
||||
var lanNss []string
|
||||
for _, ns := range slices.Sorted(slices.Values(nameservers)) {
|
||||
|
||||
195
resolver_test.go
195
resolver_test.go
@@ -2,8 +2,11 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,8 +19,7 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServers.Store(&[]string{"127.0.0.127:5353"})
|
||||
resolver := newResolverWithNameserver([]string{"127.0.0.127:5353"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
@@ -50,8 +52,7 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) {
|
||||
t.Error("not a LAN query")
|
||||
return
|
||||
}
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServers.Store(&[]string{"76.76.2.0:53"})
|
||||
resolver := newResolverWithNameserver([]string{"76.76.2.0:53"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
@@ -107,11 +108,9 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
}()
|
||||
|
||||
// We now create an osResolver which has both a LAN and public nameserver.
|
||||
resolver := &osResolver{}
|
||||
// Explicitly store the LAN nameserver.
|
||||
resolver.lanServers.Store(&[]string{lanAddr})
|
||||
// And store the public nameservers.
|
||||
resolver.publicServers.Store(&publicNS)
|
||||
nss := []string{lanAddr}
|
||||
nss = append(nss, publicNS...)
|
||||
resolver := newResolverWithNameserver(nss)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
@@ -139,6 +138,150 @@ func Test_osResolver_InitializationRace(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func Test_osResolver_Singleflight(t *testing.T) {
|
||||
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on LAN address: %v", err)
|
||||
}
|
||||
call := &atomic.Int64{}
|
||||
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run LAN test server: %v", err)
|
||||
}
|
||||
defer lanServer.Shutdown()
|
||||
|
||||
or := newResolverWithNameserver([]string{lanAddr})
|
||||
domain := "controld.com"
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
_, err := or.Resolve(context.Background(), m)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// All above queries should only make 1 call to server.
|
||||
if call.Load() != 1 {
|
||||
t.Fatalf("expected 1 result from singleflight lookup, got %d", call)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_HotCache(t *testing.T) {
|
||||
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on LAN address: %v", err)
|
||||
}
|
||||
call := &atomic.Int64{}
|
||||
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run LAN test server: %v", err)
|
||||
}
|
||||
defer lanServer.Shutdown()
|
||||
|
||||
or := newResolverWithNameserver([]string{lanAddr})
|
||||
domain := "controld.com"
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
|
||||
// Make 2 repeated queries to server, should hit hot cache.
|
||||
for i := 0; i < 2; i++ {
|
||||
if _, err := or.Resolve(context.Background(), m.Copy()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if call.Load() != 1 {
|
||||
t.Fatalf("cache not hit, server was called: %d", call.Load())
|
||||
}
|
||||
|
||||
timeoutChan := make(chan struct{})
|
||||
time.AfterFunc(5*time.Second, func() {
|
||||
close(timeoutChan)
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutChan:
|
||||
t.Fatal("timed out waiting for cache cleaned")
|
||||
default:
|
||||
count := 0
|
||||
or.cache.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count != 0 {
|
||||
t.Logf("hot cache is not empty: %d elements", count)
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if _, err := or.Resolve(context.Background(), m.Copy()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if call.Load() != 2 {
|
||||
t.Fatal("cache hit unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Edns0_CacheReply(t *testing.T) {
|
||||
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on LAN address: %v", err)
|
||||
}
|
||||
call := &atomic.Int64{}
|
||||
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run LAN test server: %v", err)
|
||||
}
|
||||
defer lanServer.Shutdown()
|
||||
|
||||
or := newResolverWithNameserver([]string{lanAddr})
|
||||
domain := "controld.com"
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
|
||||
do := func() *dns.Msg {
|
||||
msg := m.Copy()
|
||||
msg.SetEdns0(4096, true)
|
||||
cookieOption := new(dns.EDNS0_COOKIE)
|
||||
cookieOption.Code = dns.EDNS0COOKIE
|
||||
cookieOption.Cookie = generateEdns0ClientCookie()
|
||||
msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption)
|
||||
|
||||
answer, err := or.Resolve(context.Background(), msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return answer
|
||||
}
|
||||
answer1 := do()
|
||||
answer2 := do()
|
||||
// Ensure the cache was hit, so we can check that edns0 cookie must be modified.
|
||||
if call.Load() != 1 {
|
||||
t.Fatalf("cache not hit, server was called: %d", call.Load())
|
||||
}
|
||||
cookie1 := getEdns0Cookie(answer1.IsEdns0())
|
||||
cookie2 := getEdns0Cookie(answer2.IsEdns0())
|
||||
if cookie1 == nil || cookie2 == nil {
|
||||
t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2)
|
||||
}
|
||||
if cookie1.Cookie == cookie2.Cookie {
|
||||
t.Fatalf("edns0 cookie is not modified: %v", cookie1)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -208,3 +351,37 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc {
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
|
||||
func countHandler(call *atomic.Int64) dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, dns.RcodeSuccess)
|
||||
if cookie := getEdns0Cookie(msg.IsEdns0()); cookie != nil {
|
||||
if m.IsEdns0() == nil {
|
||||
m.SetEdns0(4096, false)
|
||||
}
|
||||
cookieOption := new(dns.EDNS0_COOKIE)
|
||||
cookieOption.Code = dns.EDNS0COOKIE
|
||||
cookieOption.Cookie = generateEdns0ServerCookie(cookie.Cookie)
|
||||
m.IsEdns0().Option = append(m.IsEdns0().Option, cookieOption)
|
||||
}
|
||||
w.WriteMsg(m)
|
||||
call.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func generateEdns0ClientCookie() string {
|
||||
cookie := make([]byte, 8)
|
||||
if _, err := rand.Read(cookie); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hex.EncodeToString(cookie)
|
||||
}
|
||||
|
||||
func generateEdns0ServerCookie(clientCookie string) string {
|
||||
cookie := make([]byte, 32)
|
||||
if _, err := rand.Read(cookie); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return clientCookie + hex.EncodeToString(cookie)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user