Compare commits

...

16 Commits

Author SHA1 Message Date
Cuong Manh Le
36a7423634 refactor: extract empty string filtering to reusable function
- Add filterEmptyStrings utility function for consistent string filtering
- Replace inline slices.DeleteFunc calls with filterEmptyStrings
- Apply filtering to osArgs in addition to command args
- Improves code readability and reduces duplication
- Uses slices.DeleteFunc internally for efficient filtering
2025-07-15 23:09:54 +07:00
Cuong Manh Le
e616091249 cmd/cli: ignore empty positional argument for start command
The validation was added during v1.4.0 release, but causing one-liner
install failed unexpectedly.
2025-07-15 21:57:36 +07:00
Cuong Manh Le
0948161529 Avoiding Windows runners file locking issue 2025-07-15 20:59:57 +07:00
Cuong Manh Le
ce29b5d217 refactor: split selfUpgradeCheck into version check and upgrade execution
- Move version checking logic to shouldUpgrade for testability
- Move upgrade command execution to performUpgrade
- selfUpgradeCheck now composes these two for clarity
- Update and expand tests: focus on logic, not side effects
- Improves maintainability, testability, and separation of concerns
2025-07-15 19:12:23 +07:00
Cuong Manh Le
de24fa293e internal/router: support Ubios 4.3+
This change improves compatibility with newer UniFi OS versions while
maintaining backward compatibility with UniFi OS 4.2 and earlier.
The refactoring also reduces code duplication and improves maintainability
by centralizing dnsmasq configuration path logic.
2025-07-15 19:11:13 +07:00
Cuong Manh Le
6663925c4d internal/router: support Merlin Guest Network Pro VLAN
By looking for any additional dnsmasq configuration files under
/tmp/etc, and handling them like default one.
2025-07-15 19:10:10 +07:00
Cuong Manh Le
b9ece6d7b9 Merge pull request #239 from Control-D-Inc/release-branch-v1.4.4
Release branch v1.4.4
2025-06-16 16:45:11 +07:00
Cuong Manh Le
c4efa1ab97 Initializing default os resolver during upstream bootstrap
Since calling defaultNameservers may block the whole bootstrap process
if there's no valid DNS servers available.
2025-06-12 16:22:52 +07:00
Cuong Manh Le
7cea5305e1 all: fix a regression causing invalid reloading timeout
In v1.4.3, ControlD bootstrap DNS is used again for bootstrapping
process. When this happened, the default system nameservers will be
retrieved first, then ControlD DNS will be used if none available.

However, getting default system nameservers process may take longer than
reloading command timeout, causing invalid error message printed.

To fix this, ensuring default system nameservers is retrieved once.
2025-06-10 19:42:26 +07:00
Cuong Manh Le
a20fbf95de all: enhanced TLS certificate verification error messages
Added more descriptive error messages for TLS certificate verification
failures across DoH, DoT, DoQ, and DoH3 protocols. The error messages
now include:

- Certificate subject information
- Issuer organization details
- Common name of the certificate

This helps users and developers better understand certificate validation
failures by providing specific details about the untrusted certificate,
rather than just a generic "unknown authority" message.

Example error message change:
Before: "certificate signed by unknown authority"
After: "certificate signed by unknown authority: TestCA, TestOrg, TestIssuerOrg"
2025-06-10 19:42:00 +07:00
Cuong Manh Le
628c4302aa cmd/cli: preserve search domains when reverting resolv.conf
Fixes search domains not being preserved when the resolv.conf file is
reverted to its previous state. This ensures that important domain
search configuration is maintained during DNS configuration changes.

The search domains handling was missing in setResolvConf function,
which is responsible for restoring DNS settings.
2025-06-04 18:36:51 +07:00
Cuong Manh Le
8dc34f8bf5 internal/net: improve IPv6 support detection with multiple common ports
Changed the IPv6 support detection to try multiple common ports (HTTP/HTTPS) instead of
just testing against a DNS port. The function now returns both the IPv6 support status
and the successful port that confirmed the connectivity. This makes the IPv6 detection
more reliable by not depending solely on DNS port availability.

Previously, the function only tested connectivity to a DNS port (53) over IPv6.
Now it tries to connect to commonly available ports like HTTP (80) and HTTPS (443)
until it finds a working one, making the detection more robust in environments where
certain ports might be blocked.
2025-06-04 16:29:28 +07:00
Cuong Manh Le
b4faf82f76 all: set edns0 cookie for shared message
For cached or singleflight messages, the edns0 cookie is currently
shared among all of them, causing mismatch cookie warning from clients.
The ctrld proxy should re-set client cookies for each request
separately, even though they use the same shared answer.
2025-05-27 18:09:16 +07:00
Cuong Manh Le
a983dfaee2 all: optimizing multiple queries to upstreams
To guard ctrld from possible DoS to remote upstreams, this commit
implements following things:

 - Optimizing multiple queries with the same domain and qtype to use
   singleflight group, so there's only 1 query to remote upstreams at
   any time.
 - Adding a hot cache with 1 second TTL, so repeated queries will re-use
   the result from cache if existed, preventing unnecessary requests to
   remote upstreams.
2025-05-23 21:09:15 +07:00
Cuong Manh Le
62f73bcaa2 all: preserve search domains settings
So bare hostname will be resolved as expected when ctrld is running.
2025-05-15 17:00:59 +07:00
Cuong Manh Le
00e9d2bdd3 all: do not listen on 0.0.0.0 on desktop clients
Since this may create security vulnerabilities such as DNS amplification
or abusing because the listener was exposed to the entire local network.
2025-05-15 16:59:24 +07:00
37 changed files with 1547 additions and 117 deletions

View File

@@ -1216,13 +1216,18 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, noti
// For Windows server with local Dns server running, we can only try on random local IP.
hasLocalDnsServer := hasLocalDnsServerRunning()
notRouter := router.Name() == ""
isDesktop := ctrld.IsDesktopPlatform()
for n, listener := range cfg.Listener {
lcc[n] = &listenerConfigCheck{}
if listener.IP == "" {
listener.IP = "0.0.0.0"
if hasLocalDnsServer {
// Windows Server lies to us that we could listen on 0.0.0.0:53
// even there's a process already done that, stick to local IP only.
// Windows Server lies to us that we could listen on 0.0.0.0:53
// even there's a process already done that, stick to local IP only.
//
// For desktop clients, also stick the listener to the local IP only.
// Listening on 0.0.0.0 would expose it to the entire local network, potentially
// creating security vulnerabilities (such as DNS amplification or abusing).
if hasLocalDnsServer || isDesktop {
listener.IP = "127.0.0.1"
}
lcc[n].IP = true

View File

@@ -13,6 +13,7 @@ import (
"os/exec"
"path/filepath"
"runtime"
"slices"
"sort"
"strconv"
"strings"
@@ -206,6 +207,7 @@ func initStartCmd() *cobra.Command {
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
Args: func(cmd *cobra.Command, args []string) error {
args = filterEmptyStrings(args)
if len(args) > 0 {
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
@@ -219,6 +221,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
sc := &service.Config{}
*sc = *svcConfig
osArgs := os.Args[2:]
osArgs = filterEmptyStrings(osArgs)
if os.Args[1] == "service" {
osArgs = os.Args[3:]
}
@@ -566,6 +569,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
Args: func(cmd *cobra.Command, args []string) error {
args = filterEmptyStrings(args)
if len(args) > 0 {
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
@@ -1381,3 +1385,11 @@ func initServicesCmd(commands ...*cobra.Command) *cobra.Command {
return serviceCmd
}
// filterEmptyStrings removes empty strings from a slice of strings.
// It returns a new slice containing only non-empty strings.
func filterEmptyStrings(slice []string) []string {
return slices.DeleteFunc(slice, func(s string) bool {
return s == ""
})
}

View File

@@ -500,7 +500,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
continue
}
answer := cachedValue.Msg.Copy()
answer.SetRcode(req.msg, answer.Rcode)
ctrld.SetCacheReply(answer, req.msg, answer.Rcode)
now := time.Now()
if cachedValue.Expire.After(now) {
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
@@ -1042,8 +1042,10 @@ func (p *prog) queryFromSelf(ip string) bool {
return false
}
// needRFC1918Listeners reports whether ctrld need to spawn listener for RFC 1918 addresses.
// This is helpful for non-desktop platforms to receive queries from LAN clients.
func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
return lc.IP == "127.0.0.1" && lc.Port == 53
return lc.IP == "127.0.0.1" && lc.Port == 53 && !ctrld.IsDesktopPlatform()
}
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.

View File

@@ -47,6 +47,9 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
// TODO(cuonglm): use system API
func setDNS(iface *net.Interface, nameservers []string) error {
// Note that networksetup won't modify search domains settings,
// This assignment is just a placeholder to silent linter.
_ = searchDomains
cmd := "networksetup"
args := []string{"-setdnsservers", iface.Name}
args = append(args, nameservers...)
@@ -88,7 +91,7 @@ func restoreDNS(iface *net.Interface) (err error) {
}
func currentDNS(_ *net.Interface) []string {
return resolvconffile.NameServers("")
return resolvconffile.NameServers()
}
// currentStaticDNS returns the current static DNS settings of given interface.

View File

@@ -7,6 +7,7 @@ import (
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/util/dnsname"
"github.com/Control-D-Inc/ctrld/internal/dns"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
@@ -50,7 +51,17 @@ func setDNS(iface *net.Interface, nameservers []string) error {
ns = append(ns, netip.MustParseAddr(nameserver))
}
if err := r.SetDNS(dns.OSConfig{Nameservers: ns}); err != nil {
osConfig := dns.OSConfig{
Nameservers: ns,
SearchDomains: []dnsname.FQDN{},
}
if sds, err := searchDomains(); err == nil {
osConfig.SearchDomains = sds
} else {
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
}
if err := r.SetDNS(osConfig); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to set DNS")
return err
}
@@ -83,7 +94,7 @@ func restoreDNS(iface *net.Interface) (err error) {
}
func currentDNS(_ *net.Interface) []string {
return resolvconffile.NameServers("")
return resolvconffile.NameServers()
}
// currentStaticDNS returns the current static DNS settings of given interface.

View File

@@ -71,6 +71,11 @@ func setDNS(iface *net.Interface, nameservers []string) error {
Nameservers: ns,
SearchDomains: []dnsname.FQDN{},
}
if sds, err := searchDomains(); err == nil {
osConfig.SearchDomains = sds
} else {
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
}
trySystemdResolve := false
if err := r.SetDNS(osConfig); err != nil {
if strings.Contains(err.Error(), "Rejected send message") &&
@@ -196,7 +201,8 @@ func restoreDNS(iface *net.Interface) (err error) {
}
func currentDNS(iface *net.Interface) []string {
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() }
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
if ns := fn(iface.Name); len(ns) > 0 {
return ns
}

View File

@@ -100,6 +100,10 @@ func setDNS(iface *net.Interface, nameservers []string) error {
}
}
// Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function.
// searchDomains is still implemented for Windows just in case Windows API changes in future versions.
_ = searchDomains
if len(serversV4) == 0 && len(serversV6) == 0 {
return errors.New("invalid DNS nameservers")
}

View File

@@ -35,6 +35,7 @@ import (
"github.com/Control-D-Inc/ctrld/internal/controld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
"github.com/Control-D-Inc/ctrld/internal/router"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
)
const (
@@ -70,10 +71,17 @@ func ControlSocketName() string {
}
}
// logf is a function variable used for logging formatted debug messages with optional arguments.
// This is used only when creating a new DNS OS configurator.
var logf = func(format string, args ...any) {
mainLog.Load().Debug().Msgf(format, args...)
}
// noopLogf is like logf but discards formatted log messages and arguments without any processing.
//
//lint:ignore U1000 use in newLoopbackOSConfigurator
var noopLogf = func(format string, args ...any) {}
var svcConfig = &service.Config{
Name: ctrldServiceName,
DisplayName: "Control-D Helper Service",
@@ -321,7 +329,7 @@ func (p *prog) apiConfigReload() {
// Performing self-upgrade check for production version.
if isStable {
selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger)
_ = selfUpgradeCheck(resolverConfig.Ctrld.VersionTarget, curVer, &logger)
}
if resolverConfig.DeactivationPin != nil {
@@ -600,6 +608,12 @@ func (p *prog) setupClientInfoDiscover(selfIP string) {
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
if leaseFiles := dnsmasq.AdditionalLeaseFiles(); len(leaseFiles) > 0 {
mainLog.Load().Debug().Msgf("watching additional lease files: %v", leaseFiles)
for _, leaseFile := range leaseFiles {
p.ciTable.AddLeaseFile(leaseFile, ctrld.Dnsmasq)
}
}
}
// runClientInfoDiscover runs the client info discover.
@@ -1460,14 +1474,15 @@ func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
}
}
// selfUpgradeCheck checks if the version target vt is greater
// than the current one cv, perform self-upgrade then.
// shouldUpgrade checks if the version target vt is greater than the current one cv.
// Major version upgrades are not allowed to prevent breaking changes.
//
// The callers must ensure curVer and logger are non-nil.
func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) {
// Returns true if upgrade is allowed, false otherwise.
func shouldUpgrade(vt string, cv *semver.Version, logger *zerolog.Logger) bool {
if vt == "" {
logger.Debug().Msg("no version target set, skipped checking self-upgrade")
return
return false
}
vts := vt
if !strings.HasPrefix(vts, "v") {
@@ -1476,28 +1491,58 @@ func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) {
targetVer, err := semver.NewVersion(vts)
if err != nil {
logger.Warn().Err(err).Msgf("invalid target version, skipped self-upgrade: %s", vt)
return
return false
}
// Prevent major version upgrades to avoid breaking changes
if targetVer.Major() != cv.Major() {
logger.Warn().
Str("target", vt).
Str("current", cv.String()).
Msgf("major version upgrade not allowed (target: %d, current: %d), skipped self-upgrade", targetVer.Major(), cv.Major())
return false
}
if !targetVer.GreaterThan(cv) {
logger.Debug().
Str("target", vt).
Str("current", cv.String()).
Msgf("target version is not greater than current one, skipped self-upgrade")
return
return false
}
return true
}
// performUpgrade executes the self-upgrade command.
// Returns true if upgrade was initiated successfully, false otherwise.
func performUpgrade(vt string) bool {
exe, err := os.Executable()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to get executable path, skipped self-upgrade")
return
return false
}
cmd := exec.Command(exe, "upgrade", "prod", "-vv")
cmd.SysProcAttr = sysProcAttrForDetachedChildProcess()
if err := cmd.Start(); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to start self-upgrade")
return
return false
}
mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vts)
mainLog.Load().Debug().Msgf("self-upgrade triggered, version target: %s", vt)
return true
}
// selfUpgradeCheck checks if the version target vt is greater
// than the current one cv, perform self-upgrade then.
// Major version upgrades are not allowed to prevent breaking changes.
//
// The callers must ensure curVer and logger are non-nil.
// Returns true if upgrade is allowed and should proceed, false otherwise.
func selfUpgradeCheck(vt string, cv *semver.Version, logger *zerolog.Logger) bool {
if shouldUpgrade(vt, cv, logger) {
return performUpgrade(vt)
}
return false
}
// leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow

View File

@@ -9,15 +9,12 @@ import (
"strings"
"github.com/kardianos/service"
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"github.com/Control-D-Inc/ctrld/internal/dns"
"github.com/Control-D-Inc/ctrld/internal/router"
)
func init() {
if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo"); err == nil {
if r, err := newLoopbackOSConfigurator(); err == nil {
useSystemdResolved = r.Mode() == "systemd-resolved"
}
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911

View File

@@ -1,11 +1,15 @@
package cli
import (
"runtime"
"testing"
"time"
"github.com/Control-D-Inc/ctrld"
"github.com/Masterminds/semver/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/Control-D-Inc/ctrld"
)
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
@@ -55,3 +59,215 @@ func Test_prog_dnsWatchdogInterval(t *testing.T) {
})
}
}
func Test_shouldUpgrade(t *testing.T) {
// Helper function to create a version
makeVersion := func(v string) *semver.Version {
ver, err := semver.NewVersion(v)
if err != nil {
t.Fatalf("failed to create version %s: %v", v, err)
}
return ver
}
tests := []struct {
name string
versionTarget string
currentVersion *semver.Version
shouldUpgrade bool
description string
}{
{
name: "empty version target",
versionTarget: "",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: false,
description: "should skip upgrade when version target is empty",
},
{
name: "invalid version target",
versionTarget: "invalid-version",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: false,
description: "should skip upgrade when version target is invalid",
},
{
name: "same version",
versionTarget: "v1.0.0",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: false,
description: "should skip upgrade when target version equals current version",
},
{
name: "older version",
versionTarget: "v1.0.0",
currentVersion: makeVersion("v1.1.0"),
shouldUpgrade: false,
description: "should skip upgrade when target version is older than current version",
},
{
name: "patch upgrade allowed",
versionTarget: "v1.0.1",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: true,
description: "should allow patch version upgrade within same major version",
},
{
name: "minor upgrade allowed",
versionTarget: "v1.1.0",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: true,
description: "should allow minor version upgrade within same major version",
},
{
name: "major upgrade blocked",
versionTarget: "v2.0.0",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: false,
description: "should block major version upgrade",
},
{
name: "major downgrade blocked",
versionTarget: "v1.0.0",
currentVersion: makeVersion("v2.0.0"),
shouldUpgrade: false,
description: "should block major version downgrade",
},
{
name: "version without v prefix",
versionTarget: "1.0.1",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: true,
description: "should handle version target without v prefix",
},
{
name: "complex version upgrade allowed",
versionTarget: "v1.5.3",
currentVersion: makeVersion("v1.4.2"),
shouldUpgrade: true,
description: "should allow complex version upgrade within same major version",
},
{
name: "complex major upgrade blocked",
versionTarget: "v3.1.0",
currentVersion: makeVersion("v2.5.3"),
shouldUpgrade: false,
description: "should block complex major version upgrade",
},
{
name: "pre-release version upgrade allowed",
versionTarget: "v1.0.1-beta.1",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: true,
description: "should allow pre-release version upgrade within same major version",
},
{
name: "pre-release major upgrade blocked",
versionTarget: "v2.0.0-alpha.1",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: false,
description: "should block pre-release major version upgrade",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Create test logger
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
// Call the function and capture the result
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger)
// Assert the expected result
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
})
}
}
func Test_selfUpgradeCheck(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipped due to Windows file locking issue on Github Action runners")
}
// Helper function to create a version
makeVersion := func(v string) *semver.Version {
ver, err := semver.NewVersion(v)
if err != nil {
t.Fatalf("failed to create version %s: %v", v, err)
}
return ver
}
tests := []struct {
name string
versionTarget string
currentVersion *semver.Version
shouldUpgrade bool
description string
}{
{
name: "upgrade allowed",
versionTarget: "v1.0.1",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: true,
description: "should allow upgrade and attempt to perform it",
},
{
name: "upgrade blocked",
versionTarget: "v2.0.0",
currentVersion: makeVersion("v1.0.0"),
shouldUpgrade: false,
description: "should block upgrade and not attempt to perform it",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Create test logger
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
// Call the function and capture the result
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger)
// Assert the expected result
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
})
}
}
func Test_performUpgrade(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipped due to Windows file locking issue on Github Action runners")
}
tests := []struct {
name string
versionTarget string
expectedResult bool
description string
}{
{
name: "valid version target",
versionTarget: "v1.0.1",
expectedResult: true,
description: "should attempt to perform upgrade with valid version target",
},
{
name: "empty version target",
versionTarget: "",
expectedResult: true,
description: "should attempt to perform upgrade even with empty version target",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Call the function and capture the result
result := performUpgrade(tc.versionTarget)
assert.Equal(t, tc.expectedResult, result, tc.description)
})
}
}

View File

@@ -13,9 +13,9 @@ import (
"github.com/Control-D-Inc/ctrld/internal/dns"
)
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter.
r, err := newLoopbackOSConfigurator()
if err != nil {
return err
}
@@ -24,12 +24,17 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
Nameservers: ns,
SearchDomains: []dnsname.FQDN{},
}
if sds, err := searchDomains(); err == nil {
oc.SearchDomains = sds
} else {
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file")
}
return r.SetDNS(oc)
}
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
func shouldWatchResolvconf() bool {
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter.
r, err := newLoopbackOSConfigurator()
if err != nil {
return false
}
@@ -40,3 +45,8 @@ func shouldWatchResolvconf() bool {
return false
}
}
// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface.
func newLoopbackOSConfigurator() (dns.OSConfigurator, error) {
return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo")
}

View File

@@ -0,0 +1,14 @@
//go:build unix
package cli
import (
"tailscale.com/util/dnsname"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
// searchDomains returns the current search domains config.
func searchDomains() ([]dnsname.FQDN, error) {
return resolvconffile.SearchDomains()
}

View File

@@ -0,0 +1,43 @@
package cli
import (
"fmt"
"syscall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/util/dnsname"
)
// searchDomains returns the current search domains config.
func searchDomains() ([]dnsname.FQDN, error) {
flags := winipcfg.GAAFlagIncludeGateways |
winipcfg.GAAFlagIncludePrefix
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
if err != nil {
return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err)
}
var sds []dnsname.FQDN
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
continue
}
// Skip if software loopback or other non-physical types
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
continue
}
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
d, err := dnsname.ToFQDN(a.String())
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String())
continue
}
sds = append(sds, d)
}
}
return sds, nil
}

View File

@@ -437,8 +437,9 @@ func (uc *UpstreamConfig) UID() string {
func (uc *UpstreamConfig) SetupBootstrapIP() {
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
isControlD := uc.IsControlD()
nss := initDefaultOsResolver()
for {
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, defaultNameservers())
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss)
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
// filtering them out here to prevent weird behavior.
if isControlD {

7
desktop_darwin.go Normal file
View File

@@ -0,0 +1,7 @@
package ctrld
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
// currently defined as macOS or Windows workstation.
func IsDesktopPlatform() bool {
return true
}

9
desktop_others.go Normal file
View File

@@ -0,0 +1,9 @@
//go:build !windows && !darwin
package ctrld
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
// currently defined as macOS or Windows workstation.
func IsDesktopPlatform() bool {
return false
}

7
desktop_windows.go Normal file
View File

@@ -0,0 +1,7 @@
package ctrld
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
// currently defined as macOS or Windows workstation.
func IsDesktopPlatform() bool {
return isWindowsWorkStation()
}

30
dns.go Normal file
View File

@@ -0,0 +1,30 @@
package ctrld
import (
"github.com/miekg/dns"
)
// SetCacheReply extracts and stores the necessary data from the message for a cached answer.
func SetCacheReply(answer, msg *dns.Msg, code int) {
answer.SetRcode(msg, code)
cCookie := getEdns0Cookie(msg.IsEdns0())
sCookie := getEdns0Cookie(answer.IsEdns0())
if cCookie != nil && sCookie != nil {
// Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes.
// See https://datatracker.ietf.org/doc/html/rfc7873#section-4
sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:]
}
}
// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present.
func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE {
if opt == nil {
return nil
}
for _, o := range opt.Option {
if e, ok := o.(*dns.EDNS0_COOKIE); ok {
return e
}
}
return nil
}

51
doh.go
View File

@@ -2,6 +2,7 @@ package ctrld
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
@@ -120,6 +121,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
resp, err = c.Do(req.Clone(retryCtx))
}
if err != nil {
err = wrapUrlError(err)
if r.isDoH3 {
if closer, ok := c.Transport.(io.Closer); ok {
closer.Close()
@@ -208,3 +210,52 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header {
}
return header
}
// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer.
// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output.
// If no certificate-related information is available, it simply returns the original error unmodified.
func wrapCertificateVerificationError(err error) error {
var tlsErr *tls.CertificateVerificationError
if errors.As(err, &tlsErr) {
if len(tlsErr.UnverifiedCertificates) > 0 {
cert := tlsErr.UnverifiedCertificates[0]
// Extract a more user-friendly issuer name
var issuer string
var organization string
if len(cert.Issuer.Organization) > 0 {
organization = cert.Issuer.Organization[0]
issuer = organization
} else if cert.Issuer.CommonName != "" {
issuer = cert.Issuer.CommonName
} else {
issuer = cert.Issuer.String()
}
// Get the organization from the subject field as well
if len(cert.Subject.Organization) > 0 {
organization = cert.Subject.Organization[0]
}
// Extract the subject information
subjectCN := cert.Subject.CommonName
if subjectCN == "" && len(cert.Subject.Organization) > 0 {
subjectCN = cert.Subject.Organization[0]
}
return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer)
}
}
return err
}
// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context.
func wrapUrlError(err error) error {
var urlErr *url.Error
if errors.As(err, &urlErr) {
var tlsErr *tls.CertificateVerificationError
if errors.As(urlErr.Err, &tlsErr) {
urlErr.Err = wrapCertificateVerificationError(tlsErr)
return urlErr
}
}
return err
}

View File

@@ -1,8 +1,22 @@
package ctrld
import (
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"net"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/quic-go/quic-go/http3"
)
func Test_dohOsHeaderValue(t *testing.T) {
@@ -21,3 +35,232 @@ func Test_dohOsHeaderValue(t *testing.T) {
t.Fatalf("missing decoding value for: %q", runtime.GOOS)
}
}
func Test_wrapUrlError(t *testing.T) {
tests := []struct {
name string
err error
wantErr string
}{
{
name: "No wrapping for non-URL errors",
err: errors.New("plain error"),
wantErr: "plain error",
},
{
name: "URL error without TLS error",
err: &url.Error{
Op: "Get",
URL: "https://example.com",
Err: errors.New("underlying error"),
},
wantErr: "Get \"https://example.com\": underlying error",
},
{
name: "TLS error with missing unverified certificate data",
err: &url.Error{
Op: "Get",
URL: "https://example.com",
Err: &tls.CertificateVerificationError{
UnverifiedCertificates: nil,
Err: &x509.UnknownAuthorityError{},
},
},
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority`,
},
{
name: "TLS error with valid certificate data",
err: &url.Error{
Op: "Get",
URL: "https://example.com",
Err: &tls.CertificateVerificationError{
UnverifiedCertificates: []*x509.Certificate{
{
Subject: pkix.Name{
CommonName: "BadSubjectCN",
Organization: []string{"BadSubjectOrg"},
},
Issuer: pkix.Name{
CommonName: "BadIssuerCN",
Organization: []string{"BadIssuerOrg"},
},
},
},
Err: &x509.UnknownAuthorityError{},
},
},
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority: BadSubjectCN, BadSubjectOrg, BadIssuerOrg`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotErr := wrapUrlError(tt.err)
if gotErr.Error() != tt.wantErr {
t.Errorf("wrapCertificateVerificationError() error = %v, want %v", gotErr, tt.wantErr)
}
})
}
}
func Test_ClientCertificateVerificationError(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/dns-message")
})
tlsServer, cert := testTLSServer(t, handler)
tlsServerUrl, err := url.Parse(tlsServer.URL)
if err != nil {
t.Fatal(err)
}
quicServer := newTestQUICServer(t)
http3Server := newTestHTTP3Server(t, handler)
tests := []struct {
name string
uc *UpstreamConfig
}{
{
"doh",
&UpstreamConfig{
Name: "doh",
Type: ResolverTypeDOH,
Endpoint: tlsServer.URL,
Timeout: 1000,
},
},
{
"doh3",
&UpstreamConfig{
Name: "doh3",
Type: ResolverTypeDOH3,
Endpoint: http3Server.addr,
Timeout: 5000,
},
},
{
"doq",
&UpstreamConfig{
Name: "doq",
Type: ResolverTypeDOQ,
Endpoint: quicServer.addr,
Timeout: 5000,
},
},
{
"dot",
&UpstreamConfig{
Name: "dot",
Type: ResolverTypeDOT,
Endpoint: net.JoinHostPort(tlsServerUrl.Hostname(), tlsServerUrl.Port()),
Timeout: 1000,
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init()
tc.uc.SetupBootstrapIP()
r, err := NewResolver(tc.uc)
if err != nil {
t.Fatal(err)
}
msg := new(dns.Msg)
msg.SetQuestion("verify.controld.com.", dns.TypeA)
msg.RecursionDesired = true
_, err = r.Resolve(context.Background(), msg)
// Verify the error contains the expected certificate information
if err == nil {
t.Fatal("expected certificate verification error, got nil")
}
// You can check the error contains information about the test certificate
if !strings.Contains(err.Error(), cert.Issuer.CommonName) {
t.Fatalf("error should contain issuer information %q, got: %v", cert.Issuer.CommonName, err)
}
})
}
}
// testTLSServer creates an HTTPS test server with a self-signed certificate
// returns the server and its certificate for verification testing
// testTLSServer creates an HTTPS test server with a self-signed certificate
func testTLSServer(t *testing.T, handler http.Handler) (*httptest.Server, *x509.Certificate) {
t.Helper()
testCert := generateTestCertificate(t)
// Create a test server
server := httptest.NewUnstartedServer(handler)
server.TLS = &tls.Config{
Certificates: []tls.Certificate{testCert.tlsCert},
}
server.StartTLS()
// Add cleanup
t.Cleanup(server.Close)
return server, testCert.cert
}
// testHTTP3Server represents a structure for an HTTP/3 test server with its server instance, TLS certificate, and address.
type testHTTP3Server struct {
server *http3.Server
cert *x509.Certificate
addr string
}
// newTestHTTP3Server creates and starts a test HTTP/3 server with a given handler and returns the server instance.
func newTestHTTP3Server(t *testing.T, handler http.Handler) *testHTTP3Server {
t.Helper()
testCert := generateTestCertificate(t)
// First create a listener to get the actual port
udpAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
t.Fatalf("failed to create UDP listener: %v", err)
}
// Get the actual address
actualAddr := udpConn.LocalAddr().String()
// Create TLS config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{testCert.tlsCert},
NextProtos: []string{"h3"}, // HTTP/3 protocol identifier
}
// Create HTTP/3 server
server := &http3.Server{
Handler: handler,
TLSConfig: tlsConfig,
}
// Start the server with the existing UDP connection
go func() {
if err := server.Serve(udpConn); err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Logf("HTTP/3 server error: %v", err)
}
}()
h3Server := &testHTTP3Server{
server: server,
cert: testCert.cert,
addr: actualAddr,
}
// Add cleanup
t.Cleanup(func() {
server.Close()
udpConn.Close()
})
// Wait a bit for the server to be ready
time.Sleep(100 * time.Millisecond)
return h3Server
}

2
doq.go
View File

@@ -43,7 +43,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
continue
}
if err != nil {
return nil, err
return nil, wrapCertificateVerificationError(err)
}
return answer, nil
}

223
doq_test.go Normal file
View File

@@ -0,0 +1,223 @@
// test_helpers.go
package ctrld
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
// testCertificate represents a test certificate with its components
type testCertificate struct {
cert *x509.Certificate
tlsCert tls.Certificate
template *x509.Certificate
}
// generateTestCertificate creates a self-signed certificate for testing
func generateTestCertificate(t *testing.T) *testCertificate {
t.Helper()
// Generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate private key: %v", err)
}
// Create certificate template
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
CommonName: "Test CA",
},
Issuer: pkix.Name{
Organization: []string{"Test Issuer Org"},
CommonName: "Test Issuer CA",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"localhost"},
}
// Create certificate
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("failed to create certificate: %v", err)
}
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
t.Fatalf("failed to parse certificate: %v", err)
}
// Create TLS certificate
tlsCert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: privateKey,
}
return &testCertificate{
cert: cert,
tlsCert: tlsCert,
template: template,
}
}
// testQUICServer is a structure representing a test QUIC server for handling connections and streams.
// listener is the QUIC listener used to accept incoming connections.
// cert is the x509 certificate used by the server for authentication.
// addr is the address on which the test server is running.
type testQUICServer struct {
listener *quic.Listener
cert *x509.Certificate
addr string
}
// newTestQUICServer creates and initializes a test QUIC server with TLS configuration and starts accepting connections.
func newTestQUICServer(t *testing.T) *testQUICServer {
t.Helper()
testCert := generateTestCertificate(t)
// Create TLS config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{testCert.tlsCert},
NextProtos: []string{"doq"},
}
// Create QUIC listener
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil)
if err != nil {
t.Fatalf("failed to create QUIC listener: %v", err)
}
server := &testQUICServer{
listener: listener,
cert: testCert.cert,
addr: listener.Addr().String(),
}
// Start handling connections
go server.serve(t)
// Add cleanup
t.Cleanup(func() {
listener.Close()
})
return server
}
// serve handles incoming connections on the QUIC listener and delegates them to connection handlers in separate goroutines.
func (s *testQUICServer) serve(t *testing.T) {
for {
conn, err := s.listener.Accept(context.Background())
if err != nil {
// Check if the error is due to the listener being closed
if strings.Contains(err.Error(), "server closed") {
return
}
t.Logf("failed to accept connection: %v", err)
continue
}
go s.handleConnection(t, conn)
}
}
// handleConnection manages an individual QUIC connection by accepting and handling incoming streams in separate goroutines.
func (s *testQUICServer) handleConnection(t *testing.T, conn quic.Connection) {
for {
stream, err := conn.AcceptStream(context.Background())
if err != nil {
return
}
go s.handleStream(t, stream)
}
}
// handleStream processes a single QUIC stream, reads DNS messages, generates a response, and sends it back to the client.
func (s *testQUICServer) handleStream(t *testing.T, stream quic.Stream) {
defer stream.Close()
// Read length (2 bytes)
lenBuf := make([]byte, 2)
_, err := stream.Read(lenBuf)
if err != nil {
t.Logf("failed to read message length: %v", err)
return
}
msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1])
// Read message
msgBuf := make([]byte, msgLen)
_, err = stream.Read(msgBuf)
if err != nil {
t.Logf("failed to read message: %v", err)
return
}
// Parse DNS message
msg := new(dns.Msg)
if err := msg.Unpack(msgBuf); err != nil {
t.Logf("failed to unpack DNS message: %v", err)
return
}
// Create response
response := new(dns.Msg)
response.SetReply(msg)
response.Authoritative = true
// Add a test answer
if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA {
response.Answer = append(response.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.ParseIP("192.0.2.1"), // TEST-NET-1 address
})
}
// Pack response
respBytes, err := response.Pack()
if err != nil {
t.Logf("failed to pack response: %v", err)
return
}
// Write length
respLen := uint16(len(respBytes))
_, err = stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)})
if err != nil {
t.Logf("failed to write response length: %v", err)
return
}
// Write response
_, err = stream.Write(respBytes)
if err != nil {
t.Logf("failed to write response: %v", err)
return
}
}

3
dot.go
View File

@@ -23,7 +23,6 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: tcpNet,
@@ -39,5 +38,5 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
}
answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint)
return answer, err
return answer, wrapCertificateVerificationError(err)
}

View File

@@ -17,10 +17,17 @@ import (
)
const (
v4BootstrapDNS = "76.76.2.22:53"
v6BootstrapDNS = "[2606:1a40::22]:53"
v4BootstrapDNS = "76.76.2.22:53"
v6BootstrapDNS = "[2606:1a40::22]:53"
v6BootstrapIP = "2606:1a40::22"
defaultHTTPSPort = "443"
defaultHTTPPort = "80"
defaultDNSPort = "53"
probeStackTimeout = 2 * time.Second
)
var commonIPv6Ports = []string{defaultHTTPSPort, defaultHTTPPort, defaultDNSPort}
var Dialer = &net.Dialer{
Resolver: &net.Resolver{
PreferGo: true,
@@ -33,8 +40,6 @@ var Dialer = &net.Dialer{
},
}
const probeStackTimeout = 2 * time.Second
var probeStackDialer = &net.Dialer{
Resolver: Dialer.Resolver,
Timeout: probeStackTimeout,
@@ -50,12 +55,28 @@ func init() {
stackOnce.Store(new(sync.Once))
}
func supportIPv6(ctx context.Context) bool {
c, err := probeStackDialer.DialContext(ctx, "tcp6", v6BootstrapDNS)
// supportIPv6 checks for IPv6 connectivity by attempting to connect to predefined ports
// on a specific IPv6 address.
// Returns a boolean indicating if IPv6 is supported and the port on which the connection succeeded.
// If no connection is successful, returns false and an empty string.
func supportIPv6(ctx context.Context) (supported bool, successPort string) {
for _, port := range commonIPv6Ports {
if canConnectToIPv6Port(ctx, port) {
return true, string(port)
}
}
return false, ""
}
// canConnectToIPv6Port attempts to establish a TCP connection to the specified port
// using IPv6. Returns true if the connection was successful.
func canConnectToIPv6Port(ctx context.Context, port string) bool {
address := net.JoinHostPort(v6BootstrapIP, port)
conn, err := probeStackDialer.DialContext(ctx, "tcp6", address)
if err != nil {
return false
}
c.Close()
_ = conn.Close()
return true
}
@@ -110,7 +131,8 @@ func SupportsIPv6ListenLocal() bool {
// IPv6Available is like SupportsIPv6, but always do the check without caching.
func IPv6Available(ctx context.Context) bool {
return supportIPv6(ctx)
hasV6, _ := supportIPv6(ctx)
return hasV6
}
// IsIPv6 checks if the provided IP is v6.

View File

@@ -12,7 +12,12 @@ func TestProbeStackTimeout(t *testing.T) {
go func() {
defer close(done)
close(started)
supportIPv6(context.Background())
hasV6, port := supportIPv6(context.Background())
if hasV6 {
t.Logf("connect to port %s using ipv6: %v", port, hasV6)
} else {
t.Log("ipv6 is not available")
}
}()
<-started

View File

@@ -6,6 +6,7 @@ import (
"net"
"tailscale.com/net/dns/resolvconffile"
"tailscale.com/util/dnsname"
)
const resolvconfPath = "/etc/resolv.conf"
@@ -22,7 +23,7 @@ func NameServersWithPort() []string {
return ns
}
func NameServers(_ string) []string {
func NameServers() []string {
c, err := resolvconffile.ParseFile(resolvconfPath)
if err != nil {
return nil
@@ -33,3 +34,12 @@ func NameServers(_ string) []string {
}
return ns
}
// SearchDomains returns the current search domains config in /etc/resolv.conf file.
func SearchDomains() ([]dnsname.FQDN, error) {
c, err := resolvconffile.ParseFile(resolvconfPath)
if err != nil {
return nil, err
}
return c.SearchDomains, nil
}

View File

@@ -9,7 +9,7 @@ import (
)
func TestNameServers(t *testing.T) {
ns := NameServers("")
ns := NameServers()
require.NotNil(t, ns)
t.Log(ns)
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"io"
"os"
"path/filepath"
"strings"
)
@@ -28,3 +29,62 @@ func interfaceNameFromReader(r io.Reader) (string, error) {
}
return "", errors.New("not found")
}
// AdditionalConfigFiles returns a list of Dnsmasq configuration files found in the "/tmp/etc" directory.
func AdditionalConfigFiles() []string {
if paths, err := filepath.Glob("/tmp/etc/dnsmasq-*.conf"); err == nil {
return paths
}
return nil
}
// AdditionalLeaseFiles returns a list of lease file paths corresponding to the Dnsmasq configuration files.
func AdditionalLeaseFiles() []string {
cfgFiles := AdditionalConfigFiles()
if len(cfgFiles) == 0 {
return nil
}
leaseFiles := make([]string, 0, len(cfgFiles))
for _, cfgFile := range cfgFiles {
if leaseFile := leaseFileFromConfigFileName(cfgFile); leaseFile != "" {
leaseFiles = append(leaseFiles, leaseFile)
} else {
leaseFiles = append(leaseFiles, defaultLeaseFileFromConfigPath(cfgFile))
}
}
return leaseFiles
}
// leaseFileFromConfigFileName retrieves the DHCP lease file path by reading and parsing the provided configuration file.
func leaseFileFromConfigFileName(cfgFile string) string {
if f, err := os.Open(cfgFile); err == nil {
return leaseFileFromReader(f)
}
return ""
}
// leaseFileFromReader parses the given io.Reader for the "dhcp-leasefile" configuration and returns its value as a string.
func leaseFileFromReader(r io.Reader) string {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
continue
}
before, after, found := strings.Cut(line, "=")
if !found {
continue
}
if before == "dhcp-leasefile" {
return after
}
}
return ""
}
// defaultLeaseFileFromConfigPath generates the default lease file path based on the provided configuration file path.
func defaultLeaseFileFromConfigPath(path string) string {
name := filepath.Base(path)
return filepath.Join("/var/lib/misc", strings.TrimSuffix(name, ".conf")+".leases")
}

View File

@@ -1,6 +1,7 @@
package dnsmasq
import (
"io"
"strings"
"testing"
)
@@ -44,3 +45,49 @@ interface=eth0
})
}
}
func Test_leaseFileFromReader(t *testing.T) {
tests := []struct {
name string
in io.Reader
expected string
}{
{
"default",
strings.NewReader(`
dhcp-script=/sbin/dhcpc_lease
dhcp-leasefile=/var/lib/misc/dnsmasq-1.leases
script-arp
`),
"/var/lib/misc/dnsmasq-1.leases",
},
{
"non-default",
strings.NewReader(`
dhcp-script=/sbin/dhcpc_lease
dhcp-leasefile=/tmp/var/lib/misc/dnsmasq-1.leases
script-arp
`),
"/tmp/var/lib/misc/dnsmasq-1.leases",
},
{
"missing",
strings.NewReader(`
dhcp-script=/sbin/dhcpc_lease
script-arp
`),
"",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := leaseFileFromReader(tc.in); got != tc.expected {
t.Errorf("leaseFileFromReader() = %v, want %v", got, tc.expected)
}
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"html/template"
"net"
"os"
"path/filepath"
"strings"
@@ -26,9 +27,13 @@ max-cache-ttl=0
{{- end}}
`
const MerlinConfPath = "/tmp/etc/dnsmasq.conf"
const MerlinJffsConfPath = "/jffs/configs/dnsmasq.conf"
const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf"
const (
MerlinConfPath = "/tmp/etc/dnsmasq.conf"
MerlinJffsConfDir = "/jffs/configs"
MerlinJffsConfPath = "/jffs/configs/dnsmasq.conf"
MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf"
)
const MerlinPostConfMarker = `# GENERATED BY ctrld - EOF`
const MerlinPostConfTmpl = `# GENERATED BY ctrld - DO NOT MODIFY
@@ -159,3 +164,27 @@ func FirewallaSelfInterfaces() []*net.Interface {
}
return ifaces
}
const (
ubios43ConfPath = "/run/dnsmasq.dhcp.conf.d"
ubios42ConfPath = "/run/dnsmasq.conf.d"
ubios43PidFile = "/run/dnsmasq-main.pid"
ubios42PidFile = "/run/dnsmasq.pid"
UbiosConfName = "zzzctrld.conf"
)
// UbiosConfPath returns the appropriate configuration path based on the system's directory structure.
func UbiosConfPath() string {
if st, _ := os.Stat(ubios43ConfPath); st != nil && st.IsDir() {
return ubios43ConfPath
}
return ubios42ConfPath
}
// UbiosPidFile returns the appropriate dnsmasq pid file based on the system's directory structure.
func UbiosPidFile() string {
if st, _ := os.Stat(ubios43PidFile); st != nil && !st.IsDir() {
return ubios43PidFile
}
return ubios42PidFile
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/kardianos/service"
@@ -181,7 +182,7 @@ func ContentFilteringEnabled() bool {
// DnsShieldEnabled reports whether DNS Shield is enabled.
// See: https://community.ui.com/releases/UniFi-OS-Dream-Machines-3-2-7/251dfc1e-f4dd-4264-a080-3be9d8b9e02b
func DnsShieldEnabled() bool {
buf, err := os.ReadFile("/var/run/dnsmasq.conf.d/dns.conf")
buf, err := os.ReadFile(filepath.Join(dnsmasq.UbiosConfPath(), "dns.conf"))
if err != nil {
return false
}

View File

@@ -6,6 +6,7 @@ import (
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"unicode"
@@ -20,10 +21,18 @@ import (
const Name = "merlin"
// nvramKvMap is a map of NVRAM key-value pairs used to configure and manage Merlin-specific settings.
var nvramKvMap = map[string]string{
"dnspriv_enable": "0", // Ensure Merlin native DoT disabled.
}
// dnsmasqConfig represents configuration paths for dnsmasq operations in Merlin firmware.
type dnsmasqConfig struct {
confPath string
jffsConfPath string
}
// Merlin represents a configuration handler for setting up and managing ctrld on Merlin routers.
type Merlin struct {
cfg *ctrld.Config
}
@@ -33,18 +42,22 @@ func New(cfg *ctrld.Config) *Merlin {
return &Merlin{cfg: cfg}
}
// ConfigureService configures the service based on the provided configuration. It returns an error if the configuration fails.
func (m *Merlin) ConfigureService(config *service.Config) error {
return nil
}
// Install sets up the necessary configurations and services required for the Merlin instance to function properly.
func (m *Merlin) Install(_ *service.Config) error {
return nil
}
// Uninstall removes the ctrld-related configurations and services from the Merlin router and reverts to the original state.
func (m *Merlin) Uninstall(_ *service.Config) error {
return nil
}
// PreRun prepares the Merlin instance for operation by waiting for essential services and directories to become available.
func (m *Merlin) PreRun() error {
// Wait NTP ready.
_ = m.Cleanup()
@@ -66,6 +79,7 @@ func (m *Merlin) PreRun() error {
return nil
}
// Setup initializes and configures the Merlin instance for use, including setting up dnsmasq and necessary nvram settings.
func (m *Merlin) Setup() error {
if m.cfg.FirstListener().IsDirectDnsListener() {
return nil
@@ -79,35 +93,10 @@ func (m *Merlin) Setup() error {
return err
}
// Copy current dnsmasq config to /jffs/configs/dnsmasq.conf,
// Then we will run postconf script on this file.
//
// Normally, adding postconf script is enough. However, we see
// reports on some Merlin devices that postconf scripts does not
// work, but manipulating the config directly via /jffs/configs does.
src, err := os.Open(dnsmasq.MerlinConfPath)
if err != nil {
return fmt.Errorf("failed to open dnsmasq config: %w", err)
}
defer src.Close()
dst, err := os.Create(dnsmasq.MerlinJffsConfPath)
if err != nil {
return fmt.Errorf("failed to create %s: %w", dnsmasq.MerlinJffsConfPath, err)
}
defer dst.Close()
if _, err := io.Copy(dst, src); err != nil {
return fmt.Errorf("failed to copy current dnsmasq config: %w", err)
}
if err := dst.Close(); err != nil {
return fmt.Errorf("failed to save %s: %w", dnsmasq.MerlinJffsConfPath, err)
}
// Run postconf script on /jffs/configs/dnsmasq.conf directly.
cmd := exec.Command("/bin/sh", dnsmasq.MerlinPostConfPath, dnsmasq.MerlinJffsConfPath)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to run post conf: %s: %w", string(out), err)
for _, cfg := range getDnsmasqConfigs() {
if err := m.setupDnsmasq(cfg); err != nil {
return fmt.Errorf("failed to setup dnsmasq: config: %s, error: %w", cfg.confPath, err)
}
}
// Restart dnsmasq service.
@@ -122,6 +111,7 @@ func (m *Merlin) Setup() error {
return nil
}
// Cleanup restores the original dnsmasq and nvram configurations and restarts dnsmasq if necessary.
func (m *Merlin) Cleanup() error {
if m.cfg.FirstListener().IsDirectDnsListener() {
return nil
@@ -143,9 +133,11 @@ func (m *Merlin) Cleanup() error {
if err := os.WriteFile(dnsmasq.MerlinPostConfPath, merlinParsePostConf(buf), 0750); err != nil {
return err
}
// Remove /jffs/configs/dnsmasq.conf file.
if err := os.Remove(dnsmasq.MerlinJffsConfPath); err != nil && !os.IsNotExist(err) {
return err
for _, cfg := range getDnsmasqConfigs() {
if err := m.cleanupDnsmasqJffs(cfg); err != nil {
return fmt.Errorf("failed to cleanup jffs dnsmasq: config: %s, error: %w", cfg.confPath, err)
}
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
@@ -154,6 +146,54 @@ func (m *Merlin) Cleanup() error {
return nil
}
// setupDnsmasq sets up dnsmasq configuration by writing postconf, copying configuration, and running a postconf script.
func (m *Merlin) setupDnsmasq(cfg *dnsmasqConfig) error {
src, err := os.Open(cfg.confPath)
if os.IsNotExist(err) {
return nil // nothing to do if conf file does not exist.
}
if err != nil {
return fmt.Errorf("failed to open dnsmasq config: %w", err)
}
defer src.Close()
// Copy current dnsmasq config to cfg.jffsConfPath,
// Then we will run postconf script on this file.
//
// Normally, adding postconf script is enough. However, we see
// reports on some Merlin devices that postconf scripts does not
// work, but manipulating the config directly via /jffs/configs does.
dst, err := os.Create(cfg.jffsConfPath)
if err != nil {
return fmt.Errorf("failed to create %s: %w", cfg.jffsConfPath, err)
}
defer dst.Close()
if _, err := io.Copy(dst, src); err != nil {
return fmt.Errorf("failed to copy current dnsmasq config: %w", err)
}
if err := dst.Close(); err != nil {
return fmt.Errorf("failed to save %s: %w", cfg.jffsConfPath, err)
}
// Run postconf script on cfg.jffsConfPath directly.
cmd := exec.Command("/bin/sh", dnsmasq.MerlinPostConfPath, cfg.jffsConfPath)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to run post conf: %s: %w", string(out), err)
}
return nil
}
// cleanupDnsmasqJffs removes the JFFS configuration file specified in the given dnsmasqConfig, if it exists.
func (m *Merlin) cleanupDnsmasqJffs(cfg *dnsmasqConfig) error {
// Remove cfg.jffsConfPath file.
if err := os.Remove(cfg.jffsConfPath); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
// writeDnsmasqPostconf writes the requireddnsmasqConfigs post-configuration for dnsmasq to enable custom DNS settings with ctrld.
func (m *Merlin) writeDnsmasqPostconf() error {
buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath)
// Already setup.
@@ -179,6 +219,8 @@ func (m *Merlin) writeDnsmasqPostconf() error {
return os.WriteFile(dnsmasq.MerlinPostConfPath, []byte(data), 0750)
}
// restartDNSMasq restarts the dnsmasq service by executing the appropriate system command using "service".
// Returns an error if the command fails or if there is an issue processing the command output.
func restartDNSMasq() error {
if out, err := exec.Command("service", "restart_dnsmasq").CombinedOutput(); err != nil {
return fmt.Errorf("restart_dnsmasq: %s, %w", string(out), err)
@@ -186,6 +228,22 @@ func restartDNSMasq() error {
return nil
}
// getDnsmasqConfigs retrieves a list of dnsmasqConfig containing configuration and JFFS paths for dnsmasq operations.
func getDnsmasqConfigs() []*dnsmasqConfig {
cfgs := []*dnsmasqConfig{
{dnsmasq.MerlinConfPath, dnsmasq.MerlinJffsConfPath},
}
for _, path := range dnsmasq.AdditionalConfigFiles() {
jffsConfPath := filepath.Join(dnsmasq.MerlinJffsConfDir, filepath.Base(path))
cfgs = append(cfgs, &dnsmasqConfig{path, jffsConfPath})
}
return cfgs
}
// merlinParsePostConf parses the dnsmasq post configuration by removing content after the MerlinPostConfMarker, if present.
// If no marker is found, the original buffer is returned unmodified.
// Returns nil if the input buffer is empty.
func merlinParsePostConf(buf []byte) []byte {
if len(buf) == 0 {
return nil
@@ -197,6 +255,7 @@ func merlinParsePostConf(buf []byte) []byte {
return buf
}
// waitDirExists waits until the specified directory exists, polling its existence every second.
func waitDirExists(dir string) {
for {
if _, err := os.Stat(dir); !os.IsNotExist(err) {

View File

@@ -13,14 +13,13 @@ import (
"time"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
)
// This is a copy of https://github.com/kardianos/service/blob/v1.2.1/service_sysv_linux.go,
// with modification for supporting ubios v1 init system.
// Keep in sync with ubios.ubiosDNSMasqConfigPath
const ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
type ubiosSvc struct {
i service.Interface
platform string
@@ -86,7 +85,7 @@ func (s *ubiosSvc) Install() error {
}{
s.Config,
path,
ubiosDNSMasqConfigPath,
filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName),
}
if err := s.template().Execute(f, to); err != nil {

View File

@@ -3,6 +3,7 @@ package ubios
import (
"bytes"
"os"
"path/filepath"
"strconv"
"github.com/kardianos/service"
@@ -12,19 +13,19 @@ import (
"github.com/Control-D-Inc/ctrld/internal/router/edgeos"
)
const (
Name = "ubios"
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf"
)
const Name = "ubios"
type Ubios struct {
cfg *ctrld.Config
cfg *ctrld.Config
dnsmasqConfPath string
}
// New returns a router.Router for configuring/setup/run ctrld on Ubios routers.
func New(cfg *ctrld.Config) *Ubios {
return &Ubios{cfg: cfg}
return &Ubios{
cfg: cfg,
dnsmasqConfPath: filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName),
}
}
func (u *Ubios) ConfigureService(config *service.Config) error {
@@ -59,7 +60,7 @@ func (u *Ubios) Setup() error {
if err != nil {
return err
}
if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil {
if err := os.WriteFile(u.dnsmasqConfPath, []byte(data), 0600); err != nil {
return err
}
// Restart dnsmasq service.
@@ -74,7 +75,7 @@ func (u *Ubios) Cleanup() error {
return nil
}
// Remove the custom dnsmasq config
if err := os.Remove(ubiosDNSMasqConfigPath); err != nil {
if err := os.Remove(u.dnsmasqConfPath); err != nil {
return err
}
// Restart dnsmasq service.
@@ -85,7 +86,7 @@ func (u *Ubios) Cleanup() error {
}
func restartDNSMasq() error {
buf, err := os.ReadFile("/run/dnsmasq.pid")
buf, err := os.ReadFile(dnsmasq.UbiosPidFile())
if err != nil {
return err
}

View File

@@ -14,7 +14,7 @@ import (
// currentNameserversFromResolvconf returns the current nameservers set from /etc/resolv.conf file.
func currentNameserversFromResolvconf() []string {
return resolvconffile.NameServers("")
return resolvconffile.NameServers()
}
// dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file.
@@ -34,7 +34,7 @@ func dnsFromResolvConf() []string {
time.Sleep(retryInterval)
}
nss := resolvconffile.NameServers("")
nss := resolvconffile.NameServers()
var localDNS []string
seen := make(map[string]bool)

View File

@@ -9,12 +9,14 @@ import (
"net/netip"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"github.com/rs/zerolog"
"golang.org/x/sync/singleflight"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
)
@@ -216,6 +218,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
type osResolver struct {
lanServers atomic.Pointer[[]string]
publicServers atomic.Pointer[[]string]
group *singleflight.Group
cache *sync.Map
}
type osResolverResult struct {
@@ -273,10 +277,75 @@ func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desired
return dnsClient.ExchangeContext(ctx, msg, server)
}
const hotCacheTTL = time.Second
// Resolve resolves DNS queries using pre-configured nameservers.
// Query is sent to all nameservers concurrently, and the first
// The Query is sent to all nameservers concurrently, and the first
// success response will be returned.
//
// To guard against unexpected DoS to upstreams, multiple queries of
// the same Qtype to a domain will be shared, so there's only 1 qps
// for each upstream at any time.
//
// Further, a hot cache will be used, so repeated queries will be cached
// for a short period (currently 1 second), reducing unnecessary traffics
// sent to upstreams.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if len(msg.Question) == 0 {
return nil, errors.New("no question found")
}
domain := strings.TrimSuffix(msg.Question[0].Name, ".")
qtype := msg.Question[0].Qtype
// Unique key for the singleflight group.
key := fmt.Sprintf("%s:%d:", domain, qtype)
// Checking the cache first.
if val, ok := o.cache.Load(key); ok {
if val, ok := val.(*dns.Msg); ok {
Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype])
res := val.Copy()
SetCacheReply(res, msg, val.Rcode)
return res, nil
}
}
// Ensure only one DNS query is in flight for the key.
v, err, shared := o.group.Do(key, func() (interface{}, error) {
msg, err := o.resolve(ctx, msg)
if err != nil {
return nil, err
}
// If we got an answer, storing it to the hot cache for hotCacheTTL
// This prevents possible DoS to upstream, ensuring there's only 1 QPS.
o.cache.Store(key, msg)
// Depends on go runtime scheduling, the result may end up in hot cache longer
// than hotCacheTTL duration. However, this is fine since we only want to guard
// against DoS attack. The result will be cleaned from the cache eventually.
time.AfterFunc(hotCacheTTL, func() {
o.removeCache(key)
})
return msg, nil
})
if err != nil {
return nil, err
}
sharedMsg, ok := v.(*dns.Msg)
if !ok {
return nil, fmt.Errorf("invalid answer for key: %s", key)
}
res := sharedMsg.Copy()
SetCacheReply(res, msg, sharedMsg.Rcode)
if shared {
Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype])
}
return res, nil
}
// resolve sends the query to current nameservers.
func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
publicServers := *o.publicServers.Load()
var nss []string
if p := o.lanServers.Load(); p != nil {
@@ -431,6 +500,10 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
return nil, errors.Join(errs...)
}
func (o *osResolver) removeCache(key string) {
o.cache.Delete(key)
}
type legacyResolver struct {
uc *UpstreamConfig
}
@@ -469,11 +542,26 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err
// LookupIP looks up domain using current system nameservers settings.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func LookupIP(domain string) []string {
return lookupIP(domain, -1, defaultNameservers())
nss := initDefaultOsResolver()
return lookupIP(domain, -1, nss)
}
// initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet.
// It returns the combined list of LAN and public nameservers currently held by the resolver.
func initDefaultOsResolver() []string {
resolverMutex.Lock()
defer resolverMutex.Unlock()
if or == nil {
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers")
or = newResolverWithNameserver(defaultNameservers())
}
nss := *or.lanServers.Load()
nss = append(nss, *or.publicServers.Load()...)
return nss
}
// lookupIP looks up domain with given timeout and bootstrapDNS.
// If timeout is negative, default timeout 2000 ms will be used.
// If the timeout is negative, default timeout 2000 ms will be used.
// It returns nil if bootstrapDNS is nil or empty.
func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) {
if net.ParseIP(domain) != nil {
@@ -577,13 +665,7 @@ func NewBootstrapResolver(servers ...string) Resolver {
//
// This is useful for doing PTR lookup in LAN network.
func NewPrivateResolver() Resolver {
resolverMutex.Lock()
if or == nil {
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver in NewPrivateResolver")
or = newResolverWithNameserver(defaultNameservers())
}
nss := *or.lanServers.Load()
resolverMutex.Unlock()
nss := initDefaultOsResolver()
resolveConfNss := currentNameserversFromResolvconf()
localRfc1918Addrs := Rfc1918Addresses()
n := 0
@@ -627,10 +709,10 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
// newResolverWithNameserver returns an OS resolver from given nameservers list.
// The caller must ensure each server in list is formed "ip:53".
func newResolverWithNameserver(nameservers []string) *osResolver {
logger := *ProxyLogger.Load()
Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers)
r := &osResolver{}
r := &osResolver{
group: &singleflight.Group{},
cache: &sync.Map{},
}
var publicNss []string
var lanNss []string
for _, ns := range slices.Sorted(slices.Values(nameservers)) {

View File

@@ -2,8 +2,11 @@ package ctrld
import (
"context"
"crypto/rand"
"encoding/hex"
"net"
"sync"
"sync/atomic"
"testing"
"time"
@@ -16,8 +19,7 @@ func Test_osResolver_Resolve(t *testing.T) {
go func() {
defer cancel()
resolver := &osResolver{}
resolver.publicServers.Store(&[]string{"127.0.0.127:5353"})
resolver := newResolverWithNameserver([]string{"127.0.0.127:5353"})
m := new(dns.Msg)
m.SetQuestion("controld.com.", dns.TypeA)
m.RecursionDesired = true
@@ -50,8 +52,7 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) {
t.Error("not a LAN query")
return
}
resolver := &osResolver{}
resolver.publicServers.Store(&[]string{"76.76.2.0:53"})
resolver := newResolverWithNameserver([]string{"76.76.2.0:53"})
m := new(dns.Msg)
m.SetQuestion("controld.com.", dns.TypeA)
m.RecursionDesired = true
@@ -107,11 +108,9 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
}()
// We now create an osResolver which has both a LAN and public nameserver.
resolver := &osResolver{}
// Explicitly store the LAN nameserver.
resolver.lanServers.Store(&[]string{lanAddr})
// And store the public nameservers.
resolver.publicServers.Store(&publicNS)
nss := []string{lanAddr}
nss = append(nss, publicNS...)
resolver := newResolverWithNameserver(nss)
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
@@ -139,6 +138,150 @@ func Test_osResolver_InitializationRace(t *testing.T) {
wg.Wait()
}
func Test_osResolver_Singleflight(t *testing.T) {
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on LAN address: %v", err)
}
call := &atomic.Int64{}
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call))
if err != nil {
t.Fatalf("failed to run LAN test server: %v", err)
}
defer lanServer.Shutdown()
or := newResolverWithNameserver([]string{lanAddr})
domain := "controld.com"
n := 10
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
m.RecursionDesired = true
_, err := or.Resolve(context.Background(), m)
if err != nil {
t.Error(err)
}
}()
}
wg.Wait()
// All above queries should only make 1 call to server.
if call.Load() != 1 {
t.Fatalf("expected 1 result from singleflight lookup, got %d", call)
}
}
func Test_osResolver_HotCache(t *testing.T) {
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on LAN address: %v", err)
}
call := &atomic.Int64{}
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call))
if err != nil {
t.Fatalf("failed to run LAN test server: %v", err)
}
defer lanServer.Shutdown()
or := newResolverWithNameserver([]string{lanAddr})
domain := "controld.com"
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
m.RecursionDesired = true
// Make 2 repeated queries to server, should hit hot cache.
for i := 0; i < 2; i++ {
if _, err := or.Resolve(context.Background(), m.Copy()); err != nil {
t.Fatal(err)
}
}
if call.Load() != 1 {
t.Fatalf("cache not hit, server was called: %d", call.Load())
}
timeoutChan := make(chan struct{})
time.AfterFunc(5*time.Second, func() {
close(timeoutChan)
})
for {
select {
case <-timeoutChan:
t.Fatal("timed out waiting for cache cleaned")
default:
count := 0
or.cache.Range(func(key, value interface{}) bool {
count++
return true
})
if count != 0 {
t.Logf("hot cache is not empty: %d elements", count)
continue
}
}
break
}
if _, err := or.Resolve(context.Background(), m.Copy()); err != nil {
t.Fatal(err)
}
if call.Load() != 2 {
t.Fatal("cache hit unexpectedly")
}
}
func Test_Edns0_CacheReply(t *testing.T) {
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on LAN address: %v", err)
}
call := &atomic.Int64{}
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call))
if err != nil {
t.Fatalf("failed to run LAN test server: %v", err)
}
defer lanServer.Shutdown()
or := newResolverWithNameserver([]string{lanAddr})
domain := "controld.com"
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
m.RecursionDesired = true
do := func() *dns.Msg {
msg := m.Copy()
msg.SetEdns0(4096, true)
cookieOption := new(dns.EDNS0_COOKIE)
cookieOption.Code = dns.EDNS0COOKIE
cookieOption.Cookie = generateEdns0ClientCookie()
msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption)
answer, err := or.Resolve(context.Background(), msg)
if err != nil {
t.Fatal(err)
}
return answer
}
answer1 := do()
answer2 := do()
// Ensure the cache was hit, so we can check that edns0 cookie must be modified.
if call.Load() != 1 {
t.Fatalf("cache not hit, server was called: %d", call.Load())
}
cookie1 := getEdns0Cookie(answer1.IsEdns0())
cookie2 := getEdns0Cookie(answer2.IsEdns0())
if cookie1 == nil || cookie2 == nil {
t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2)
}
if cookie1.Cookie == cookie2.Cookie {
t.Fatalf("edns0 cookie is not modified: %v", cookie1)
}
}
func Test_upstreamTypeFromEndpoint(t *testing.T) {
tests := []struct {
name string
@@ -208,3 +351,37 @@ func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc {
w.WriteMsg(m)
}
}
func countHandler(call *atomic.Int64) dns.HandlerFunc {
return func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, dns.RcodeSuccess)
if cookie := getEdns0Cookie(msg.IsEdns0()); cookie != nil {
if m.IsEdns0() == nil {
m.SetEdns0(4096, false)
}
cookieOption := new(dns.EDNS0_COOKIE)
cookieOption.Code = dns.EDNS0COOKIE
cookieOption.Cookie = generateEdns0ServerCookie(cookie.Cookie)
m.IsEdns0().Option = append(m.IsEdns0().Option, cookieOption)
}
w.WriteMsg(m)
call.Add(1)
}
}
func generateEdns0ClientCookie() string {
cookie := make([]byte, 8)
if _, err := rand.Read(cookie); err != nil {
panic(err)
}
return hex.EncodeToString(cookie)
}
func generateEdns0ServerCookie(clientCookie string) string {
cookie := make([]byte, 32)
if _, err := rand.Read(cookie); err != nil {
panic(err)
}
return clientCookie + hex.EncodeToString(cookie)
}