Merge pull request #218 from Control-D-Inc/release-branch-v1.4.1

Release branch v1.4.1
This commit is contained in:
Cuong Manh Le
2025-03-07 08:25:38 +07:00
committed by GitHub
42 changed files with 1501 additions and 407 deletions

View File

@@ -19,7 +19,7 @@ jobs:
with:
go-version: ${{ matrix.go }}
- run: "go test -race ./..."
- uses: dominikh/staticcheck-action@v1.2.0
- uses: dominikh/staticcheck-action@v1.3.1
with:
version: "2024.1.1"
install-go: false

5
cmd/cli/cgo.go Normal file
View File

@@ -0,0 +1,5 @@
//go:build cgo
package cli
const cgoEnabled = true

View File

@@ -61,7 +61,7 @@ var (
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
defaultConfigFile = "ctrld.toml"
rootCertPool *x509.CertPool
errSelfCheckNoAnswer = errors.New("no answer from ctrld listener")
errSelfCheckNoAnswer = errors.New("no response from ctrld listener. You can try to re-launch with flag --skip_self_checks")
)
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"}
@@ -222,10 +222,16 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
lc := &logConn{conn: conn}
consoleWriter.Out = io.MultiWriter(os.Stdout, lc)
p.logConn = lc
} else {
mainLog.Load().Warn().Err(err).Msgf("unable to create log ipc connection")
}
} else {
mainLog.Load().Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath)
}
notifyExitToLogServer := func() {
_, _ = p.logConn.Write([]byte(msgExit))
if p.logConn != nil {
_, _ = p.logConn.Write([]byte(msgExit))
}
}
if daemon && runtime.GOOS == "windows" {
@@ -266,10 +272,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
// Log config do not have thing to validate, so it's safe to init log here,
// so it's able to log information in processCDFlags.
logWriters := initLogging()
// Initializing internal logging after global logging.
p.initInternalLogging(logWriters)
p.initLogging(true)
mainLog.Load().Info().Msgf("starting ctrld %s", curVersion())
mainLog.Load().Info().Msgf("os: %s", osVersion())
@@ -322,7 +325,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
}
}
updated := updateListenerConfig(&cfg)
updated := updateListenerConfig(&cfg, notifyExitToLogServer)
if cdUID != "" {
processLogAndCacheFlags()
@@ -418,7 +421,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
if err := p.router.Cleanup(); err != nil {
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
}
p.resetDNS()
// restore static DNS settings or DHCP
p.resetDNS(false, true)
})
}
}
@@ -484,7 +488,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
}
nop := zerolog.Nop()
_, _ = tryUpdateListenerConfig(&cfg, &nop, true)
_, _ = tryUpdateListenerConfig(&cfg, &nop, func() {}, true)
addExtraSplitDnsRule(&cfg)
if err := writeConfigFile(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
@@ -645,11 +649,15 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) {
// Fetch config, unmarshal to cfg.
if resolverConfig.Ctrld.CustomConfig != "" {
logger.Info().Msg("using defined custom config of Control-D resolver")
if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil {
var cfgErr error
if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil {
setListenerDefaultValue(cfg)
return resolverConfig, nil
setNetworkDefaultValue(cfg)
if cfgErr = validateConfig(cfg); cfgErr == nil {
return resolverConfig, nil
}
}
mainLog.Load().Err(err).Msg("disregarding invalid custom config")
mainLog.Load().Warn().Err(err).Msg("disregarding invalid custom config")
}
bootstrapIP := func(endpoint string) string {
@@ -666,11 +674,7 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) {
}
return ""
}
cfg.Network = make(map[string]*ctrld.NetworkConfig)
cfg.Network["0"] = &ctrld.NetworkConfig{
Name: "Network 0",
Cidrs: []string{"0.0.0.0/0"},
}
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
BootstrapIP: bootstrapIP(resolverConfig.DOH),
@@ -693,6 +697,7 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) {
// Set default value.
setListenerDefaultValue(cfg)
setNetworkDefaultValue(cfg)
return resolverConfig, nil
}
@@ -706,7 +711,21 @@ func setListenerDefaultValue(cfg *ctrld.Config) {
}
}
// setListenerDefaultValue sets the default value for cfg.Listener if none existed.
func setNetworkDefaultValue(cfg *ctrld.Config) {
if len(cfg.Network) == 0 {
cfg.Network = map[string]*ctrld.NetworkConfig{
"0": {
Name: "Network 0",
Cidrs: []string{"0.0.0.0/0"},
},
}
}
}
// validateCdRemoteConfig validates the custom config from ControlD if defined.
// This only validate the config syntax. To validate the config rules, calling
// validateConfig with the cfg after calling this function.
func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error {
if rc.Ctrld.CustomConfig == "" {
return nil
@@ -783,7 +802,13 @@ func defaultIfaceName() string {
if oi := osinfo.New(); strings.Contains(oi.String(), "Microsoft") {
return "lo"
}
mainLog.Load().Fatal().Err(err).Msg("failed to get default route interface")
// On linux, it could be either resolvconf or systemd which is managing DNS settings,
// so the interface name does not matter if there's no default route interface.
if runtime.GOOS == "linux" {
return "lo"
}
mainLog.Load().Debug().Err(err).Msg("no default route interface found")
return ""
}
return dri
}
@@ -843,10 +868,12 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo
}
mainLog.Load().Debug().Msg("ctrld listener is ready")
mainLog.Load().Debug().Msg("performing self-check")
lc := cfg.FirstListener()
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
mainLog.Load().Debug().Msgf("performing listener test, sending queries to %s", addr)
if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil {
return false, status, err
}
@@ -860,7 +887,7 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo
func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain string) error {
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
bo.LogLongerThan = 500 * time.Millisecond
maxAttempts := 20
maxAttempts := 10
c := new(dns.Client)
var (
@@ -876,7 +903,7 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri
m := new(dns.Msg)
m.SetQuestion(domain+".", dns.TypeA)
m.RecursionDesired = true
r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, addr)
r, _, exErr := exchangeContextWithTimeout(c, 5*time.Second, m, addr)
if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 {
mainLog.Load().Debug().Msgf("%s self-check against %q succeeded", scope, domain)
return nil
@@ -1030,16 +1057,25 @@ func uninstall(p *prog, s service.Service) {
mainLog.Load().Warn().Err(err).Msg("post uninstallation failed, please check system/service log for details error")
return
}
p.resetDNS()
// restore static DNS settings or DHCP
p.resetDNS(false, true)
// if present restore the original DNS settings
if netIface, err := netInterface(p.runningIface); err == nil {
if err := restoreDNS(netIface); err != nil {
mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface")
} else {
mainLog.Load().Debug().Msg("Restored DNS on interface successfully")
// Iterate over all physical interfaces and restore DNS if a saved static config exists.
withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error {
file := savedStaticDnsSettingsFilePath(i)
if _, err := os.Stat(file); err == nil {
if err := restoreDNS(i); err != nil {
mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name)
} else {
mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name)
err = os.Remove(file)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("Could not remove saved static DNS file for interface %s", i.Name)
}
}
}
}
return nil
})
if router.Name() != "" {
mainLog.Load().Debug().Msg("Router cleanup")
@@ -1146,8 +1182,8 @@ func mobileListenerIp() string {
// updateListenerConfig updates the config for listeners if not defined,
// or defined but invalid to be used, e.g: using loopback address other
// than 127.0.0.1 with systemd-resolved.
func updateListenerConfig(cfg *ctrld.Config) bool {
updated, _ := tryUpdateListenerConfig(cfg, nil, true)
func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool {
updated, _ := tryUpdateListenerConfig(cfg, nil, notifyToLogServerFunc, true)
if addExtraSplitDnsRule(cfg) {
updated = true
}
@@ -1157,13 +1193,14 @@ func updateListenerConfig(cfg *ctrld.Config) bool {
// tryUpdateListenerConfig tries updating listener config with a working one.
// If fatal is true, and there's listen address conflicted, the function do
// fatal error.
func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fatal bool) (updated, ok bool) {
func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, notifyFunc func(), fatal bool) (updated, ok bool) {
ok = true
lcc := make(map[string]*listenerConfigCheck)
cdMode := cdUID != ""
nextdnsMode := nextdns != ""
// For Windows server with local Dns server running, we can only try on random local IP.
hasLocalDnsServer := hasLocalDnsServerRunning()
notRouter := router.Name() == ""
for n, listener := range cfg.Listener {
lcc[n] = &listenerConfigCheck{}
if listener.IP == "" {
@@ -1193,6 +1230,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
}
updated = updated || lcc[n].IP || lcc[n].Port
}
il := mainLog.Load()
if infoLogger != nil {
il = infoLogger
@@ -1277,10 +1315,17 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
tryOldIPPort5354 = false
tryPort5354 = false
}
// if not running on a router, we should not try to listen on any port other than 53
// if we do, this will break the dns resolution for the system.
if notRouter {
tryOldIPPort5354 = false
tryPort5354 = false
}
attempts := 0
maxAttempts := 10
for {
if attempts == maxAttempts {
notifyFunc()
logMsg(mainLog.Load().Fatal(), n, "could not find available listen ip and port")
}
addr := net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))
@@ -1288,8 +1333,12 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
if err == nil {
break
}
logMsg(il.Info(), n, "error listening on address: %s, error: %v", addr, err)
if !check.IP && !check.Port {
if fatal {
notifyFunc()
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
}
ok = false
@@ -1348,14 +1397,17 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
} else {
listener.IP = oldIP
}
if check.Port {
// if we are not running on a router, we should not try to listen on any port other than 53
// if we do, this will break the dns resolution for the system.
if check.Port && !notRouter {
listener.Port = randomPort()
} else {
listener.Port = oldPort
}
if listener.IP == oldIP && listener.Port == oldPort {
if fatal {
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
notifyFunc()
logMsg(mainLog.Load().Fatal(), n, "could not listen on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
}
ok = false
break
@@ -1393,6 +1445,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
}
}
if !found {
notifyFunc()
logMsg(mainLog.Load().Fatal(), n, "could not use %q as DNS nameserver with systemd resolved", listener.IP)
}
}
@@ -1597,7 +1650,7 @@ func doGenerateNextDNSConfig(uid string) error {
}
mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile)
generateNextDNSConfig(uid)
updateListenerConfig(&cfg)
updateListenerConfig(&cfg, func() {})
return writeConfigFile(&cfg)
}
@@ -1739,9 +1792,14 @@ func goArm() string {
// upgradeUrl returns the url for downloading new ctrld binary.
func upgradeUrl(baseUrl string) string {
dlPath := fmt.Sprintf("%s-%s/ctrld", runtime.GOOS, runtime.GOARCH)
// Use arm version set during build time, v5 binary can be run on higher arm version system.
if armVersion := goArm(); armVersion != "" {
dlPath = fmt.Sprintf("%s-%sv%s/ctrld", runtime.GOOS, runtime.GOARCH, armVersion)
}
// linux/amd64 has nocgo version, to support systems that missing some libc (like openwrt).
if !cgoEnabled && runtime.GOOS == "linux" && runtime.GOARCH == "amd64" {
dlPath = fmt.Sprintf("%s-%s-nocgo/ctrld", runtime.GOOS, runtime.GOARCH)
}
dlUrl := fmt.Sprintf("%s/%s", baseUrl, dlPath)
if runtime.GOOS == "windows" {
dlUrl += ".exe"
@@ -1768,50 +1826,6 @@ func runningIface(s service.Service) *ifaceResponse {
return nil
}
// resetDnsNoLog performs resetting DNS with logging disable.
func resetDnsNoLog(p *prog) {
// Normally, disable log to prevent annoying users.
if verbose < 3 {
lvl := zerolog.GlobalLevel()
zerolog.SetGlobalLevel(zerolog.Disabled)
p.resetDNS()
zerolog.SetGlobalLevel(lvl)
return
}
// For debugging purpose, still emit log.
p.resetDNS()
}
// resetDnsTask returns a task which perform reset DNS operation.
func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task {
return task{func() error {
if iface == "" {
mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask")
return nil
}
// Always reset DNS first, ensuring DNS setting is in a good state.
// resetDNS must use the "iface" value of current running ctrld
// process to reset what setDNS has done properly.
oldIface := iface
iface = "auto"
p.requiredMultiNICsConfig = requiredMultiNICsConfig()
if ir != nil {
iface = ir.Name
p.requiredMultiNICsConfig = ir.All
}
p.runningIface = iface
if isCtrldInstalled {
mainLog.Load().Debug().Msg("restore system DNS settings")
if status, _ := s.Status(); status == service.StatusRunning {
mainLog.Load().Fatal().Msg("reset DNS while ctrld still running is not safe")
}
resetDnsNoLog(p)
}
iface = oldIface
return nil
}, false, "Reset DNS"}
}
// doValidateCdRemoteConfig fetches and validates custom config for cdUID.
func doValidateCdRemoteConfig(cdUID string, fatal bool) error {
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
@@ -1825,10 +1839,22 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) error {
return err
}
}
// return earlier if there's no custom config.
if rc.Ctrld.CustomConfig == "" {
return nil
}
// validateCdRemoteConfig clobbers v, saving it here to restore later.
oldV := v
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
if errors.As(err, &viper.ConfigParseError{}) {
var cfgErr error
remoteCfg := &ctrld.Config{}
if cfgErr = validateCdRemoteConfig(rc, remoteCfg); cfgErr == nil {
setListenerDefaultValue(remoteCfg)
setNetworkDefaultValue(remoteCfg)
cfgErr = validateConfig(remoteCfg)
} else {
if errors.As(cfgErr, &viper.ConfigParseError{}) {
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
tmpDir := os.TempDir()
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
@@ -1844,12 +1870,14 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) error {
}
// If we could not log details error, emit what we have already got.
if !errorLogged {
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
mainLog.Load().Error().Msgf("failed to parse custom config: %v", cfgErr)
}
}
} else {
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
}
}
if cfgErr != nil {
mainLog.Load().Warn().Msg("disregarding invalid custom config")
}
v = oldV
@@ -1863,8 +1891,8 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool {
logger.Warn().Err(err).Msg("failed to create new service")
return false
}
p.resetDNS()
// restore static DNS settings or DHCP
p.resetDNS(false, true)
tasks := []task{{s.Uninstall, true, "Uninstall"}}
if doTasks(tasks) {

View File

@@ -205,7 +205,13 @@ func initStartCmd() *cobra.Command {
Long: `Install and start the ctrld service
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
Args: cobra.NoArgs,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) > 0 {
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
checkStrFlagEmpty(cmd, cdUidFlagName)
checkStrFlagEmpty(cmd, cdOrgFlagName)
@@ -242,6 +248,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
os.Exit(deactivationPinInvalidExitCode)
}
currentIface = runningIface(s)
mainLog.Load().Debug().Msgf("current interface on start: %v", currentIface)
}
ctx, cancel := context.WithCancel(context.Background())
@@ -339,13 +346,17 @@ NOTE: running "ctrld start" without any arguments will start already installed c
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
}
// if already running, dont restart
if isCtrldRunning {
mainLog.Load().Notice().Msg("service is already running")
return
}
initInteractiveLogging()
tasks := []task{
{s.Stop, false, "Stop"},
resetDnsTask(p, s, isCtrldInstalled, currentIface),
{func() error {
// Save current DNS so we can restore later.
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
return err
}
@@ -355,7 +366,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
}, false, "Save current DNS"},
{func() error {
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
}, false, "Configure Windows service failure actions"},
}, false, "Configure service failure actions"},
{s.Start, true, "Start"},
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
}
@@ -424,10 +435,10 @@ NOTE: running "ctrld start" without any arguments will start already installed c
{s.Stop, false, "Stop"},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
resetDnsTask(p, s, isCtrldInstalled, currentIface),
//resetDnsTask(p, s, isCtrldInstalled, currentIface),
{func() error {
// Save current DNS so we can restore later.
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
return err
}
@@ -451,6 +462,9 @@ NOTE: running "ctrld start" without any arguments will start already installed c
return
}
// add a small delay to ensure the service is started and did not crash
time.Sleep(1 * time.Second)
ok, status, err := selfCheckStatus(ctx, s, sockDir)
switch {
case ok && status == service.StatusRunning:
@@ -550,6 +564,13 @@ NOTE: running "ctrld start" without any arguments will start already installed c
Long: `Quick start service and configure DNS on interface
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) > 0 {
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
if len(os.Args) == 2 {
startOnly = true
@@ -608,16 +629,21 @@ func initStopCmd() *cobra.Command {
}
if doTasks([]task{{s.Stop, true, "Stop"}}) {
p.router.Cleanup()
p.resetDNS()
// restore static DNS settings or DHCP
p.resetDNS(false, true)
// restore DNS settings
if netIface, err := netInterface(p.runningIface); err == nil {
if err := restoreDNS(netIface); err != nil {
mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface")
} else {
mainLog.Load().Debug().Msg("Restored DNS on interface successfully")
// Iterate over all physical interfaces and restore static DNS if a saved static config exists.
withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error {
file := savedStaticDnsSettingsFilePath(i)
if _, err := os.Stat(file); err == nil {
if err := restoreDNS(i); err != nil {
mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name)
} else {
mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name)
}
}
}
return nil
})
if router.WaitProcessExited() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
@@ -714,7 +740,8 @@ func initRestartCmd() *cobra.Command {
{s.Stop, true, "Stop"},
{func() error {
p.router.Cleanup()
p.resetDNS()
// restore static DNS settings or DHCP
p.resetDNS(false, true)
return nil
}, false, "Cleanup"},
{func() error {
@@ -994,13 +1021,13 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
if os.IsNotExist(err) {
continue
}
mainLog.Load().Warn().Err(err).Msg("failed to remove file")
mainLog.Load().Warn().Err(err).Msgf("failed to remove file: %s", file)
} else {
mainLog.Load().Debug().Msgf("file removed: %s", file)
}
}
if err := selfDeleteExe(); err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to remove file")
mainLog.Load().Warn().Err(err).Msg("failed to delete ctrld binary")
} else {
if !supportedSelfDelete {
mainLog.Load().Debug().Msgf("file removed: %s", bin)
@@ -1044,9 +1071,16 @@ func initInterfacesCmd() *cobra.Command {
Short: "List network interfaces of the host",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error {
fmt.Printf("Index : %d\n", i.Index)
fmt.Printf("Name : %s\n", i.Name)
var status string
if i.Flags&net.FlagUp != 0 {
status = "Up"
} else {
status = "Down"
}
fmt.Printf("Status: %s\n", status)
addrs, _ := i.Addrs()
for i, ipaddr := range addrs {
if i == 0 {
@@ -1242,7 +1276,8 @@ func initUpgradeCmd() *cobra.Command {
}
dlUrl := upgradeUrl(baseUrl)
mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl)
resp, err := http.Get(dlUrl)
resp, err := getWithRetry(dlUrl)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to download binary")
}
@@ -1266,7 +1301,8 @@ func initUpgradeCmd() *cobra.Command {
{s.Stop, true, "Stop"},
{func() error {
p.router.Cleanup()
p.resetDNS()
// restore static DNS settings or DHCP
p.resetDNS(false, true)
return nil
}, false, "Cleanup"},
{func() error {

View File

@@ -79,33 +79,81 @@ func (s *controlServer) register(pattern string, handler http.Handler) {
func (p *prog) registerControlServerHandler() {
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
mainLog.Load().Debug().Msg("handling list clients request")
clients := p.ciTable.ListClients()
mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
sort.Slice(clients, func(i, j int) bool {
return clients[i].IP.Less(clients[j].IP)
})
mainLog.Load().Debug().Msg("sorted clients by IP address")
if p.metricsQueryStats.Load() {
for _, client := range clients {
mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts")
for idx, client := range clients {
mainLog.Load().Debug().
Int("index", idx).
Str("ip", client.IP.String()).
Str("mac", client.Mac).
Str("hostname", client.Hostname).
Msg("processing client metrics")
client.IncludeQueryCount = true
dm := &dto.Metric{}
if statsClientQueriesCount.MetricVec == nil {
mainLog.Load().Debug().
Str("client_ip", client.IP.String()).
Msg("skipping metrics collection: MetricVec is nil")
continue
}
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
client.IP.String(),
client.Mac,
client.Hostname,
)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("could not get metrics for client: %v", client)
mainLog.Load().Debug().
Err(err).
Str("client_ip", client.IP.String()).
Str("mac", client.Mac).
Str("hostname", client.Hostname).
Msg("failed to get metrics for client")
continue
}
if err := m.Write(dm); err == nil {
if err := m.Write(dm); err == nil && dm.Counter != nil {
client.QueryCount = int64(dm.Counter.GetValue())
mainLog.Load().Debug().
Str("client_ip", client.IP.String()).
Int64("query_count", client.QueryCount).
Msg("successfully collected query count")
} else if err != nil {
mainLog.Load().Debug().
Err(err).
Str("client_ip", client.IP.String()).
Msg("failed to write metric")
}
}
} else {
mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts")
}
if err := json.NewEncoder(w).Encode(&clients); err != nil {
mainLog.Load().Error().
Err(err).
Int("client_count", len(clients)).
Msg("failed to encode clients response")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
mainLog.Load().Debug().
Int("client_count", len(clients)).
Msg("successfully sent clients list response")
}))
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
select {
@@ -250,7 +298,7 @@ func (p *prog) registerControlServerHandler() {
}
}))
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
if time.Since(p.internalLogSent) < logSentInterval {
if time.Since(p.internalLogSent) < logWriterSentInterval {
w.WriteHeader(http.StatusServiceUnavailable)
return
}

View File

@@ -20,12 +20,12 @@ import (
"golang.org/x/sync/errgroup"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/controld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
"github.com/Control-D-Inc/ctrld/internal/router"
)
const (
@@ -435,14 +435,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
upstreams = []string{upstreamOS}
}
if p.isAdDomainQuery(req.msg) {
ctrld.Log(ctx, mainLog.Load().Debug(),
"AD domain query detected for %s in domain %s",
req.msg.Question[0].Name, p.adDomain)
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
upstreams = []string{upstreamOS}
// For OS resolver, local addresses are ignored to prevent possible looping.
// However, on Active Directory Domain Controller, where it has local DNS server
// running and listening on local addresses, these local addresses must be used
// as nameservers, so queries for ADDC could be resolved as expected.
if p.isAdDomainQuery(req.msg) {
ctrld.Log(ctx, mainLog.Load().Debug(),
"AD domain query detected for %s in domain %s",
req.msg.Question[0].Name, p.adDomain)
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
upstreams = []string{upstreamOSLocal}
}
}
res := &proxyResponse{}
@@ -458,7 +461,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
} else {
switch {
case isSrvLookup(req.msg):
case isSrvLanLookup(req.msg):
upstreams = []string{upstreamOS}
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
ctx = ctrld.LanQueryCtx(ctx)
@@ -620,7 +623,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
// if we have no healthy upstreams, trigger recovery flow
if p.recoverOnUpstreamFailure() {
if p.leakOnUpstreamFailure() {
if p.um.countHealthy(upstreams) == 0 {
p.recoveryCancelMu.Lock()
if p.recoveryCancel == nil {
@@ -639,19 +642,20 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
} else {
mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
}
}
// attempt query to OS resolver while as a retry catch all
if upstreams[0] != upstreamOS {
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
if answer != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
res.answer = answer
res.upstream = osUpstreamConfig.Endpoint
return res
// attempt query to OS resolver while as a retry catch all
// we dont want this to happen if leakOnUpstreamFailure is false
if upstreams[0] != upstreamOS {
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
if answer != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
res.answer = answer
res.upstream = osUpstreamConfig.Endpoint
return res
}
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
}
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
}
answer := new(dns.Msg)
@@ -1108,21 +1112,27 @@ func isLanHostnameQuery(m *dns.Msg) bool {
default:
return false
}
name := strings.TrimSuffix(q.Name, ".")
return isLanHostname(q.Name)
}
// isSrvLanLookup reports whether DNS message is an SRV query of a LAN hostname.
func isSrvLanLookup(m *dns.Msg) bool {
if m == nil || len(m.Question) == 0 {
return false
}
q := m.Question[0]
return q.Qtype == dns.TypeSRV && isLanHostname(q.Name)
}
// isLanHostname reports whether name is a LAN hostname.
func isLanHostname(name string) bool {
name = strings.TrimSuffix(name, ".")
return !strings.Contains(name, ".") ||
strings.HasSuffix(name, ".domain") ||
strings.HasSuffix(name, ".lan") ||
strings.HasSuffix(name, ".local")
}
// isSrvLookup reports whether DNS message is a SRV query.
func isSrvLookup(m *dns.Msg) bool {
if m == nil || len(m.Question) == 0 {
return false
}
return m.Question[0].Qtype == dns.TypeSRV
}
// isWanClient reports whether the input is a WAN address.
func isWanClient(na net.Addr) bool {
var ip netip.Addr
@@ -1177,7 +1187,10 @@ func FlushDNSCache() error {
// monitorNetworkChanges starts monitoring for network interface changes
func (p *prog) monitorNetworkChanges(ctx context.Context) error {
mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: "))
mon, err := netmon.New(func(format string, args ...any) {
// Always fetch the latest logger (and inject the prefix)
mainLog.Load().Printf("netmon: "+format, args...)
})
if err != nil {
return fmt.Errorf("creating network monitor: %w", err)
}
@@ -1248,8 +1261,16 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
}
}
// if the default route changed, set changed to true
if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface {
changed = true
mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface)
}
if !changed {
mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected")
// check if the default IPs are still on an interface that is up
ValidateDefaultLocalIPsFromDelta(delta.New)
return
}
@@ -1260,6 +1281,13 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
// Get IPs from default route interface in new state
selfIP := defaultRouteIP()
// Ensure that selfIP is an IPv4 address.
// If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil {
mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP)
selfIP = ""
}
var ipv6 string
if delta.New.DefaultRouteInterface != "" {
@@ -1295,7 +1323,8 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
}
}
if ip := net.ParseIP(selfIP); ip != nil {
// Only set the IPv4 default if selfIP is a valid IPv4 address.
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil {
ctrld.SetDefaultLocalIPv4(ip)
if !isMobile() && p.ciTable != nil {
p.ciTable.SetSelfIP(selfIP)
@@ -1306,7 +1335,8 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
}
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
if p.recoverOnUpstreamFailure() {
// we only trigger recovery flow for network changes on non router devices
if router.Name() == "" {
p.handleRecovery(RecoveryReasonNetworkChange)
}
})
@@ -1438,7 +1468,10 @@ func (p *prog) handleRecovery(reason RecoveryReason) {
// Immediately remove our DNS settings from the interface.
// set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
p.recoveryRunning.Store(true)
p.resetDNS()
// we do not want to restore any static DNS settings
// we must try to get the DHCP values, any static DNS settings
// will be appended to nameservers from the saved interface values
p.resetDNS(false, false)
// For an OS failure, reinitialize OS resolver nameservers immediately.
if reason == RecoveryReasonOSFailure {
@@ -1504,12 +1537,14 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string
go func(name string, uc *ctrld.UpstreamConfig) {
defer wg.Done()
mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name)
attempts := 0
for {
select {
case <-ctx.Done():
mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name)
return
default:
attempts++
// checkUpstreamOnce will reset any failure counters on success.
if err := p.checkUpstreamOnce(name, uc); err == nil {
mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name)
@@ -1523,6 +1558,18 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string
}
mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
time.Sleep(checkUpstreamBackoffSleep)
// if this is the upstreamOS and it's the 3rd attempt (or multiple of 3),
// we should try to reinit the OS resolver to ensure we can recover
if name == upstreamOS && attempts%3 == 0 {
mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
ns := ctrld.InitializeOsResolver(true)
if len(ns) == 0 {
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
} else {
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
}
}
}
}
}(name, uc)
@@ -1556,3 +1603,32 @@ func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.U
}
return upstreams
}
// ValidateDefaultLocalIPsFromDelta checks if the default local IPv4 and IPv6 stored
// are still present in the new network state (provided by delta.New).
// If a stored default IP is no longer active, it resets that default (sets it to nil)
// so that it won't be used in subsequent custom dialer contexts.
func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) {
currentIPv4 := ctrld.GetDefaultLocalIPv4()
currentIPv6 := ctrld.GetDefaultLocalIPv6()
// Build a map of active IP addresses from the new state.
activeIPs := make(map[string]bool)
for _, prefixes := range newState.InterfaceIPs {
for _, prefix := range prefixes {
activeIPs[prefix.Addr().String()] = true
}
}
// Check if the default IPv4 is still active.
if currentIPv4 != nil && !activeIPs[currentIPv4.String()] {
mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4)
ctrld.SetDefaultLocalIPv4(nil)
}
// Check if the default IPv6 is still active.
if currentIPv6 != nil && !activeIPs[currentIPv6.String()] {
mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6)
ctrld.SetDefaultLocalIPv6(nil)
}
}

View File

@@ -418,20 +418,21 @@ func Test_isPrivatePtrLookup(t *testing.T) {
}
}
func Test_isSrvLookup(t *testing.T) {
func Test_isSrvLanLookup(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
isSrvLookup bool
}{
{"SRV", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isSrvLookup(tc.msg); tc.isSrvLookup != got {
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
}
})

View File

@@ -1,5 +1,12 @@
package cli
import (
"fmt"
"net"
"net/http"
"time"
)
// AppCallback provides hooks for injecting certain functionalities
// from mobile platforms to main ctrld cli.
type AppCallback struct {
@@ -17,3 +24,55 @@ type AppConfig struct {
Verbose int
LogPath string
}
const (
defaultHTTPTimeout = 30 * time.Second
defaultMaxRetries = 3
)
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
func httpClientWithFallback(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
Transport: &http.Transport{
// Prefer IPv4 over IPv6
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4
}).DialContext,
},
}
}
// doWithRetry performs an HTTP request with retries
func doWithRetry(req *http.Request, maxRetries int) (*http.Response, error) {
var lastErr error
client := httpClientWithFallback(defaultHTTPTimeout)
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff
}
resp, err := client.Do(req)
if err == nil {
return resp, nil
}
lastErr = err
mainLog.Load().Debug().Err(err).
Str("method", req.Method).
Str("url", req.URL.String()).
Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries)
}
return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr)
}
// Helper for making GET requests with retries
func getWithRetry(url string) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
return doWithRetry(req, defaultMaxRetries)
}

View File

@@ -16,13 +16,12 @@ import (
)
const (
logWriterSize = 1024 * 1024 * 5 // 5 MB
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
logWriterInitialSize = 32 * 1024 // 32 KB
logSentInterval = time.Minute
logStartEndMarker = "\n\n=== INIT_END ===\n\n"
logLogEndMarker = "\n\n=== LOG_END ===\n\n"
logWarnEndMarker = "\n\n=== WARN_END ===\n\n"
logWriterSize = 1024 * 1024 * 5 // 5 MB
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
logWriterInitialSize = 32 * 1024 // 32 KB
logWriterSentInterval = time.Minute
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
)
type logViewResponse struct {
@@ -69,13 +68,23 @@ func (lw *logWriter) Write(p []byte) (int, error) {
// If writing p causes overflows, discard old data.
if lw.buf.Len()+len(p) > lw.size {
buf := lw.buf.Bytes()
buf = buf[:logWriterInitialSize]
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
buf = buf[:idx]
haveEndMarker := false
// If there's init end marker already, preserve the data til the marker.
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
buf = buf[:idx+len(logWriterInitEndMarker)]
haveEndMarker = true
} else {
// Otherwise, preserve the initial size data.
buf = buf[:logWriterInitialSize]
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
buf = buf[:idx]
}
}
lw.buf.Reset()
lw.buf.Write(buf)
lw.buf.WriteString(logStartEndMarker) // indicate that the log was truncated.
if !haveEndMarker {
lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated.
}
}
// If p is bigger than buffer size, truncate p by half until its size is smaller.
for len(p)+lw.buf.Len() > lw.size {
@@ -84,6 +93,15 @@ func (lw *logWriter) Write(p []byte) (int, error) {
return lw.buf.Write(p)
}
// initLogging initializes global logging setup.
func (p *prog) initLogging(backup bool) {
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
logWriters := initLoggingWithBackup(backup)
// Initializing internal logging after global logging.
p.initInternalLogging(logWriters)
}
// initInternalLogging performs internal logging if there's no log enabled.
func (p *prog) initInternalLogging(writers []io.Writer) {
if !p.needInternalLogging() {
@@ -92,7 +110,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) {
p.initInternalLogWriterOnce.Do(func() {
mainLog.Load().Notice().Msg("internal logging enabled")
p.internalLogWriter = newLogWriter()
p.internalLogSent = time.Now().Add(-logSentInterval)
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
p.internalWarnLogWriter = newSmallLogWriter()
})
p.mu.Lock()
@@ -158,7 +176,7 @@ func (p *prog) logReader() (*logReader, error) {
wlwReader := bytes.NewReader(wlw.buf.Bytes())
wlwSize := wlw.buf.Len()
wlw.mu.Unlock()
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logLogEndMarker)), wlwReader)
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
lr := &logReader{r: io.NopCloser(reader)}
lr.size = int64(lwSize + wlwSize)
if lr.size == 0 {

View File

@@ -16,7 +16,7 @@ func Test_logWriter_Write(t *testing.T) {
t.Fatalf("unexpected buf content: %v", lw.buf.String())
}
newData := "B"
halfData := strings.Repeat("A", len(data)/2) + logStartEndMarker
halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker
lw.Write([]byte(newData))
if lw.buf.String() != halfData+newData {
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
@@ -47,3 +47,39 @@ func Test_logWriter_ConcurrentWrite(t *testing.T) {
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
}
}
func Test_logWriter_MarkerInitEnd(t *testing.T) {
size := 64 * 1024
lw := &logWriter{size: size}
lw.buf.Grow(lw.size)
paddingSize := 10
// Writing half of the size, minus len(end marker) and padding size.
dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize
data := strings.Repeat("A", dataSize)
// Inserting newline for making partial init data
data += "\n"
// Filling left over buffer to make the log full.
// The data length: len(end marker) + padding size - 1 (for newline above) + size/2
data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2))
lw.Write([]byte(data))
if lw.buf.String() != data {
t.Fatalf("unexpected buf content: %v", lw.buf.String())
}
lw.Write([]byte("B"))
lw.Write([]byte(strings.Repeat("B", 256*1024)))
firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker)
lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker)
// Check if init end marker present.
if firstIdx == -1 || lastIdx == -1 {
t.Fatalf("missing init end marker: %s", lw.buf.String())
}
// Check if init end marker appears only once.
if firstIdx != lastIdx {
t.Fatalf("log init end marker appears more than once: %s", lw.buf.String())
}
// Ensure that we have the correct init log data.
if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) {
t.Fatalf("unexpected log content: %s", lw.buf.String())
}
}

View File

@@ -88,24 +88,21 @@ func initConsoleLogging() {
multi := zerolog.MultiLevelWriter(consoleWriter)
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
mainLog.Store(&l)
switch {
case silent:
zerolog.SetGlobalLevel(zerolog.NoLevel)
case verbose == 1:
ctrld.ProxyLogger.Store(&l)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
case verbose > 1:
ctrld.ProxyLogger.Store(&l)
zerolog.SetGlobalLevel(zerolog.DebugLevel)
default:
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
}
}
// initLogging initializes global logging setup.
func initLogging() []io.Writer {
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
return initLoggingWithBackup(true)
}
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
// to be used for all interactive commands.
//

5
cmd/cli/nocgo.go Normal file
View File

@@ -0,0 +1,5 @@
//go:build !cgo
package cli
const cgoEnabled = false

View File

@@ -72,34 +72,25 @@ func setDNS(iface *net.Interface, nameservers []string) error {
SearchDomains: []dnsname.FQDN{},
}
trySystemdResolve := false
for i := 0; i < maxSetDNSAttempts; i++ {
if err := r.SetDNS(osConfig); err != nil {
if strings.Contains(err.Error(), "Rejected send message") &&
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
trySystemdResolve = true
break
}
// This error happens on read-only file system, which causes ctrld failed to create backup
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
// The error format is controlled by us, so checking for error string is fine.
// See: ../../internal/dns/direct.go:L278
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
return nil
}
return err
if err := r.SetDNS(osConfig); err != nil {
if strings.Contains(err.Error(), "Rejected send message") &&
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
trySystemdResolve = true
goto systemdResolve
}
if useSystemdResolved {
if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil {
mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out))
}
}
currentNS := currentDNS(iface)
if isSubSet(nameservers, currentNS) {
// This error happens on read-only file system, which causes ctrld failed to create backup
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
// The error format is controlled by us, so checking for error string is fine.
// See: ../../internal/dns/direct.go:L278
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
return nil
}
return err
}
systemdResolve:
if trySystemdResolve {
// Stop systemd-networkd and retry setting DNS.
if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil {
@@ -119,8 +110,8 @@ func setDNS(iface *net.Interface, nameservers []string) error {
}
time.Sleep(time.Second)
}
mainLog.Load().Debug().Msg("DNS was not set for some reason")
}
mainLog.Load().Debug().Msg("DNS was not set for some reason")
return nil
}
@@ -169,6 +160,7 @@ func resetDNS(iface *net.Interface) (err error) {
}
// TODO(cuonglm): handle DHCPv6 properly.
mainLog.Load().Debug().Msg("checking for IPv6 availability")
if ctrldnet.IPv6Available(ctx) {
c := client6.NewClient()
conversation, err := c.Exchange(iface.Name)
@@ -188,6 +180,8 @@ func resetDNS(iface *net.Interface) (err error) {
}
}
}
} else {
mainLog.Load().Debug().Msg("IPv6 is not available")
}
return ignoringEINTR(func() error {

View File

@@ -43,21 +43,42 @@ func setDNS(iface *net.Interface, nameservers []string) error {
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
// Configuring the Dns server to forward queries to ctrld instead.
if hasLocalDnsServerRunning() {
mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders")
file := absHomeDir(windowsForwardersFilename)
oldForwardersContent, _ := os.ReadFile(file)
mainLog.Load().Debug().Msgf("Using forwarders file: %s", file)
oldForwardersContent, err := os.ReadFile(file)
if err != nil {
mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file")
} else {
mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent))
}
hasLocalIPv6Listener := needLocalIPv6Listener()
mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status")
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
if !hasLocalIPv6Listener {
return false
}
return s == "::1"
})
mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list")
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
} else {
mainLog.Load().Debug().Msg("Successfully wrote new forwarders file")
}
oldForwarders := strings.Split(string(oldForwardersContent), ",")
mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders")
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
} else {
mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders")
}
}
})
@@ -147,15 +168,32 @@ func restoreDNS(iface *net.Interface) (err error) {
}
}
for _, ns := range [][]string{v4ns, v6ns} {
if len(ns) == 0 {
continue
}
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
err = setDNS(iface, ns)
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return fmt.Errorf("restoreDNS: %w", err)
}
if err != nil {
return err
if len(v4ns) > 0 {
mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
if err := setDNS(iface, v4ns); err != nil {
return fmt.Errorf("restoreDNS (IPv4): %w", err)
}
} else {
mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name)
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
}
}
if len(v6ns) > 0 {
mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
if err := setDNS(iface, v6ns); err != nil {
return fmt.Errorf("restoreDNS (IPv6): %w", err)
}
} else {
mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name)
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
}
}
}
@@ -180,43 +218,69 @@ func currentDNS(iface *net.Interface) []string {
return ns
}
// currentStaticDNS returns the current static DNS settings of given interface.
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
// like "NameServer" and "ProfileNameServer".
func currentStaticDNS(iface *net.Interface) ([]string, error) {
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err)
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
}
guid, err := luid.GUID()
if err != nil {
return nil, fmt.Errorf("luid.GUID: %w", err)
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
}
var ns []string
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
found := false
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
for _, path := range keyPaths {
interfaceKeyPath := path + guid.String()
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
if err != nil {
return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err)
mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
continue
}
for _, key := range []string{"NameServer", "ProfileNameServer"} {
if found {
continue
}
value, _, err := k.GetStringValue(key)
if err != nil && !errors.Is(err, registry.ErrNotExist) {
return nil, fmt.Errorf("%s: %w", key, err)
}
if len(value) > 0 {
found = true
for _, e := range strings.Split(value, ",") {
ns = append(ns, strings.TrimRight(e, "\x00"))
func() {
defer k.Close()
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
value, _, err := k.GetStringValue(keyName)
if err != nil && !errors.Is(err, registry.ErrNotExist) {
mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName)
continue
}
if len(value) > 0 {
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
parsed := parseDNSServers(value)
for _, pns := range parsed {
if !slices.Contains(ns, pns) {
ns = append(ns, pns)
}
}
}
}
}
}()
}
if len(ns) == 0 {
mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name)
}
return ns, nil
}
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
// and trims any extraneous whitespace or null characters.
func parseDNSServers(val string) []string {
fields := strings.FieldsFunc(val, func(r rune) bool {
return r == ' ' || r == ','
})
var servers []string
for _, f := range fields {
trimmed := strings.TrimSpace(f)
if len(trimmed) > 0 {
servers = append(servers, trimmed)
}
}
return servers
}
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
// and also removing old forwarders if provided.
func addDnsServerForwarders(nameservers, old []string) error {

View File

@@ -43,7 +43,7 @@ const (
ctrldControlUnixSockMobile = "cd.sock"
upstreamPrefix = "upstream."
upstreamOS = upstreamPrefix + "os"
upstreamPrivate = upstreamPrefix + "private"
upstreamOSLocal = upstreamOS + ".local"
dnsWatchdogDefaultInterval = 20 * time.Second
ctrldServiceName = "ctrld"
)
@@ -120,6 +120,7 @@ type prog struct {
runningIface string
requiredMultiNICsConfig bool
adDomain string
runningOnDomainController bool
selfUninstallMu sync.Mutex
refusedQueryCount int
@@ -268,7 +269,7 @@ func (p *prog) preRun() {
if runtime.GOOS == "darwin" {
p.onStopped = append(p.onStopped, func() {
if !service.Interactive() {
p.resetDNS()
p.resetDNS(false, true)
}
})
}
@@ -276,7 +277,12 @@ func (p *prog) preRun() {
func (p *prog) postRun() {
if !service.Interactive() {
p.resetDNS()
if runtime.GOOS == "windows" {
isDC, roleInt := isRunningOnDomainController()
p.runningOnDomainController = isDC
mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
}
p.resetDNS(false, false)
ns := ctrld.InitializeOsResolver(false)
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
p.setDNS()
@@ -345,14 +351,19 @@ func (p *prog) apiConfigReload() {
if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated || forced {
lastUpdated = time.Now().Unix()
cfg := &ctrld.Config{}
if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil {
var cfgErr error
if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil {
setListenerDefaultValue(cfg)
setNetworkDefaultValue(cfg)
cfgErr = validateConfig(cfg)
}
if cfgErr != nil {
logger.Warn().Err(err).Msg("skipping invalid custom config")
if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil {
logger.Error().Err(err).Msg("could not mark custom last update failed")
}
return
}
setListenerDefaultValue(cfg)
logger.Debug().Msg("custom config changes detected, reloading...")
p.apiReloadCh <- cfg
} else {
@@ -560,13 +571,12 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
if !reload {
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
logWriters := initLoggingWithBackup(false)
p.initLogging(false)
if p.logConn != nil {
_ = p.logConn.Close()
}
go p.apiConfigReload()
p.postRun()
p.initInternalLogging(logWriters)
}
wg.Wait()
}
@@ -647,16 +657,74 @@ func (p *prog) setDNS() {
if cfg.Listener == nil {
return
}
if p.runningIface == "" {
return
}
// allIfaces tracks whether we should set DNS for all physical interfaces.
allIfaces := p.requiredMultiNICsConfig
lc := cfg.FirstListener()
if lc == nil {
return
}
ns := lc.IP
switch {
case lc.IsDirectDnsListener():
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
ns = "127.0.0.1"
case lc.Port != 53:
ns = "127.0.0.1"
if resolver := router.LocalResolverIP(); resolver != "" {
ns = resolver
}
default:
// If we ever reach here, it means ctrld is running on lc.IP port 53,
// so we could just use lc.IP as nameserver.
}
nameservers := []string{ns}
if needRFC1918Listeners(lc) {
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
}
if needLocalIPv6Listener() {
nameservers = append(nameservers, "::1")
}
slices.Sort(nameservers)
netIfaceName := ""
netIface := p.setDnsForRunningIface(nameservers)
if netIface != nil {
netIfaceName = netIface.Name
}
setDnsOK = true
if p.requiredMultiNICsConfig {
withEachPhysicalInterfaces(netIfaceName, "set DNS", func(i *net.Interface) error {
return setDnsIgnoreUnusableInterface(i, nameservers)
})
}
// resolvconf file is only useful when we have default route interface,
// then set DNS on this interface will push change to /etc/resolv.conf file.
if netIface != nil && shouldWatchResolvconf() {
servers := make([]netip.Addr, len(nameservers))
for i := range nameservers {
servers[i] = netip.MustParseAddr(nameservers[i])
}
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.watchResolvConf(netIface, servers, setResolvConf)
}()
}
if p.dnsWatchdogEnabled() {
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.dnsWatchdog(netIface, nameservers)
}()
}
}
func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.Interface) {
if p.runningIface == "" {
return
}
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
const maxDNSRetryAttempts = 3
@@ -690,59 +758,14 @@ func (p *prog) setDNS() {
return
}
runningIface = netIface
logger.Debug().Msg("setting DNS for interface")
ns := lc.IP
switch {
case lc.IsDirectDnsListener():
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
ns = "127.0.0.1"
case lc.Port != 53:
ns = "127.0.0.1"
if resolver := router.LocalResolverIP(); resolver != "" {
ns = resolver
}
default:
// If we ever reach here, it means ctrld is running on lc.IP port 53,
// so we could just use lc.IP as nameserver.
}
nameservers := []string{ns}
if needRFC1918Listeners(lc) {
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
}
if needLocalIPv6Listener() {
nameservers = append(nameservers, "::1")
}
slices.Sort(nameservers)
if err := setDNS(netIface, nameservers); err != nil {
logger.Error().Err(err).Msgf("could not set DNS for interface")
return
}
setDnsOK = true
logger.Debug().Msg("setting DNS successfully")
if allIfaces {
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
return setDnsIgnoreUnusableInterface(i, nameservers)
})
}
if shouldWatchResolvconf() {
servers := make([]netip.Addr, len(nameservers))
for i := range nameservers {
servers[i] = netip.MustParseAddr(nameservers[i])
}
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.watchResolvConf(netIface, servers, setResolvConf)
}()
}
if p.dnsWatchdogEnabled() {
p.dnsWg.Add(1)
go func() {
defer p.dnsWg.Done()
p.dnsWatchdog(netIface, nameservers, allIfaces)
}()
}
return
}
// dnsWatchdogEnabled reports whether DNS watchdog is enabled.
@@ -765,12 +788,12 @@ func (p *prog) dnsWatchdogDuration() time.Duration {
// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings.
// This is only works when deactivation pin set.
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces bool) {
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
if !requiredMultiNICsConfig() {
return
}
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
logger.Debug().Msg("start DNS settings watchdog")
mainLog.Load().Debug().Msg("start DNS settings watchdog")
ns := nameservers
slices.Sort(ns)
@@ -788,14 +811,56 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
return
}
if dnsChanged(iface, ns) {
logger.Debug().Msg("DNS settings were changed, re-applying settings")
mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings")
// Check if the interface already has static DNS servers configured.
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
staticDNS, err := currentStaticDNS(iface)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name)
} else if len(staticDNS) > 0 {
//filter out loopback addresses
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
// if we have a static config and no saved IPs already, save them
if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 {
// Save these static DNS values so that they can be restored later.
if err := saveCurrentStaticDNS(iface); err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name)
}
}
}
if err := setDNS(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
}
}
if allIfaces {
withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error {
if p.requiredMultiNICsConfig {
ifaceName := ""
if iface != nil {
ifaceName = iface.Name
}
withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error {
if dnsChanged(i, ns) {
// Check if the interface already has static DNS servers configured.
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
staticDNS, err := currentStaticDNS(i)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name)
} else if len(staticDNS) > 0 {
//filter out loopback addresses
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
// if we have a static config and no saved IPs already, save them
if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 {
// Save these static DNS values so that they can be restored later.
if err := saveCurrentStaticDNS(i); err != nil {
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name)
}
}
}
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
} else {
@@ -809,33 +874,78 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
}
}
func (p *prog) resetDNS() {
// resetDNS performs a DNS reset for all interfaces.
func (p *prog) resetDNS(isStart bool, restoreStatic bool) {
netIfaceName := ""
if netIface := p.resetDNSForRunningIface(isStart, restoreStatic); netIface != nil {
netIfaceName = netIface.Name
}
// See corresponding comments in (*prog).setDNS function.
if p.requiredMultiNICsConfig {
withEachPhysicalInterfaces(netIfaceName, "reset DNS", resetDnsIgnoreUnusableInterface)
}
}
// resetDNSForRunningIface performs a DNS reset on the running interface.
// The parameter isStart indicates whether this is being called as part of a start (or restart)
// command. When true, we check if the current static DNS configuration already differs from the
// service listener (127.0.0.1). If so, we assume that an admin has manually changed the interface's
// static DNS settings and we do not override them using the potentially out-of-date saved file.
// Otherwise, we restore the saved configuration (if any) or reset to DHCP.
func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) {
if p.runningIface == "" {
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
return
}
// See corresponding comments in (*prog).setDNS function.
allIfaces := p.requiredMultiNICsConfig
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
netIface, err := netInterface(p.runningIface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
return
}
runningIface = netIface
if err := restoreNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not restore NetworkManager")
return
}
logger.Debug().Msg("Restoring DNS for interface")
if err := resetDNS(netIface); err != nil {
logger.Error().Err(err).Msgf("could not reset DNS")
return
// If starting, check the current static DNS configuration.
if isStart {
current, err := currentStaticDNS(netIface)
if err != nil {
logger.Warn().Err(err).Msg("unable to obtain current static DNS configuration; proceeding to restore saved config")
} else if len(current) > 0 {
// If any static DNS value is not our own listener, assume an admin override.
hasManualConfig := false
for _, ns := range current {
if ns != "127.0.0.1" && ns != "::1" {
hasManualConfig = true
break
}
}
if hasManualConfig {
logger.Debug().Msgf("Detected manual DNS configuration on interface %q: %v; not overriding with saved configuration", netIface.Name, current)
return
}
}
}
logger.Debug().Msg("Restoring DNS successfully")
if allIfaces {
withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDnsIgnoreUnusableInterface)
// Default logic: if there is a saved static DNS configuration, restore it.
saved := savedStaticNameservers(netIface)
if len(saved) > 0 && restoreStatic {
logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved)
if err := setDNS(netIface, saved); err != nil {
logger.Error().Err(err).Msgf("failed to restore static DNS config on interface %q", netIface.Name)
return
}
} else {
logger.Debug().Msgf("No saved static DNS config for interface %q; resetting to DHCP", netIface.Name)
if err := resetDNS(netIface); err != nil {
logger.Error().Err(err).Msgf("failed to reset DNS to DHCP on interface %q", netIface.Name)
return
}
}
return
}
func (p *prog) logInterfacesState() {
@@ -985,12 +1095,6 @@ func findWorkingInterface(currentIface string) string {
return currentIface
}
// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure.
func (p *prog) recoverOnUpstreamFailure() bool {
// Default is false on routers, since this recovery flow is only useful for devices that move between networks.
return router.Name() == ""
}
func randomLocalIP() string {
n := rand.Intn(254-2) + 2
return fmt.Sprintf("127.0.0.%d", n)
@@ -1192,7 +1296,7 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.
// TODO: investigate whether we should report this error?
if err := f(netIface); err == nil {
if context != "" {
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name)
}
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
@@ -1215,19 +1319,38 @@ var errSaveCurrentStaticDNSNotSupported = errors.New("saving current DNS is not
// saveCurrentStaticDNS saves the current static DNS settings for restoring later.
// Only works on Windows and Mac.
func saveCurrentStaticDNS(iface *net.Interface) error {
if iface == nil {
mainLog.Load().Debug().Msg("could not save current static DNS settings for nil interface")
return nil
}
switch runtime.GOOS {
case "windows", "darwin":
default:
return errSaveCurrentStaticDNSNotSupported
}
file := savedStaticDnsSettingsFilePath(iface)
ns, _ := currentStaticDNS(iface)
ns, err := currentStaticDNS(iface)
if err != nil {
mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name)
return err
}
if len(ns) == 0 {
mainLog.Load().Debug().Msgf("no static DNS settings for %q, removing old static DNS settings file", iface.Name)
_ = os.Remove(file) // removing old static DNS settings
return nil
}
//filter out loopback addresses
ns = slices.DeleteFunc(ns, func(s string) bool {
return net.ParseIP(s).IsLoopback()
})
//if we now have no static DNS settings and the file already exists
// return and do not save the file
if len(ns) == 0 {
mainLog.Load().Debug().Msgf("loopback on %q, skipping saving static DNS settings", iface.Name)
return nil
}
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file")
mainLog.Load().Warn().Err(err).Msgf("could not remove old static DNS settings file: %s", file)
}
nss := strings.Join(ns, ",")
mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss)
@@ -1241,6 +1364,9 @@ func saveCurrentStaticDNS(iface *net.Interface) error {
// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface.
func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
if iface == nil {
return ""
}
return absHomeDir(".dns_" + iface.Name)
}
@@ -1248,6 +1374,10 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
//
//lint:ignore U1000 use in os_windows.go and os_darwin.go
func savedStaticNameservers(iface *net.Interface) []string {
if iface == nil {
mainLog.Load().Debug().Msg("could not get saved static DNS settings for nil interface")
return nil
}
file := savedStaticDnsSettingsFilePath(iface)
if data, _ := os.ReadFile(file); len(data) > 0 {
saveValues := strings.Split(string(data), ",")
@@ -1265,8 +1395,13 @@ func savedStaticNameservers(iface *net.Interface) []string {
}
// dnsChanged reports whether DNS settings for given interface was changed.
// It returns false for a nil iface.
//
// The caller must sort the nameservers before calling this function.
func dnsChanged(iface *net.Interface, nameservers []string) bool {
if iface == nil {
return false
}
curNameservers, _ := currentStaticDNS(iface)
slices.Sort(curNameservers)
if !slices.Equal(curNameservers, nameservers) {
@@ -1286,3 +1421,36 @@ func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
selfUninstall(p, logger)
}
}
// leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow
// when upstream failures occur.
func (p *prog) leakOnUpstreamFailure() bool {
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
return *ptr
}
// Default is false on routers, since this leaking is only useful for devices that move between networks.
if router.Name() != "" {
return false
}
// if we are running on ADDC, we should not leak on upstream failure
if p.runningOnDomainController {
return false
}
return true
}
// Domain controller role values from Win32_ComputerSystem
// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem
const (
BackupDomainController = 4
PrimaryDomainController = 5
)
// isRunningOnDomainController checks if the current machine is a domain controller
// by querying the DomainRole property from Win32_ComputerSystem via WMI.
func isRunningOnDomainController() (bool, int) {
if runtime.GOOS != "windows" {
return false, 0
}
return isRunningOnDomainControllerWindows()
}

View File

@@ -13,6 +13,7 @@ import (
"tailscale.com/health"
"github.com/Control-D-Inc/ctrld/internal/dns"
"github.com/Control-D-Inc/ctrld/internal/router"
)
func init() {
@@ -39,6 +40,9 @@ func setDependencies(svc *service.Config) {
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
}
}
if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 {
svc.Dependencies = append(svc.Dependencies, routerDeps...)
}
}
func setWorkingDirectory(svc *service.Config, dir string) {

View File

@@ -1,4 +1,4 @@
//go:build !linux && !freebsd && !darwin
//go:build !linux && !freebsd && !darwin && !windows
package cli

14
cmd/cli/prog_windows.go Normal file
View File

@@ -0,0 +1,14 @@
package cli
import "github.com/kardianos/service"
func setDependencies(svc *service.Config) {
if hasLocalDnsServerRunning() {
svc.Dependencies = []string{"DNS"}
}
}
func setWorkingDirectory(svc *service.Config, dir string) {
// WorkingDirectory is not supported on Windows.
svc.WorkingDirectory = dir
}

View File

@@ -6,10 +6,12 @@ import (
"fmt"
"os"
"os/exec"
"runtime"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld/internal/router"
"github.com/Control-D-Inc/ctrld/internal/router/openwrt"
)
// newService wraps service.New call to return service.Service
@@ -167,7 +169,11 @@ func doTasks(tasks []task) bool {
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
return false
}
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
// if this is darwin stop command, dont print debug
// since launchctl complains on every start
if runtime.GOOS != "darwin" || task.Name != "Stop" {
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
}
}
}
return true
@@ -188,6 +194,13 @@ func checkHasElevatedPrivilege() {
func unixSystemVServiceStatus() (service.Status, error) {
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
if err != nil {
// Specific case for openwrt >= 24.10, it returns non-success code
// for above status command, which may not right.
if router.Name() == openwrt.Name {
if string(bytes.ToLower(bytes.TrimSpace(out))) == "inactive" {
return service.StatusStopped, nil
}
}
return service.StatusUnknown, nil
}

View File

@@ -18,3 +18,5 @@ func openLogFile(path string, flags int) (*os.File, error) {
func hasLocalDnsServerRunning() bool { return false }
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 }

View File

@@ -2,12 +2,18 @@ package cli
import (
"os"
"reflect"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"unsafe"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc/mgr"
)
@@ -165,3 +171,57 @@ func hasLocalDnsServerRunning() bool {
}
}
}
func isRunningOnDomainControllerWindows() (bool, int) {
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("Win32_ComputerSystem")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
if err != nil {
mainLog.Load().Debug().Err(err).Msg("WMI query failed")
return false, 0
}
if instances == nil {
mainLog.Load().Debug().Msg("WMI query returned nil instances")
return false, 0
}
defer instances.Close()
if len(instances) == 0 {
mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem")
return false, 0
}
val, err := instances[0].GetProperty("DomainRole")
if err != nil {
mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property")
return false, 0
}
if val == nil {
mainLog.Load().Debug().Msg("DomainRole property is nil")
return false, 0
}
// Safely handle varied types: string or integer
var roleInt int
switch v := val.(type) {
case string:
// "4", "5", etc.
parsed, parseErr := strconv.Atoi(v)
if parseErr != nil {
mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v)
return false, 0
}
roleInt = parsed
case int8, int16, int32, int64:
roleInt = int(reflect.ValueOf(v).Int())
case uint8, uint16, uint32, uint64:
roleInt = int(reflect.ValueOf(v).Uint())
default:
mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v)
return false, 0
}
// Check if role indicates a domain controller
isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController
return isDC, roleInt
}

View File

@@ -1,7 +1,13 @@
package main
import "github.com/Control-D-Inc/ctrld/cmd/cli"
import (
"os"
"github.com/Control-D-Inc/ctrld/cmd/cli"
)
func main() {
cli.Main()
// make sure we exit with 0 if there are no errors
os.Exit(0)
}

View File

@@ -529,7 +529,7 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
conn, err := pd.DialContext(ctx, network, dialAddrs)
conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load())
if err != nil {
return nil, err
}

View File

@@ -166,7 +166,6 @@ before serving the query.
### max_concurrent_requests
The number of concurrent requests that will be handled, must be a non-negative integer.
Tweaking this value depends on the capacity of your system.
- Type: number
- Required: no
@@ -253,9 +252,7 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus
- Default: ""
### dns_watchdog_enabled
Checking DNS changes to network interfaces and reverting to ctrld's own settings.
The DNS watchdog process only runs on Windows and MacOS, while in `--cd` mode.
Watches all physical interfaces for DNS changes and reverts them to ctrld's settings.The DNS watchdog process only runs on Windows and MacOS.
- Type: boolean
- Required: no
@@ -274,7 +271,7 @@ If the time duration is non-positive, default value will be used.
- Default: 20s
### refetch_time
Time in seconds between each iteration that reloads custom config if changed.
Time in seconds between each iteration that reloads custom config from the API.
The value must be a positive number, any invalid value will be ignored and default value will be used.
- Type: number
@@ -282,7 +279,7 @@ The value must be a positive number, any invalid value will be ignored and defau
- Default: 3600
### leak_on_upstream_failure
Once ctrld is "offline", mean ctrld could not connect to any upstream, next queries will be leaked to OS resolver.
If a remote upstream fails to resolve a query or is unreachable, `ctrld` will forward the queries to the default DNS resolver on the network. If failures persist, `ctrld` will remove itself from all networking interfaces until connectivity is restored.
- Type: boolean
- Required: no
@@ -531,6 +528,15 @@ rules = [
]
```
If there is no explicitly defined rules, LAN queries will be handled solely by the OS resolver.
These following domains are considered LAN queries:
- Queries does not have dot `.` in domain name, like `machine1`, `example`, ... (1)
- Queries have domain ends with: `.domain`, `.lan`, `.local`. (2)
- All `SRV` queries of LAN hostname (1) + (2).
- `PTR` queries with private IPs.
---
Note that the order of matching preference:

View File

@@ -207,11 +207,10 @@ func (t *Table) init() {
}
for platform, discover := range discovers {
if err := discover.refresh(); err != nil {
ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", platform)
} else {
t.hostnameResolvers = append(t.hostnameResolvers, discover)
t.refreshers = append(t.refreshers, discover)
ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform)
}
t.hostnameResolvers = append(t.hostnameResolvers, discover)
t.refreshers = append(t.refreshers, discover)
}
}
// Hosts file mapping.
@@ -423,17 +422,27 @@ func (t *Table) ListClients() []*Client {
t.Refresh()
ipMap := make(map[string]*Client)
il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni}
for _, ir := range il {
if ir == nil {
continue
}
for _, ip := range ir.List() {
c, ok := ipMap[ip]
if !ok {
c = &Client{
IP: netip.MustParseAddr(ip),
Source: map[string]struct{}{ir.String(): {}},
// Validate IP before using MustParseAddr
if addr, err := netip.ParseAddr(ip); err == nil {
c, ok := ipMap[ip]
if !ok {
c = &Client{
IP: addr,
Source: map[string]struct{}{},
}
ipMap[ip] = c
}
// Safely get source name
if src := ir.String(); src != "" {
c.Source[src] = struct{}{}
}
ipMap[ip] = c
} else {
c.Source[ir.String()] = struct{}{}
}
}
}

View File

@@ -92,6 +92,11 @@ func (m *mdns) init(quitCh chan struct{}) error {
return err
}
// Check if IPv6 is available once and use the result for the rest of the function.
ctrld.ProxyLogger.Load().Debug().Msgf("checking for IPv6 availability in mdns init")
ipv6 := ctrldnet.IPv6Available(context.Background())
ctrld.ProxyLogger.Load().Debug().Msgf("IPv6 is %v in mdns init", ipv6)
v4ConnList := make([]*net.UDPConn, 0, len(ifaces))
v6ConnList := make([]*net.UDPConn, 0, len(ifaces))
for _, iface := range ifaces {
@@ -102,7 +107,8 @@ func (m *mdns) init(quitCh chan struct{}) error {
v4ConnList = append(v4ConnList, conn)
go m.readLoop(conn)
}
if ctrldnet.IPv6Available(context.Background()) {
if ipv6 {
if conn, err := net.ListenMulticastUDP("udp6", &iface, mdnsV6Addr); err == nil {
v6ConnList = append(v6ConnList, conn)
go m.readLoop(conn)

View File

@@ -67,4 +67,16 @@ var services = [...]string{
// Merlin
"_alexa._tcp",
// Newer Android TV devices
"_androidtvremote2._tcp.local.",
// https://esphome.io/
"_esphomelib._tcp.local.",
// https://www.home-assistant.io/
"_home-assistant._tcp.local.",
// https://kno.wled.ge/
"_wled._tcp.local.",
}

View File

@@ -3,6 +3,7 @@ package clientinfo
import (
"bytes"
"encoding/json"
"fmt"
"io"
"os/exec"
"strings"
@@ -44,9 +45,9 @@ func (u *ubiosDiscover) refreshDevices() error {
cmd := exec.Command("/usr/bin/mongo", "localhost:27117/ace", "--quiet", "--eval", `
DBQuery.shellBatchSize = 256;
db.user.find({name: {$exists: true, $ne: ""}}, {_id:0, mac:1, name:1});`)
b, err := cmd.Output()
b, err := cmd.CombinedOutput()
if err != nil {
return err
return fmt.Errorf("out: %s, err: %w", string(b), err)
}
return u.storeDevices(bytes.NewReader(b))
}

View File

@@ -32,6 +32,8 @@ const (
logURLCom = apiURLCom + "/logs"
logURLDev = apiURLDev + "/logs"
InvalidConfigCode = 40402
defaultTimeout = 20 * time.Second
sendLogTimeout = 300 * time.Second
)
// ResolverConfig represents Control D resolver data.
@@ -135,7 +137,7 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
req.Header.Add("Content-Type", "application/json")
transport := apiTransport(cdDev)
client := http.Client{
Timeout: 10 * time.Second,
Timeout: defaultTimeout,
Transport: transport,
}
resp, err := client.Do(req)
@@ -176,7 +178,7 @@ func SendLogs(lr *LogsRequest, cdDev bool) error {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
transport := apiTransport(cdDev)
client := http.Client{
Timeout: 300 * time.Second,
Timeout: sendLogTimeout,
Transport: transport,
}
resp, err := client.Do(req)
@@ -214,19 +216,55 @@ func apiTransport(cdDev bool) *http.Transport {
if cdDev {
apiDomain = apiDomainDev
}
// First try IPv4
dialer := &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
ips := ctrld.LookupIP(apiDomain)
if len(ips) == 0 {
ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr)
return ctrldnet.Dialer.DialContext(ctx, network, addr)
ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, falling back to direct connection to %s", apiDomain, addr)
return dialer.DialContext(ctx, network, addr)
}
ctrld.ProxyLogger.Load().Debug().Msgf("API IPs: %v", ips)
// Separate IPv4 and IPv6 addresses
var ipv4s, ipv6s []string
for _, ip := range ips {
if strings.Contains(ip, ":") {
ipv6s = append(ipv6s, ip)
} else {
ipv4s = append(ipv4s, ip)
}
}
_, port, _ := net.SplitHostPort(addr)
addrs := make([]string, len(ips))
for i := range ips {
addrs[i] = net.JoinHostPort(ips[i], port)
// Try IPv4 first
if len(ipv4s) > 0 {
addrs := make([]string, len(ipv4s))
for i, ip := range ipv4s {
addrs[i] = net.JoinHostPort(ip, port)
}
d := &ctrldnet.ParallelDialer{}
if conn, err := d.DialContext(ctx, "tcp4", addrs, ctrld.ProxyLogger.Load()); err == nil {
return conn, nil
}
}
d := &ctrldnet.ParallelDialer{}
return d.DialContext(ctx, network, addrs)
// Fall back to IPv6 if available
if len(ipv6s) > 0 {
addrs := make([]string, len(ipv6s))
for i, ip := range ipv6s {
addrs[i] = net.JoinHostPort(ip, port)
}
d := &ctrldnet.ParallelDialer{}
return d.DialContext(ctx, "tcp6", addrs, ctrld.ProxyLogger.Load())
}
// Final fallback to direct connection
return dialer.DialContext(ctx, network, addr)
}
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}

View File

@@ -3,6 +3,7 @@ package net
import (
"context"
"errors"
"io"
"net"
"os"
"os/signal"
@@ -11,6 +12,7 @@ import (
"syscall"
"time"
"github.com/rs/zerolog"
"tailscale.com/logtail/backoff"
)
@@ -26,7 +28,8 @@ var Dialer = &net.Dialer{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := ParallelDialer{}
d.Timeout = 10 * time.Second
return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS})
l := zerolog.New(io.Discard)
return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS}, &l)
},
},
}
@@ -49,8 +52,12 @@ func init() {
}
func supportIPv6(ctx context.Context) bool {
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "443"))
return err == nil
c, err := probeStackDialer.DialContext(ctx, "tcp6", v6BootstrapDNS)
if err != nil {
return false
}
c.Close()
return true
}
func supportListenIPv6Local() bool {
@@ -133,7 +140,7 @@ type ParallelDialer struct {
net.Dialer
}
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) {
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string, logger *zerolog.Logger) (net.Conn, error) {
if len(addrs) == 0 {
return nil, errors.New("empty addresses")
}
@@ -153,11 +160,16 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
logger.Debug().Msgf("dialing to %s", addr)
conn, err := d.Dialer.DialContext(ctx, network, addr)
if err != nil {
logger.Debug().Msgf("failed to dial %s: %v", addr, err)
}
select {
case ch <- &parallelDialerResult{conn: conn, err: err}:
case <-done:
if conn != nil {
logger.Debug().Msgf("connection closed: %s", conn.RemoteAddr())
conn.Close()
}
}
@@ -168,6 +180,7 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs
for res := range ch {
if res.err == nil {
cancel()
logger.Debug().Msgf("connected to %s", res.conn.RemoteAddr())
return res.conn, res.err
}
errs = append(errs, res.err)

View File

@@ -2,10 +2,13 @@ package openwrt
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/kardianos/service"
@@ -15,10 +18,13 @@ import (
)
const (
Name = "openwrt"
openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
Name = "openwrt"
openwrtDNSMasqConfigName = "ctrld.conf"
openwrtDNSMasqDefaultConfigDir = "/tmp/dnsmasq.d"
)
var openwrtDnsmasqDefaultConfigPath = filepath.Join(openwrtDNSMasqDefaultConfigDir, openwrtDNSMasqConfigName)
type Openwrt struct {
cfg *ctrld.Config
dnsmasqCacheSize string
@@ -67,7 +73,7 @@ func (o *Openwrt) Setup() error {
if err != nil {
return err
}
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil {
if err := os.WriteFile(dnsmasqConfPathFromUbus(), []byte(data), 0600); err != nil {
return err
}
// Restart dnsmasq service.
@@ -82,7 +88,7 @@ func (o *Openwrt) Cleanup() error {
return nil
}
// Remove the custom dnsmasq config
if err := os.Remove(openwrtDNSMasqConfigPath); err != nil {
if err := os.Remove(dnsmasqConfPathFromUbus()); err != nil {
return err
}
@@ -126,3 +132,60 @@ func uci(args ...string) (string, error) {
}
return strings.TrimSpace(stdout.String()), nil
}
// openwrtServiceList represents openwrt services config.
type openwrtServiceList struct {
Dnsmasq dnsmasqConf `json:"dnsmasq"`
}
// dnsmasqConf represents dnsmasq config.
type dnsmasqConf struct {
Instances map[string]confInstances `json:"instances"`
}
// confInstances represents an instance config of a service.
type confInstances struct {
Mount map[string]string `json:"mount"`
}
// dnsmasqConfPath returns the dnsmasq config path.
//
// Since version 24.10, openwrt makes some changes to dnsmasq to support
// multiple instances of dnsmasq. This change causes breaking changes to
// software which depends on the default dnsmasq path.
//
// There are some discussion/PRs in openwrt repo to address this:
//
// - https://github.com/openwrt/openwrt/pull/16806
// - https://github.com/openwrt/openwrt/pull/16890
//
// In the meantime, workaround this problem by querying the actual config path
// by querying ubus service list.
func dnsmasqConfPath(r io.Reader) string {
var svc openwrtServiceList
if err := json.NewDecoder(r).Decode(&svc); err != nil {
return openwrtDnsmasqDefaultConfigPath
}
for _, inst := range svc.Dnsmasq.Instances {
for mount := range inst.Mount {
dirName := filepath.Base(mount)
parts := strings.Split(dirName, ".")
if len(parts) < 2 {
continue
}
if parts[0] == "dnsmasq" && parts[len(parts)-1] == "d" {
return filepath.Join(mount, openwrtDNSMasqConfigName)
}
}
}
return openwrtDnsmasqDefaultConfigPath
}
// dnsmasqConfPathFromUbus get dnsmasq config path from ubus service list.
func dnsmasqConfPathFromUbus() string {
output, err := exec.Command("ubus", "call", "service", "list").Output()
if err != nil {
return openwrtDnsmasqDefaultConfigPath
}
return dnsmasqConfPath(bytes.NewReader(output))
}

View File

@@ -0,0 +1,58 @@
package openwrt
import (
"io"
"path/filepath"
"strings"
"testing"
)
// Sample output from https://github.com/openwrt/openwrt/pull/16806#issuecomment-2448255734
const ubusDnsmasqBefore2410 = `{
"dnsmasq": {
"instances": {
"guest_dns": {
"mount": {
"/tmp/dnsmasq.d": "0",
"/var/run/dnsmasq/": "1"
}
}
}
}
}`
const ubusDnsmasq2410 = `{
"dnsmasq": {
"instances": {
"guest_dns": {
"mount": {
"/tmp/dnsmasq.guest_dns.d": "0",
"/var/run/dnsmasq/": "1"
}
}
}
}
}`
func Test_dnsmasqConfPath(t *testing.T) {
var dnsmasq2410expected = filepath.Join("/tmp/dnsmasq.guest_dns.d", openwrtDNSMasqConfigName)
tests := []struct {
name string
in io.Reader
expected string
}{
{"empty", strings.NewReader(""), openwrtDnsmasqDefaultConfigPath},
{"invalid", strings.NewReader("}}"), openwrtDnsmasqDefaultConfigPath},
{"before 24.10", strings.NewReader(ubusDnsmasqBefore2410), openwrtDnsmasqDefaultConfigPath},
{"24.10", strings.NewReader(ubusDnsmasq2410), dnsmasq2410expected},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := dnsmasqConfPath(tc.in); got != tc.expected {
t.Errorf("dnsmasqConfPath() = %v, want %v", got, tc.expected)
}
})
}
}

View File

@@ -215,6 +215,20 @@ func LeaseFilesDir() string {
return ""
}
// ServiceDependencies returns list of dependencies that ctrld services needs on this router.
// See https://pkg.go.dev/github.com/kardianos/service#Config for list format.
func ServiceDependencies() []string {
if Name() == ubios.Name {
// On Ubios, ctrld needs to start after unifi-mongodb,
// so it can query custom client info mapping.
return []string{
"Wants=unifi-mongodb.service",
"After=unifi-mongodb.service",
}
}
return nil
}
func distroName() string {
switch {
case bytes.HasPrefix(unameO(), []byte("DD-WRT")):

View File

@@ -45,11 +45,15 @@ func (s *tomatoSvc) Platform() string {
}
func (s *tomatoSvc) configPath() string {
path, err := os.Executable()
if err != nil {
return ""
bin := s.Config.Executable
if bin == "" {
path, err := os.Executable()
if err != nil {
return ""
}
bin = path
}
return path + ".startup"
return bin + ".startup"
}
func (s *tomatoSvc) template() *template.Template {

View File

@@ -219,6 +219,8 @@ const ubiosBootSystemdService = `[Unit]
Description=Run ctrld On Startup UDM
Wants=network-online.target
After=network-online.target
Wants=unifi-mongodb
After=unifi-mongodb
StartLimitIntervalSec=500
StartLimitBurst=5

5
log.go
View File

@@ -9,11 +9,6 @@ import (
"github.com/rs/zerolog"
)
func init() {
l := zerolog.New(io.Discard)
ProxyLogger.Store(&l)
}
// ProxyLog emits the log record for proxy operations.
// The caller should set it only once.
// DEPRECATED: use ProxyLogger instead.

View File

@@ -7,6 +7,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"net"
"os/exec"
"regexp"
@@ -155,6 +156,8 @@ func getDHCPNameservers(iface string) ([]string, error) {
}
func getAllDHCPNameservers() []string {
logger := *ProxyLogger.Load()
interfaces, err := net.Interfaces()
if err != nil {
return nil
@@ -213,5 +216,67 @@ func getAllDHCPNameservers() []string {
}
}
// if we have static DNS servers saved for the current default route, we should add them to the list
drIfaceName, err := netmon.DefaultRouteInterface()
Log(context.Background(), logger.Debug(), "checking for static DNS servers for default route interface: %s", drIfaceName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get default route interface: %v", err)
} else {
drIface, err := net.InterfaceByName(drIfaceName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get interface by name %s: %v", drIfaceName, err)
} else if drIface != nil {
if _, err := patchNetIfaceName(drIface); err != nil {
Log(context.Background(), logger.Debug(),
"Failed to patch interface name %s: %v", drIfaceName, err)
}
staticNs, file := SavedStaticNameservers(drIface)
Log(context.Background(), logger.Debug(),
"static dns servers from %s: %v", file, staticNs)
if len(staticNs) > 0 {
Log(context.Background(), logger.Debug(),
"Adding static DNS servers from %s: %v", drIface.Name, staticNs)
allNameservers = append(allNameservers, staticNs...)
}
}
}
return allNameservers
}
func patchNetIfaceName(iface *net.Interface) (bool, error) {
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
if err != nil {
return false, err
}
patched := false
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
patched = true
iface.Name = name
}
return patched, nil
}
func networkServiceName(ifaceName string, r io.Reader) string {
scanner := bufio.NewScanner(r)
prevLine := ""
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "*") {
// Network services is disabled.
continue
}
if !strings.Contains(line, "Device: "+ifaceName) {
prevLine = line
continue
}
parts := strings.SplitN(prevLine, " ", 2)
if len(parts) == 2 {
return strings.TrimSpace(parts[1])
}
}
return ""
}

View File

@@ -17,9 +17,9 @@ import (
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
"github.com/rs/zerolog"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/net/netmon"
)
const (
@@ -62,10 +62,7 @@ func dnsFromAdapter() []string {
var ns []string
var err error
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
logger := *ProxyLogger.Load()
for i := 0; i < maxDNSAdapterRetries; i++ {
if ctx.Err() != nil {
@@ -111,10 +108,8 @@ func dnsFromAdapter() []string {
}
func getDNSServers(ctx context.Context) ([]string, error) {
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
logger := *ProxyLogger.Load()
// Check context before making the call
if ctx.Err() != nil {
return nil, ctx.Err()
@@ -303,6 +298,28 @@ func getDNSServers(ctx context.Context) ([]string, error) {
}
}
// if we have static DNS servers saved for the current default route, we should add them to the list
drIfaceName, err := netmon.DefaultRouteInterface()
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get default route interface: %v", err)
} else {
drIface, err := net.InterfaceByName(drIfaceName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get interface by name %s: %v", drIfaceName, err)
} else {
staticNs, file := SavedStaticNameservers(drIface)
Log(context.Background(), logger.Debug(),
"static dns servers from %s: %v", file, staticNs)
if len(staticNs) > 0 {
Log(context.Background(), logger.Debug(),
"Adding static DNS servers from %s: %v", drIfaceName, staticNs)
ns = append(ns, staticNs...)
}
}
}
if len(ns) == 0 {
return nil, fmt.Errorf("no valid DNS servers found")
}
@@ -320,10 +337,8 @@ func nameserversFromResolvconf() []string {
// checkDomainJoined checks if the machine is joined to an Active Directory domain
// Returns whether it's domain joined and the domain name if available
func checkDomainJoined() bool {
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
logger := *ProxyLogger.Load()
var domain *uint16
var status uint32
@@ -400,10 +415,7 @@ func validInterfaces() map[string]struct{} {
defer log.SetOutput(os.Stderr)
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
logger := *ProxyLogger.Load()
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("MSFT_NetAdapter")

2
net.go
View File

@@ -18,6 +18,7 @@ const ipv6ProbingInterval = 10 * time.Second
func hasIPv6() bool {
hasIPv6Once.Do(func() {
Log(context.Background(), ProxyLogger.Load().Debug(), "checking for IPv6 availability once")
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
val := ctrldnet.IPv6Available(ctx)
@@ -43,6 +44,7 @@ func probingIPv6(ctx context.Context, old bool) {
if ipv6Available.CompareAndSwap(old, cur) {
old = cur
}
Log(ctx, ProxyLogger.Load().Debug(), "IPv6 availability: %v", cur)
}()
}
}

View File

@@ -48,7 +48,15 @@ const (
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
var localResolver = newLocalResolver()
var localResolver Resolver
func init() {
// Initializing ProxyLogger here, so other places don't have to do nil check.
l := zerolog.New(io.Discard)
ProxyLogger.Store(&l)
localResolver = newLocalResolver()
}
var (
resolverMutex sync.Mutex
@@ -91,10 +99,8 @@ func availableNameservers() []string {
machineIPsMap := make(map[string]struct{}, len(regularIPs))
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
logger := *ProxyLogger.Load()
Log(context.Background(), logger.Debug(),
"Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs)
@@ -193,9 +199,12 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
case ResolverTypeDOQ:
return &doqResolver{uc: uc}, nil
case ResolverTypeOS:
resolverMutex.Lock()
if or == nil {
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver")
or = newResolverWithNameserver(defaultNameservers())
}
resolverMutex.Unlock()
return or, nil
case ResolverTypeLegacy:
return &legacyResolver{uc: uc}, nil
@@ -277,14 +286,29 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
nss = append(nss, (*p)...)
}
numServers := len(nss) + len(publicServers)
// If this is a LAN query, skip public DNS.
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
// remove controldPublicDnsWithPort from publicServers for LAN queries
// this is to prevent DoS for high frequency local requests
if ok && lan {
numServers -= len(publicServers)
if index := slices.Index(publicServers, controldPublicDnsWithPort); index != -1 {
publicServers = slices.Delete(publicServers, index, index+1)
numServers--
}
}
question := ""
if msg != nil && len(msg.Question) > 0 {
question = msg.Question[0].Name
}
Log(ctx, ProxyLogger.Load().Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers)
// New check: If no resolvers are available, return an error.
if numServers == 0 {
return nil, errors.New("no nameservers available")
return nil, errors.New("no nameservers available for query")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -320,10 +344,6 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
}(server)
}
}
do(nss, true)
if !lan {
do(publicServers, false)
}
logAnswer := func(server string) {
host, _, err := net.SplitHostPort(server)
@@ -333,6 +353,18 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
}
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host)
}
// try local nameservers
if len(nss) > 0 {
do(nss, true)
}
// we must always try the public servers too, since DCHP may have only public servers
// this is okay to do since we always prefer LAN nameserver responses
if len(publicServers) > 0 {
do(publicServers, false)
}
var (
nonSuccessAnswer *dns.Msg
nonSuccessServer string
@@ -353,33 +385,49 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
case res.server == controldPublicDnsWithPort:
controldSuccessAnswer = res.answer
case !res.lan:
// if there are no LAN nameservers, we should not wait
// just use the first response
if len(nss) == 0 {
Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server)
cancel()
logAnswer(res.server)
return res.answer, nil
}
publicResponses = append(publicResponses, publicResponse{
answer: res.answer,
server: res.server,
})
}
case res.answer != nil:
nonSuccessAnswer = res.answer
nonSuccessServer = res.server
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d",
res.server, res.answer.Rcode)
// When there are no LAN nameservers, we should not wait
// for other nameservers to respond.
if len(nss) == 0 {
Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer")
cancel()
logAnswer(res.server)
return res.answer, nil
}
nonSuccessAnswer = res.answer
nonSuccessServer = res.server
}
errs = append(errs, res.err)
}
if len(publicResponses) > 0 {
resp := publicResponses[0]
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server)
Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server)
logAnswer(resp.server)
return resp.answer, nil
}
if controldSuccessAnswer != nil {
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort)
Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort)
logAnswer(controldPublicDnsWithPort)
return controldSuccessAnswer, nil
}
if nonSuccessAnswer != nil {
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer)
Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer)
logAnswer(nonSuccessServer)
return nonSuccessAnswer, nil
}
@@ -428,9 +476,13 @@ func LookupIP(domain string) []string {
}
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
resolverMutex.Lock()
if or == nil {
ProxyLogger.Load().Debug().Msgf("Initialize OS resolver in lookupIP")
or = newResolverWithNameserver(defaultNameservers())
}
resolverMutex.Unlock()
nss := *or.lanServers.Load()
nss = append(nss, *or.publicServers.Load()...)
if withBootstrapDNS {
@@ -510,6 +562,9 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
// - Gateway IP address (depends on OS).
// - Input servers.
func NewBootstrapResolver(servers ...string) Resolver {
logger := *ProxyLogger.Load()
Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers)
nss := defaultNameservers()
nss = append([]string{controldPublicDnsWithPort}, nss...)
for _, ns := range servers {
@@ -526,6 +581,11 @@ func NewBootstrapResolver(servers ...string) Resolver {
//
// This is useful for doing PTR lookup in LAN network.
func NewPrivateResolver() Resolver {
logger := *ProxyLogger.Load()
Log(context.Background(), logger.Debug(), "NewPrivateResolver called")
nss := defaultNameservers()
resolveConfNss := nameserversFromResolvconf()
localRfc1918Addrs := Rfc1918Addresses()
@@ -570,6 +630,9 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
// newResolverWithNameserver returns an OS resolver from given nameservers list.
// The caller must ensure each server in list is formed "ip:53".
func newResolverWithNameserver(nameservers []string) *osResolver {
logger := *ProxyLogger.Load()
Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers)
r := &osResolver{}
var publicNss []string
var lanNss []string

View File

@@ -70,41 +70,59 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) {
}
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
ns := make([]string, 0, 2)
servers := make([]*dns.Server, 0, 2)
handlers := []dns.Handler{
// Set up a LAN nameserver that returns a success response.
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") // 127.0.0.1 is considered LAN (loopback)
if err != nil {
t.Fatalf("failed to listen on LAN address: %v", err)
}
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, successHandler())
if err != nil {
t.Fatalf("failed to run LAN test server: %v", err)
}
defer lanServer.Shutdown()
// Set up two public nameservers that return non-success responses.
publicHandlers := []dns.Handler{
nonSuccessHandlerWithRcode(dns.RcodeRefused),
nonSuccessHandlerWithRcode(dns.RcodeNameError),
successHandler(),
}
for i := range handlers {
var publicNS []string
var publicServers []*dns.Server
for _, handler := range publicHandlers {
pc, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatalf("unexpected error: %v", err)
t.Fatalf("failed to listen on public address: %v", err)
}
s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i])
s, addr, err := runLocalPacketConnTestServer(t, pc, handler)
if err != nil {
t.Fatalf("unexpected error: %v", err)
t.Fatalf("failed to run public test server: %v", err)
}
ns = append(ns, addr)
servers = append(servers, s)
publicNS = append(publicNS, addr)
publicServers = append(publicServers, s)
}
defer func() {
for _, server := range servers {
server.Shutdown()
for _, s := range publicServers {
s.Shutdown()
}
}()
// We now create an osResolver which has both a LAN and public nameserver.
resolver := &osResolver{}
resolver.publicServers.Store(&ns)
// Explicitly store the LAN nameserver.
resolver.lanServers.Store(&[]string{lanAddr})
// And store the public nameservers.
resolver.publicServers.Store(&publicNS)
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
answer, err := resolver.Resolve(context.Background(), msg)
if err != nil {
t.Fatal(err)
}
// Since a LAN nameserver is available and returns a success answer, we expect RcodeSuccess.
if answer.Rcode != dns.RcodeSuccess {
t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode])
t.Errorf("expected a success answer from LAN nameserver (RcodeSuccess) but got: %s", dns.RcodeToString[answer.Rcode])
}
}

79
staticdns.go Normal file
View File

@@ -0,0 +1,79 @@
package ctrld
import (
"net"
"os"
"path/filepath"
"runtime"
"strings"
)
var homedir string
// absHomeDir returns the absolute path to given filename using home directory as root dir.
func absHomeDir(filename string) string {
if homedir != "" {
return filepath.Join(homedir, filename)
}
dir, err := userHomeDir()
if err != nil {
return filename
}
return filepath.Join(dir, filename)
}
func dirWritable(dir string) (bool, error) {
f, err := os.CreateTemp(dir, "")
if err != nil {
return false, err
}
defer os.Remove(f.Name())
return true, f.Close()
}
func userHomeDir() (string, error) {
// viper will expand for us.
if runtime.GOOS == "windows" {
// If we're on windows, use the install path for this.
exePath, err := os.Executable()
if err != nil {
return "", err
}
return filepath.Dir(exePath), nil
}
dir := "/etc/controld"
if err := os.MkdirAll(dir, 0750); err != nil {
return os.UserHomeDir() // fallback to user home directory
}
if ok, _ := dirWritable(dir); !ok {
return os.UserHomeDir()
}
return dir, nil
}
// SavedStaticDnsSettingsFilePath returns the file path where the static DNS settings
// for the provided interface are saved.
func SavedStaticDnsSettingsFilePath(iface *net.Interface) string {
// The file is stored in the user home directory under a hidden file.
return absHomeDir(".dns_" + iface.Name)
}
// SavedStaticNameservers returns the stored static nameservers for the given interface.
func SavedStaticNameservers(iface *net.Interface) ([]string, string) {
file := SavedStaticDnsSettingsFilePath(iface)
data, err := os.ReadFile(file)
if err != nil || len(data) == 0 {
return nil, file
}
saveValues := strings.Split(string(data), ",")
var ns []string
for _, v := range saveValues {
// Skip any IP that is loopback
if ip := net.ParseIP(v); ip != nil && ip.IsLoopback() {
continue
}
ns = append(ns, v)
}
return ns, file
}