mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Merge pull request #218 from Control-D-Inc/release-branch-v1.4.1
Release branch v1.4.1
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: "go test -race ./..."
|
||||
- uses: dominikh/staticcheck-action@v1.2.0
|
||||
- uses: dominikh/staticcheck-action@v1.3.1
|
||||
with:
|
||||
version: "2024.1.1"
|
||||
install-go: false
|
||||
|
||||
5
cmd/cli/cgo.go
Normal file
5
cmd/cli/cgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = true
|
||||
196
cmd/cli/cli.go
196
cmd/cli/cli.go
@@ -61,7 +61,7 @@ var (
|
||||
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
defaultConfigFile = "ctrld.toml"
|
||||
rootCertPool *x509.CertPool
|
||||
errSelfCheckNoAnswer = errors.New("no answer from ctrld listener")
|
||||
errSelfCheckNoAnswer = errors.New("no response from ctrld listener. You can try to re-launch with flag --skip_self_checks")
|
||||
)
|
||||
|
||||
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"}
|
||||
@@ -222,10 +222,16 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
lc := &logConn{conn: conn}
|
||||
consoleWriter.Out = io.MultiWriter(os.Stdout, lc)
|
||||
p.logConn = lc
|
||||
} else {
|
||||
mainLog.Load().Warn().Err(err).Msgf("unable to create log ipc connection")
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Warn().Err(err).Msgf("unable to resolve socket address: %s", sockPath)
|
||||
}
|
||||
notifyExitToLogServer := func() {
|
||||
_, _ = p.logConn.Write([]byte(msgExit))
|
||||
if p.logConn != nil {
|
||||
_, _ = p.logConn.Write([]byte(msgExit))
|
||||
}
|
||||
}
|
||||
|
||||
if daemon && runtime.GOOS == "windows" {
|
||||
@@ -266,10 +272,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
|
||||
// Log config do not have thing to validate, so it's safe to init log here,
|
||||
// so it's able to log information in processCDFlags.
|
||||
logWriters := initLogging()
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logWriters)
|
||||
p.initLogging(true)
|
||||
|
||||
mainLog.Load().Info().Msgf("starting ctrld %s", curVersion())
|
||||
mainLog.Load().Info().Msgf("os: %s", osVersion())
|
||||
@@ -322,7 +325,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
}
|
||||
|
||||
updated := updateListenerConfig(&cfg)
|
||||
updated := updateListenerConfig(&cfg, notifyExitToLogServer)
|
||||
|
||||
if cdUID != "" {
|
||||
processLogAndCacheFlags()
|
||||
@@ -418,7 +421,8 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
if err := p.router.Cleanup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
p.resetDNS()
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -484,7 +488,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
nop := zerolog.Nop()
|
||||
_, _ = tryUpdateListenerConfig(&cfg, &nop, true)
|
||||
_, _ = tryUpdateListenerConfig(&cfg, &nop, func() {}, true)
|
||||
addExtraSplitDnsRule(&cfg)
|
||||
if err := writeConfigFile(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
|
||||
@@ -645,11 +649,15 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) {
|
||||
// Fetch config, unmarshal to cfg.
|
||||
if resolverConfig.Ctrld.CustomConfig != "" {
|
||||
logger.Info().Msg("using defined custom config of Control-D resolver")
|
||||
if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil {
|
||||
var cfgErr error
|
||||
if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil {
|
||||
setListenerDefaultValue(cfg)
|
||||
return resolverConfig, nil
|
||||
setNetworkDefaultValue(cfg)
|
||||
if cfgErr = validateConfig(cfg); cfgErr == nil {
|
||||
return resolverConfig, nil
|
||||
}
|
||||
}
|
||||
mainLog.Load().Err(err).Msg("disregarding invalid custom config")
|
||||
mainLog.Load().Warn().Err(err).Msg("disregarding invalid custom config")
|
||||
}
|
||||
|
||||
bootstrapIP := func(endpoint string) string {
|
||||
@@ -666,11 +674,7 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
cfg.Network = make(map[string]*ctrld.NetworkConfig)
|
||||
cfg.Network["0"] = &ctrld.NetworkConfig{
|
||||
Name: "Network 0",
|
||||
Cidrs: []string{"0.0.0.0/0"},
|
||||
}
|
||||
|
||||
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
|
||||
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
|
||||
BootstrapIP: bootstrapIP(resolverConfig.DOH),
|
||||
@@ -693,6 +697,7 @@ func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) {
|
||||
|
||||
// Set default value.
|
||||
setListenerDefaultValue(cfg)
|
||||
setNetworkDefaultValue(cfg)
|
||||
|
||||
return resolverConfig, nil
|
||||
}
|
||||
@@ -706,7 +711,21 @@ func setListenerDefaultValue(cfg *ctrld.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
// setListenerDefaultValue sets the default value for cfg.Listener if none existed.
|
||||
func setNetworkDefaultValue(cfg *ctrld.Config) {
|
||||
if len(cfg.Network) == 0 {
|
||||
cfg.Network = map[string]*ctrld.NetworkConfig{
|
||||
"0": {
|
||||
Name: "Network 0",
|
||||
Cidrs: []string{"0.0.0.0/0"},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateCdRemoteConfig validates the custom config from ControlD if defined.
|
||||
// This only validate the config syntax. To validate the config rules, calling
|
||||
// validateConfig with the cfg after calling this function.
|
||||
func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error {
|
||||
if rc.Ctrld.CustomConfig == "" {
|
||||
return nil
|
||||
@@ -783,7 +802,13 @@ func defaultIfaceName() string {
|
||||
if oi := osinfo.New(); strings.Contains(oi.String(), "Microsoft") {
|
||||
return "lo"
|
||||
}
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to get default route interface")
|
||||
// On linux, it could be either resolvconf or systemd which is managing DNS settings,
|
||||
// so the interface name does not matter if there's no default route interface.
|
||||
if runtime.GOOS == "linux" {
|
||||
return "lo"
|
||||
}
|
||||
mainLog.Load().Debug().Err(err).Msg("no default route interface found")
|
||||
return ""
|
||||
}
|
||||
return dri
|
||||
}
|
||||
@@ -843,10 +868,12 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msg("ctrld listener is ready")
|
||||
mainLog.Load().Debug().Msg("performing self-check")
|
||||
|
||||
lc := cfg.FirstListener()
|
||||
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
|
||||
|
||||
mainLog.Load().Debug().Msgf("performing listener test, sending queries to %s", addr)
|
||||
|
||||
if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil {
|
||||
return false, status, err
|
||||
}
|
||||
@@ -860,7 +887,7 @@ func selfCheckStatus(ctx context.Context, s service.Service, sockDir string) (bo
|
||||
func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain string) error {
|
||||
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
|
||||
bo.LogLongerThan = 500 * time.Millisecond
|
||||
maxAttempts := 20
|
||||
maxAttempts := 10
|
||||
c := new(dns.Client)
|
||||
|
||||
var (
|
||||
@@ -876,7 +903,7 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(domain+".", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, addr)
|
||||
r, _, exErr := exchangeContextWithTimeout(c, 5*time.Second, m, addr)
|
||||
if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 {
|
||||
mainLog.Load().Debug().Msgf("%s self-check against %q succeeded", scope, domain)
|
||||
return nil
|
||||
@@ -1030,16 +1057,25 @@ func uninstall(p *prog, s service.Service) {
|
||||
mainLog.Load().Warn().Err(err).Msg("post uninstallation failed, please check system/service log for details error")
|
||||
return
|
||||
}
|
||||
p.resetDNS()
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
|
||||
// if present restore the original DNS settings
|
||||
if netIface, err := netInterface(p.runningIface); err == nil {
|
||||
if err := restoreDNS(netIface); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Restored DNS on interface successfully")
|
||||
// Iterate over all physical interfaces and restore DNS if a saved static config exists.
|
||||
withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error {
|
||||
file := savedStaticDnsSettingsFilePath(i)
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
if err := restoreDNS(i); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name)
|
||||
err = os.Remove(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Could not remove saved static DNS file for interface %s", i.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if router.Name() != "" {
|
||||
mainLog.Load().Debug().Msg("Router cleanup")
|
||||
@@ -1146,8 +1182,8 @@ func mobileListenerIp() string {
|
||||
// updateListenerConfig updates the config for listeners if not defined,
|
||||
// or defined but invalid to be used, e.g: using loopback address other
|
||||
// than 127.0.0.1 with systemd-resolved.
|
||||
func updateListenerConfig(cfg *ctrld.Config) bool {
|
||||
updated, _ := tryUpdateListenerConfig(cfg, nil, true)
|
||||
func updateListenerConfig(cfg *ctrld.Config, notifyToLogServerFunc func()) bool {
|
||||
updated, _ := tryUpdateListenerConfig(cfg, nil, notifyToLogServerFunc, true)
|
||||
if addExtraSplitDnsRule(cfg) {
|
||||
updated = true
|
||||
}
|
||||
@@ -1157,13 +1193,14 @@ func updateListenerConfig(cfg *ctrld.Config) bool {
|
||||
// tryUpdateListenerConfig tries updating listener config with a working one.
|
||||
// If fatal is true, and there's listen address conflicted, the function do
|
||||
// fatal error.
|
||||
func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fatal bool) (updated, ok bool) {
|
||||
func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, notifyFunc func(), fatal bool) (updated, ok bool) {
|
||||
ok = true
|
||||
lcc := make(map[string]*listenerConfigCheck)
|
||||
cdMode := cdUID != ""
|
||||
nextdnsMode := nextdns != ""
|
||||
// For Windows server with local Dns server running, we can only try on random local IP.
|
||||
hasLocalDnsServer := hasLocalDnsServerRunning()
|
||||
notRouter := router.Name() == ""
|
||||
for n, listener := range cfg.Listener {
|
||||
lcc[n] = &listenerConfigCheck{}
|
||||
if listener.IP == "" {
|
||||
@@ -1193,6 +1230,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
|
||||
}
|
||||
updated = updated || lcc[n].IP || lcc[n].Port
|
||||
}
|
||||
|
||||
il := mainLog.Load()
|
||||
if infoLogger != nil {
|
||||
il = infoLogger
|
||||
@@ -1277,10 +1315,17 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
|
||||
tryOldIPPort5354 = false
|
||||
tryPort5354 = false
|
||||
}
|
||||
// if not running on a router, we should not try to listen on any port other than 53
|
||||
// if we do, this will break the dns resolution for the system.
|
||||
if notRouter {
|
||||
tryOldIPPort5354 = false
|
||||
tryPort5354 = false
|
||||
}
|
||||
attempts := 0
|
||||
maxAttempts := 10
|
||||
for {
|
||||
if attempts == maxAttempts {
|
||||
notifyFunc()
|
||||
logMsg(mainLog.Load().Fatal(), n, "could not find available listen ip and port")
|
||||
}
|
||||
addr := net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))
|
||||
@@ -1288,8 +1333,12 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
logMsg(il.Info(), n, "error listening on address: %s, error: %v", addr, err)
|
||||
|
||||
if !check.IP && !check.Port {
|
||||
if fatal {
|
||||
notifyFunc()
|
||||
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
|
||||
}
|
||||
ok = false
|
||||
@@ -1348,14 +1397,17 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
|
||||
} else {
|
||||
listener.IP = oldIP
|
||||
}
|
||||
if check.Port {
|
||||
// if we are not running on a router, we should not try to listen on any port other than 53
|
||||
// if we do, this will break the dns resolution for the system.
|
||||
if check.Port && !notRouter {
|
||||
listener.Port = randomPort()
|
||||
} else {
|
||||
listener.Port = oldPort
|
||||
}
|
||||
if listener.IP == oldIP && listener.Port == oldPort {
|
||||
if fatal {
|
||||
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
|
||||
notifyFunc()
|
||||
logMsg(mainLog.Load().Fatal(), n, "could not listen on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
@@ -1393,6 +1445,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
notifyFunc()
|
||||
logMsg(mainLog.Load().Fatal(), n, "could not use %q as DNS nameserver with systemd resolved", listener.IP)
|
||||
}
|
||||
}
|
||||
@@ -1597,7 +1650,7 @@ func doGenerateNextDNSConfig(uid string) error {
|
||||
}
|
||||
mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile)
|
||||
generateNextDNSConfig(uid)
|
||||
updateListenerConfig(&cfg)
|
||||
updateListenerConfig(&cfg, func() {})
|
||||
return writeConfigFile(&cfg)
|
||||
}
|
||||
|
||||
@@ -1739,9 +1792,14 @@ func goArm() string {
|
||||
// upgradeUrl returns the url for downloading new ctrld binary.
|
||||
func upgradeUrl(baseUrl string) string {
|
||||
dlPath := fmt.Sprintf("%s-%s/ctrld", runtime.GOOS, runtime.GOARCH)
|
||||
// Use arm version set during build time, v5 binary can be run on higher arm version system.
|
||||
if armVersion := goArm(); armVersion != "" {
|
||||
dlPath = fmt.Sprintf("%s-%sv%s/ctrld", runtime.GOOS, runtime.GOARCH, armVersion)
|
||||
}
|
||||
// linux/amd64 has nocgo version, to support systems that missing some libc (like openwrt).
|
||||
if !cgoEnabled && runtime.GOOS == "linux" && runtime.GOARCH == "amd64" {
|
||||
dlPath = fmt.Sprintf("%s-%s-nocgo/ctrld", runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
dlUrl := fmt.Sprintf("%s/%s", baseUrl, dlPath)
|
||||
if runtime.GOOS == "windows" {
|
||||
dlUrl += ".exe"
|
||||
@@ -1768,50 +1826,6 @@ func runningIface(s service.Service) *ifaceResponse {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsNoLog performs resetting DNS with logging disable.
|
||||
func resetDnsNoLog(p *prog) {
|
||||
// Normally, disable log to prevent annoying users.
|
||||
if verbose < 3 {
|
||||
lvl := zerolog.GlobalLevel()
|
||||
zerolog.SetGlobalLevel(zerolog.Disabled)
|
||||
p.resetDNS()
|
||||
zerolog.SetGlobalLevel(lvl)
|
||||
return
|
||||
}
|
||||
// For debugging purpose, still emit log.
|
||||
p.resetDNS()
|
||||
}
|
||||
|
||||
// resetDnsTask returns a task which perform reset DNS operation.
|
||||
func resetDnsTask(p *prog, s service.Service, isCtrldInstalled bool, ir *ifaceResponse) task {
|
||||
return task{func() error {
|
||||
if iface == "" {
|
||||
mainLog.Load().Debug().Msg("no iface, skipping resetDnsTask")
|
||||
return nil
|
||||
}
|
||||
// Always reset DNS first, ensuring DNS setting is in a good state.
|
||||
// resetDNS must use the "iface" value of current running ctrld
|
||||
// process to reset what setDNS has done properly.
|
||||
oldIface := iface
|
||||
iface = "auto"
|
||||
p.requiredMultiNICsConfig = requiredMultiNICsConfig()
|
||||
if ir != nil {
|
||||
iface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
p.runningIface = iface
|
||||
if isCtrldInstalled {
|
||||
mainLog.Load().Debug().Msg("restore system DNS settings")
|
||||
if status, _ := s.Status(); status == service.StatusRunning {
|
||||
mainLog.Load().Fatal().Msg("reset DNS while ctrld still running is not safe")
|
||||
}
|
||||
resetDnsNoLog(p)
|
||||
}
|
||||
iface = oldIface
|
||||
return nil
|
||||
}, false, "Reset DNS"}
|
||||
}
|
||||
|
||||
// doValidateCdRemoteConfig fetches and validates custom config for cdUID.
|
||||
func doValidateCdRemoteConfig(cdUID string, fatal bool) error {
|
||||
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
@@ -1825,10 +1839,22 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// return earlier if there's no custom config.
|
||||
if rc.Ctrld.CustomConfig == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCdRemoteConfig clobbers v, saving it here to restore later.
|
||||
oldV := v
|
||||
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
|
||||
if errors.As(err, &viper.ConfigParseError{}) {
|
||||
var cfgErr error
|
||||
remoteCfg := &ctrld.Config{}
|
||||
if cfgErr = validateCdRemoteConfig(rc, remoteCfg); cfgErr == nil {
|
||||
setListenerDefaultValue(remoteCfg)
|
||||
setNetworkDefaultValue(remoteCfg)
|
||||
cfgErr = validateConfig(remoteCfg)
|
||||
} else {
|
||||
if errors.As(cfgErr, &viper.ConfigParseError{}) {
|
||||
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
|
||||
tmpDir := os.TempDir()
|
||||
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
|
||||
@@ -1844,12 +1870,14 @@ func doValidateCdRemoteConfig(cdUID string, fatal bool) error {
|
||||
}
|
||||
// If we could not log details error, emit what we have already got.
|
||||
if !errorLogged {
|
||||
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
|
||||
mainLog.Load().Error().Msgf("failed to parse custom config: %v", cfgErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
|
||||
}
|
||||
}
|
||||
if cfgErr != nil {
|
||||
mainLog.Load().Warn().Msg("disregarding invalid custom config")
|
||||
}
|
||||
v = oldV
|
||||
@@ -1863,8 +1891,8 @@ func uninstallInvalidCdUID(p *prog, logger zerolog.Logger, doStop bool) bool {
|
||||
logger.Warn().Err(err).Msg("failed to create new service")
|
||||
return false
|
||||
}
|
||||
|
||||
p.resetDNS()
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
|
||||
tasks := []task{{s.Uninstall, true, "Uninstall"}}
|
||||
if doTasks(tasks) {
|
||||
|
||||
@@ -205,7 +205,13 @@ func initStartCmd() *cobra.Command {
|
||||
Long: `Install and start the ctrld service
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: cobra.NoArgs,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
|
||||
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
checkStrFlagEmpty(cmd, cdUidFlagName)
|
||||
checkStrFlagEmpty(cmd, cdOrgFlagName)
|
||||
@@ -242,6 +248,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
currentIface = runningIface(s)
|
||||
mainLog.Load().Debug().Msgf("current interface on start: %v", currentIface)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -339,13 +346,17 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
// if already running, dont restart
|
||||
if isCtrldRunning {
|
||||
mainLog.Load().Notice().Msg("service is already running")
|
||||
return
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
tasks := []task{
|
||||
{s.Stop, false, "Stop"},
|
||||
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -355,7 +366,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
}, false, "Save current DNS"},
|
||||
{func() error {
|
||||
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||
}, false, "Configure Windows service failure actions"},
|
||||
}, false, "Configure service failure actions"},
|
||||
{s.Start, true, "Start"},
|
||||
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||
}
|
||||
@@ -424,10 +435,10 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
{s.Stop, false, "Stop"},
|
||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
|
||||
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
|
||||
resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
//resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -451,6 +462,9 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
return
|
||||
}
|
||||
|
||||
// add a small delay to ensure the service is started and did not crash
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
ok, status, err := selfCheckStatus(ctx, s, sockDir)
|
||||
switch {
|
||||
case ok && status == service.StatusRunning:
|
||||
@@ -550,6 +564,13 @@ NOTE: running "ctrld start" without any arguments will start already installed c
|
||||
Long: `Quick start service and configure DNS on interface
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
|
||||
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(os.Args) == 2 {
|
||||
startOnly = true
|
||||
@@ -608,16 +629,21 @@ func initStopCmd() *cobra.Command {
|
||||
}
|
||||
if doTasks([]task{{s.Stop, true, "Stop"}}) {
|
||||
p.router.Cleanup()
|
||||
p.resetDNS()
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
|
||||
// restore DNS settings
|
||||
if netIface, err := netInterface(p.runningIface); err == nil {
|
||||
if err := restoreDNS(netIface); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not restore DNS on interface")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Restored DNS on interface successfully")
|
||||
// Iterate over all physical interfaces and restore static DNS if a saved static config exists.
|
||||
withEachPhysicalInterfaces("", "restore static DNS", func(i *net.Interface) error {
|
||||
file := savedStaticDnsSettingsFilePath(i)
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
if err := restoreDNS(i); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Could not restore static DNS on interface %s", i.Name)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restored static DNS on interface %s successfully", i.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if router.WaitProcessExited() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
@@ -714,7 +740,8 @@ func initRestartCmd() *cobra.Command {
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
p.router.Cleanup()
|
||||
p.resetDNS()
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
@@ -994,13 +1021,13 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to remove file")
|
||||
mainLog.Load().Warn().Err(err).Msgf("failed to remove file: %s", file)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("file removed: %s", file)
|
||||
}
|
||||
}
|
||||
if err := selfDeleteExe(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to remove file")
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to delete ctrld binary")
|
||||
} else {
|
||||
if !supportedSelfDelete {
|
||||
mainLog.Load().Debug().Msgf("file removed: %s", bin)
|
||||
@@ -1044,9 +1071,16 @@ func initInterfacesCmd() *cobra.Command {
|
||||
Short: "List network interfaces of the host",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error {
|
||||
fmt.Printf("Index : %d\n", i.Index)
|
||||
fmt.Printf("Name : %s\n", i.Name)
|
||||
var status string
|
||||
if i.Flags&net.FlagUp != 0 {
|
||||
status = "Up"
|
||||
} else {
|
||||
status = "Down"
|
||||
}
|
||||
fmt.Printf("Status: %s\n", status)
|
||||
addrs, _ := i.Addrs()
|
||||
for i, ipaddr := range addrs {
|
||||
if i == 0 {
|
||||
@@ -1242,7 +1276,8 @@ func initUpgradeCmd() *cobra.Command {
|
||||
}
|
||||
dlUrl := upgradeUrl(baseUrl)
|
||||
mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl)
|
||||
resp, err := http.Get(dlUrl)
|
||||
|
||||
resp, err := getWithRetry(dlUrl)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to download binary")
|
||||
}
|
||||
@@ -1266,7 +1301,8 @@ func initUpgradeCmd() *cobra.Command {
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
p.router.Cleanup()
|
||||
p.resetDNS()
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
|
||||
@@ -79,33 +79,81 @@ func (s *controlServer) register(pattern string, handler http.Handler) {
|
||||
|
||||
func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
mainLog.Load().Debug().Msg("handling list clients request")
|
||||
|
||||
clients := p.ciTable.ListClients()
|
||||
mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
|
||||
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
mainLog.Load().Debug().Msg("sorted clients by IP address")
|
||||
|
||||
if p.metricsQueryStats.Load() {
|
||||
for _, client := range clients {
|
||||
mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts")
|
||||
|
||||
for idx, client := range clients {
|
||||
mainLog.Load().Debug().
|
||||
Int("index", idx).
|
||||
Str("ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("processing client metrics")
|
||||
|
||||
client.IncludeQueryCount = true
|
||||
dm := &dto.Metric{}
|
||||
|
||||
if statsClientQueriesCount.MetricVec == nil {
|
||||
mainLog.Load().Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("skipping metrics collection: MetricVec is nil")
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
|
||||
client.IP.String(),
|
||||
client.Mac,
|
||||
client.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("could not get metrics for client: %v", client)
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("failed to get metrics for client")
|
||||
continue
|
||||
}
|
||||
if err := m.Write(dm); err == nil {
|
||||
|
||||
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
||||
client.QueryCount = int64(dm.Counter.GetValue())
|
||||
mainLog.Load().Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Int64("query_count", client.QueryCount).
|
||||
Msg("successfully collected query count")
|
||||
} else if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("failed to write metric")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts")
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
||||
mainLog.Load().Error().
|
||||
Err(err).
|
||||
Int("client_count", len(clients)).
|
||||
Msg("failed to encode clients response")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Int("client_count", len(clients)).
|
||||
Msg("successfully sent clients list response")
|
||||
}))
|
||||
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
select {
|
||||
@@ -250,7 +298,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
}
|
||||
}))
|
||||
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if time.Since(p.internalLogSent) < logSentInterval {
|
||||
if time.Since(p.internalLogSent) < logWriterSentInterval {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -20,12 +20,12 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/logger"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -435,14 +435,17 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
}
|
||||
|
||||
if p.isAdDomainQuery(req.msg) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(),
|
||||
"AD domain query detected for %s in domain %s",
|
||||
req.msg.Question[0].Name, p.adDomain)
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
// For OS resolver, local addresses are ignored to prevent possible looping.
|
||||
// However, on Active Directory Domain Controller, where it has local DNS server
|
||||
// running and listening on local addresses, these local addresses must be used
|
||||
// as nameservers, so queries for ADDC could be resolved as expected.
|
||||
if p.isAdDomainQuery(req.msg) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(),
|
||||
"AD domain query detected for %s in domain %s",
|
||||
req.msg.Question[0].Name, p.adDomain)
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
|
||||
upstreams = []string{upstreamOSLocal}
|
||||
}
|
||||
}
|
||||
|
||||
res := &proxyResponse{}
|
||||
@@ -458,7 +461,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
switch {
|
||||
case isSrvLookup(req.msg):
|
||||
case isSrvLanLookup(req.msg):
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
@@ -620,7 +623,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
||||
|
||||
// if we have no healthy upstreams, trigger recovery flow
|
||||
if p.recoverOnUpstreamFailure() {
|
||||
if p.leakOnUpstreamFailure() {
|
||||
if p.um.countHealthy(upstreams) == 0 {
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel == nil {
|
||||
@@ -639,19 +642,20 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
|
||||
}
|
||||
}
|
||||
|
||||
// attempt query to OS resolver while as a retry catch all
|
||||
if upstreams[0] != upstreamOS {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
|
||||
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
|
||||
if answer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
|
||||
res.answer = answer
|
||||
res.upstream = osUpstreamConfig.Endpoint
|
||||
return res
|
||||
// attempt query to OS resolver while as a retry catch all
|
||||
// we dont want this to happen if leakOnUpstreamFailure is false
|
||||
if upstreams[0] != upstreamOS {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
|
||||
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
|
||||
if answer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
|
||||
res.answer = answer
|
||||
res.upstream = osUpstreamConfig.Endpoint
|
||||
return res
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
|
||||
}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
@@ -1108,21 +1112,27 @@ func isLanHostnameQuery(m *dns.Msg) bool {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
return isLanHostname(q.Name)
|
||||
}
|
||||
|
||||
// isSrvLanLookup reports whether DNS message is an SRV query of a LAN hostname.
|
||||
func isSrvLanLookup(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
q := m.Question[0]
|
||||
return q.Qtype == dns.TypeSRV && isLanHostname(q.Name)
|
||||
}
|
||||
|
||||
// isLanHostname reports whether name is a LAN hostname.
|
||||
func isLanHostname(name string) bool {
|
||||
name = strings.TrimSuffix(name, ".")
|
||||
return !strings.Contains(name, ".") ||
|
||||
strings.HasSuffix(name, ".domain") ||
|
||||
strings.HasSuffix(name, ".lan") ||
|
||||
strings.HasSuffix(name, ".local")
|
||||
}
|
||||
|
||||
// isSrvLookup reports whether DNS message is a SRV query.
|
||||
func isSrvLookup(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
return m.Question[0].Qtype == dns.TypeSRV
|
||||
}
|
||||
|
||||
// isWanClient reports whether the input is a WAN address.
|
||||
func isWanClient(na net.Addr) bool {
|
||||
var ip netip.Addr
|
||||
@@ -1177,7 +1187,10 @@ func FlushDNSCache() error {
|
||||
|
||||
// monitorNetworkChanges starts monitoring for network interface changes
|
||||
func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: "))
|
||||
mon, err := netmon.New(func(format string, args ...any) {
|
||||
// Always fetch the latest logger (and inject the prefix)
|
||||
mainLog.Load().Printf("netmon: "+format, args...)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating network monitor: %w", err)
|
||||
}
|
||||
@@ -1248,8 +1261,16 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// if the default route changed, set changed to true
|
||||
if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface {
|
||||
changed = true
|
||||
mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface)
|
||||
}
|
||||
|
||||
if !changed {
|
||||
mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
||||
// check if the default IPs are still on an interface that is up
|
||||
ValidateDefaultLocalIPsFromDelta(delta.New)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1260,6 +1281,13 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
|
||||
// Get IPs from default route interface in new state
|
||||
selfIP := defaultRouteIP()
|
||||
|
||||
// Ensure that selfIP is an IPv4 address.
|
||||
// If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it
|
||||
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil {
|
||||
mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP)
|
||||
selfIP = ""
|
||||
}
|
||||
var ipv6 string
|
||||
|
||||
if delta.New.DefaultRouteInterface != "" {
|
||||
@@ -1295,7 +1323,8 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(selfIP); ip != nil {
|
||||
// Only set the IPv4 default if selfIP is a valid IPv4 address.
|
||||
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil {
|
||||
ctrld.SetDefaultLocalIPv4(ip)
|
||||
if !isMobile() && p.ciTable != nil {
|
||||
p.ciTable.SetSelfIP(selfIP)
|
||||
@@ -1306,7 +1335,8 @@ func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
|
||||
if p.recoverOnUpstreamFailure() {
|
||||
// we only trigger recovery flow for network changes on non router devices
|
||||
if router.Name() == "" {
|
||||
p.handleRecovery(RecoveryReasonNetworkChange)
|
||||
}
|
||||
})
|
||||
@@ -1438,7 +1468,10 @@ func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
// Immediately remove our DNS settings from the interface.
|
||||
// set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
|
||||
p.recoveryRunning.Store(true)
|
||||
p.resetDNS()
|
||||
// we do not want to restore any static DNS settings
|
||||
// we must try to get the DHCP values, any static DNS settings
|
||||
// will be appended to nameservers from the saved interface values
|
||||
p.resetDNS(false, false)
|
||||
|
||||
// For an OS failure, reinitialize OS resolver nameservers immediately.
|
||||
if reason == RecoveryReasonOSFailure {
|
||||
@@ -1504,12 +1537,14 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string
|
||||
go func(name string, uc *ctrld.UpstreamConfig) {
|
||||
defer wg.Done()
|
||||
mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name)
|
||||
attempts := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name)
|
||||
return
|
||||
default:
|
||||
attempts++
|
||||
// checkUpstreamOnce will reset any failure counters on success.
|
||||
if err := p.checkUpstreamOnce(name, uc); err == nil {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name)
|
||||
@@ -1523,6 +1558,18 @@ func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
|
||||
// if this is the upstreamOS and it's the 3rd attempt (or multiple of 3),
|
||||
// we should try to reinit the OS resolver to ensure we can recover
|
||||
if name == upstreamOS && attempts%3 == 0 {
|
||||
mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}(name, uc)
|
||||
@@ -1556,3 +1603,32 @@ func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.U
|
||||
}
|
||||
return upstreams
|
||||
}
|
||||
|
||||
// ValidateDefaultLocalIPsFromDelta checks if the default local IPv4 and IPv6 stored
|
||||
// are still present in the new network state (provided by delta.New).
|
||||
// If a stored default IP is no longer active, it resets that default (sets it to nil)
|
||||
// so that it won't be used in subsequent custom dialer contexts.
|
||||
func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) {
|
||||
currentIPv4 := ctrld.GetDefaultLocalIPv4()
|
||||
currentIPv6 := ctrld.GetDefaultLocalIPv6()
|
||||
|
||||
// Build a map of active IP addresses from the new state.
|
||||
activeIPs := make(map[string]bool)
|
||||
for _, prefixes := range newState.InterfaceIPs {
|
||||
for _, prefix := range prefixes {
|
||||
activeIPs[prefix.Addr().String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the default IPv4 is still active.
|
||||
if currentIPv4 != nil && !activeIPs[currentIPv4.String()] {
|
||||
mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4)
|
||||
ctrld.SetDefaultLocalIPv4(nil)
|
||||
}
|
||||
|
||||
// Check if the default IPv6 is still active.
|
||||
if currentIPv6 != nil && !activeIPs[currentIPv6.String()] {
|
||||
mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6)
|
||||
ctrld.SetDefaultLocalIPv6(nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,20 +418,21 @@ func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSrvLookup(t *testing.T) {
|
||||
func Test_isSrvLanLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isSrvLookup bool
|
||||
}{
|
||||
{"SRV", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
||||
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isSrvLookup(tc.msg); tc.isSrvLookup != got {
|
||||
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
type AppCallback struct {
|
||||
@@ -17,3 +24,55 @@ type AppConfig struct {
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHTTPTimeout = 30 * time.Second
|
||||
defaultMaxRetries = 3
|
||||
)
|
||||
|
||||
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
|
||||
func httpClientWithFallback(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
// Prefer IPv4 over IPv6
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry performs an HTTP request with retries
|
||||
func doWithRetry(req *http.Request, maxRetries int) (*http.Response, error) {
|
||||
var lastErr error
|
||||
client := httpClientWithFallback(defaultHTTPTimeout)
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = err
|
||||
mainLog.Load().Debug().Err(err).
|
||||
Str("method", req.Method).
|
||||
Str("url", req.URL.String()).
|
||||
Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries)
|
||||
}
|
||||
return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr)
|
||||
}
|
||||
|
||||
// Helper for making GET requests with retries
|
||||
func getWithRetry(url string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doWithRetry(req, defaultMaxRetries)
|
||||
}
|
||||
|
||||
@@ -16,13 +16,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
logSentInterval = time.Minute
|
||||
logStartEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
logLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
logWarnEndMarker = "\n\n=== WARN_END ===\n\n"
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
logWriterSentInterval = time.Minute
|
||||
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
)
|
||||
|
||||
type logViewResponse struct {
|
||||
@@ -69,13 +68,23 @@ func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
// If writing p causes overflows, discard old data.
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
haveEndMarker := false
|
||||
// If there's init end marker already, preserve the data til the marker.
|
||||
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
|
||||
buf = buf[:idx+len(logWriterInitEndMarker)]
|
||||
haveEndMarker = true
|
||||
} else {
|
||||
// Otherwise, preserve the initial size data.
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
}
|
||||
}
|
||||
lw.buf.Reset()
|
||||
lw.buf.Write(buf)
|
||||
lw.buf.WriteString(logStartEndMarker) // indicate that the log was truncated.
|
||||
if !haveEndMarker {
|
||||
lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated.
|
||||
}
|
||||
}
|
||||
// If p is bigger than buffer size, truncate p by half until its size is smaller.
|
||||
for len(p)+lw.buf.Len() > lw.size {
|
||||
@@ -84,6 +93,15 @@ func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
return lw.buf.Write(p)
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func (p *prog) initLogging(backup bool) {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
logWriters := initLoggingWithBackup(backup)
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logWriters)
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
func (p *prog) initInternalLogging(writers []io.Writer) {
|
||||
if !p.needInternalLogging() {
|
||||
@@ -92,7 +110,7 @@ func (p *prog) initInternalLogging(writers []io.Writer) {
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
mainLog.Load().Notice().Msg("internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logSentInterval)
|
||||
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
})
|
||||
p.mu.Lock()
|
||||
@@ -158,7 +176,7 @@ func (p *prog) logReader() (*logReader, error) {
|
||||
wlwReader := bytes.NewReader(wlw.buf.Bytes())
|
||||
wlwSize := wlw.buf.Len()
|
||||
wlw.mu.Unlock()
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logLogEndMarker)), wlwReader)
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
|
||||
lr := &logReader{r: io.NopCloser(reader)}
|
||||
lr.size = int64(lwSize + wlwSize)
|
||||
if lr.size == 0 {
|
||||
|
||||
@@ -16,7 +16,7 @@ func Test_logWriter_Write(t *testing.T) {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
newData := "B"
|
||||
halfData := strings.Repeat("A", len(data)/2) + logStartEndMarker
|
||||
halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker
|
||||
lw.Write([]byte(newData))
|
||||
if lw.buf.String() != halfData+newData {
|
||||
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
|
||||
@@ -47,3 +47,39 @@ func Test_logWriter_ConcurrentWrite(t *testing.T) {
|
||||
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_MarkerInitEnd(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
|
||||
paddingSize := 10
|
||||
// Writing half of the size, minus len(end marker) and padding size.
|
||||
dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize
|
||||
data := strings.Repeat("A", dataSize)
|
||||
// Inserting newline for making partial init data
|
||||
data += "\n"
|
||||
// Filling left over buffer to make the log full.
|
||||
// The data length: len(end marker) + padding size - 1 (for newline above) + size/2
|
||||
data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2))
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
lw.Write([]byte("B"))
|
||||
lw.Write([]byte(strings.Repeat("B", 256*1024)))
|
||||
firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker)
|
||||
lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker)
|
||||
// Check if init end marker present.
|
||||
if firstIdx == -1 || lastIdx == -1 {
|
||||
t.Fatalf("missing init end marker: %s", lw.buf.String())
|
||||
}
|
||||
// Check if init end marker appears only once.
|
||||
if firstIdx != lastIdx {
|
||||
t.Fatalf("log init end marker appears more than once: %s", lw.buf.String())
|
||||
}
|
||||
// Ensure that we have the correct init log data.
|
||||
if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) {
|
||||
t.Fatalf("unexpected log content: %s", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,24 +88,21 @@ func initConsoleLogging() {
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
mainLog.Store(&l)
|
||||
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
case verbose == 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
case verbose > 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func initLogging() []io.Writer {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
return initLoggingWithBackup(true)
|
||||
}
|
||||
|
||||
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
|
||||
// to be used for all interactive commands.
|
||||
//
|
||||
|
||||
5
cmd/cli/nocgo.go
Normal file
5
cmd/cli/nocgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = false
|
||||
@@ -72,34 +72,25 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
trySystemdResolve := false
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
break
|
||||
}
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
goto systemdResolve
|
||||
}
|
||||
if useSystemdResolved {
|
||||
if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out))
|
||||
}
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
systemdResolve:
|
||||
if trySystemdResolve {
|
||||
// Stop systemd-networkd and retry setting DNS.
|
||||
if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil {
|
||||
@@ -119,8 +110,8 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -169,6 +160,7 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
// TODO(cuonglm): handle DHCPv6 properly.
|
||||
mainLog.Load().Debug().Msg("checking for IPv6 availability")
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
@@ -188,6 +180,8 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("IPv6 is not available")
|
||||
}
|
||||
|
||||
return ignoringEINTR(func() error {
|
||||
|
||||
@@ -43,21 +43,42 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if hasLocalDnsServerRunning() {
|
||||
mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders")
|
||||
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
oldForwardersContent, _ := os.ReadFile(file)
|
||||
mainLog.Load().Debug().Msgf("Using forwarders file: %s", file)
|
||||
|
||||
oldForwardersContent, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent))
|
||||
}
|
||||
|
||||
hasLocalIPv6Listener := needLocalIPv6Listener()
|
||||
mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status")
|
||||
|
||||
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
||||
if !hasLocalIPv6Listener {
|
||||
return false
|
||||
}
|
||||
return s == "::1"
|
||||
})
|
||||
mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list")
|
||||
|
||||
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully wrote new forwarders file")
|
||||
}
|
||||
|
||||
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
||||
mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders")
|
||||
|
||||
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -147,15 +168,32 @@ func restoreDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, ns := range [][]string{v4ns, v6ns} {
|
||||
if len(ns) == 0 {
|
||||
continue
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
|
||||
err = setDNS(iface, ns)
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoreDNS: %w", err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
if len(v4ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
|
||||
if err := setDNS(iface, v4ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(v6ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
|
||||
if err := setDNS(iface, v6ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,43 +218,69 @@ func currentDNS(iface *net.Interface) []string {
|
||||
return ns
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
|
||||
// like "NameServer" and "ProfileNameServer".
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err)
|
||||
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("luid.GUID: %w", err)
|
||||
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
|
||||
}
|
||||
|
||||
var ns []string
|
||||
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
|
||||
found := false
|
||||
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
|
||||
for _, path := range keyPaths {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err)
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
|
||||
continue
|
||||
}
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
value, _, err := k.GetStringValue(key)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
return nil, fmt.Errorf("%s: %w", key, err)
|
||||
}
|
||||
if len(value) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(value, ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
func() {
|
||||
defer k.Close()
|
||||
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
|
||||
value, _, err := k.GetStringValue(keyName)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName)
|
||||
continue
|
||||
}
|
||||
if len(value) > 0 {
|
||||
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
|
||||
parsed := parseDNSServers(value)
|
||||
for _, pns := range parsed {
|
||||
if !slices.Contains(ns, pns) {
|
||||
ns = append(ns, pns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name)
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
|
||||
// and trims any extraneous whitespace or null characters.
|
||||
func parseDNSServers(val string) []string {
|
||||
fields := strings.FieldsFunc(val, func(r rune) bool {
|
||||
return r == ' ' || r == ','
|
||||
})
|
||||
var servers []string
|
||||
for _, f := range fields {
|
||||
trimmed := strings.TrimSpace(f)
|
||||
if len(trimmed) > 0 {
|
||||
servers = append(servers, trimmed)
|
||||
}
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
||||
// and also removing old forwarders if provided.
|
||||
func addDnsServerForwarders(nameservers, old []string) error {
|
||||
|
||||
340
cmd/cli/prog.go
340
cmd/cli/prog.go
@@ -43,7 +43,7 @@ const (
|
||||
ctrldControlUnixSockMobile = "cd.sock"
|
||||
upstreamPrefix = "upstream."
|
||||
upstreamOS = upstreamPrefix + "os"
|
||||
upstreamPrivate = upstreamPrefix + "private"
|
||||
upstreamOSLocal = upstreamOS + ".local"
|
||||
dnsWatchdogDefaultInterval = 20 * time.Second
|
||||
ctrldServiceName = "ctrld"
|
||||
)
|
||||
@@ -120,6 +120,7 @@ type prog struct {
|
||||
runningIface string
|
||||
requiredMultiNICsConfig bool
|
||||
adDomain string
|
||||
runningOnDomainController bool
|
||||
|
||||
selfUninstallMu sync.Mutex
|
||||
refusedQueryCount int
|
||||
@@ -268,7 +269,7 @@ func (p *prog) preRun() {
|
||||
if runtime.GOOS == "darwin" {
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
if !service.Interactive() {
|
||||
p.resetDNS()
|
||||
p.resetDNS(false, true)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -276,7 +277,12 @@ func (p *prog) preRun() {
|
||||
|
||||
func (p *prog) postRun() {
|
||||
if !service.Interactive() {
|
||||
p.resetDNS()
|
||||
if runtime.GOOS == "windows" {
|
||||
isDC, roleInt := isRunningOnDomainController()
|
||||
p.runningOnDomainController = isDC
|
||||
mainLog.Load().Debug().Msgf("running on domain controller: %t, role: %d", p.runningOnDomainController, roleInt)
|
||||
}
|
||||
p.resetDNS(false, false)
|
||||
ns := ctrld.InitializeOsResolver(false)
|
||||
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
|
||||
p.setDNS()
|
||||
@@ -345,14 +351,19 @@ func (p *prog) apiConfigReload() {
|
||||
if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated || forced {
|
||||
lastUpdated = time.Now().Unix()
|
||||
cfg := &ctrld.Config{}
|
||||
if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil {
|
||||
var cfgErr error
|
||||
if cfgErr = validateCdRemoteConfig(resolverConfig, cfg); cfgErr == nil {
|
||||
setListenerDefaultValue(cfg)
|
||||
setNetworkDefaultValue(cfg)
|
||||
cfgErr = validateConfig(cfg)
|
||||
}
|
||||
if cfgErr != nil {
|
||||
logger.Warn().Err(err).Msg("skipping invalid custom config")
|
||||
if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil {
|
||||
logger.Error().Err(err).Msg("could not mark custom last update failed")
|
||||
}
|
||||
return
|
||||
}
|
||||
setListenerDefaultValue(cfg)
|
||||
logger.Debug().Msg("custom config changes detected, reloading...")
|
||||
p.apiReloadCh <- cfg
|
||||
} else {
|
||||
@@ -560,13 +571,12 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
if !reload {
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
logWriters := initLoggingWithBackup(false)
|
||||
p.initLogging(false)
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
go p.apiConfigReload()
|
||||
p.postRun()
|
||||
p.initInternalLogging(logWriters)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -647,16 +657,74 @@ func (p *prog) setDNS() {
|
||||
if cfg.Listener == nil {
|
||||
return
|
||||
}
|
||||
if p.runningIface == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// allIfaces tracks whether we should set DNS for all physical interfaces.
|
||||
allIfaces := p.requiredMultiNICsConfig
|
||||
lc := cfg.FirstListener()
|
||||
if lc == nil {
|
||||
return
|
||||
}
|
||||
ns := lc.IP
|
||||
switch {
|
||||
case lc.IsDirectDnsListener():
|
||||
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
|
||||
ns = "127.0.0.1"
|
||||
case lc.Port != 53:
|
||||
ns = "127.0.0.1"
|
||||
if resolver := router.LocalResolverIP(); resolver != "" {
|
||||
ns = resolver
|
||||
}
|
||||
default:
|
||||
// If we ever reach here, it means ctrld is running on lc.IP port 53,
|
||||
// so we could just use lc.IP as nameserver.
|
||||
}
|
||||
|
||||
nameservers := []string{ns}
|
||||
if needRFC1918Listeners(lc) {
|
||||
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
|
||||
}
|
||||
if needLocalIPv6Listener() {
|
||||
nameservers = append(nameservers, "::1")
|
||||
}
|
||||
|
||||
slices.Sort(nameservers)
|
||||
|
||||
netIfaceName := ""
|
||||
netIface := p.setDnsForRunningIface(nameservers)
|
||||
if netIface != nil {
|
||||
netIfaceName = netIface.Name
|
||||
}
|
||||
setDnsOK = true
|
||||
|
||||
if p.requiredMultiNICsConfig {
|
||||
withEachPhysicalInterfaces(netIfaceName, "set DNS", func(i *net.Interface) error {
|
||||
return setDnsIgnoreUnusableInterface(i, nameservers)
|
||||
})
|
||||
}
|
||||
// resolvconf file is only useful when we have default route interface,
|
||||
// then set DNS on this interface will push change to /etc/resolv.conf file.
|
||||
if netIface != nil && shouldWatchResolvconf() {
|
||||
servers := make([]netip.Addr, len(nameservers))
|
||||
for i := range nameservers {
|
||||
servers[i] = netip.MustParseAddr(nameservers[i])
|
||||
}
|
||||
p.dnsWg.Add(1)
|
||||
go func() {
|
||||
defer p.dnsWg.Done()
|
||||
p.watchResolvConf(netIface, servers, setResolvConf)
|
||||
}()
|
||||
}
|
||||
if p.dnsWatchdogEnabled() {
|
||||
p.dnsWg.Add(1)
|
||||
go func() {
|
||||
defer p.dnsWg.Done()
|
||||
p.dnsWatchdog(netIface, nameservers)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) setDnsForRunningIface(nameservers []string) (runningIface *net.Interface) {
|
||||
if p.runningIface == "" {
|
||||
return
|
||||
}
|
||||
|
||||
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
|
||||
const maxDNSRetryAttempts = 3
|
||||
@@ -690,59 +758,14 @@ func (p *prog) setDNS() {
|
||||
return
|
||||
}
|
||||
|
||||
runningIface = netIface
|
||||
logger.Debug().Msg("setting DNS for interface")
|
||||
ns := lc.IP
|
||||
switch {
|
||||
case lc.IsDirectDnsListener():
|
||||
// If ctrld is direct listener, use 127.0.0.1 as nameserver.
|
||||
ns = "127.0.0.1"
|
||||
case lc.Port != 53:
|
||||
ns = "127.0.0.1"
|
||||
if resolver := router.LocalResolverIP(); resolver != "" {
|
||||
ns = resolver
|
||||
}
|
||||
default:
|
||||
// If we ever reach here, it means ctrld is running on lc.IP port 53,
|
||||
// so we could just use lc.IP as nameserver.
|
||||
}
|
||||
|
||||
nameservers := []string{ns}
|
||||
if needRFC1918Listeners(lc) {
|
||||
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
|
||||
}
|
||||
if needLocalIPv6Listener() {
|
||||
nameservers = append(nameservers, "::1")
|
||||
}
|
||||
slices.Sort(nameservers)
|
||||
if err := setDNS(netIface, nameservers); err != nil {
|
||||
logger.Error().Err(err).Msgf("could not set DNS for interface")
|
||||
return
|
||||
}
|
||||
setDnsOK = true
|
||||
logger.Debug().Msg("setting DNS successfully")
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
|
||||
return setDnsIgnoreUnusableInterface(i, nameservers)
|
||||
})
|
||||
}
|
||||
if shouldWatchResolvconf() {
|
||||
servers := make([]netip.Addr, len(nameservers))
|
||||
for i := range nameservers {
|
||||
servers[i] = netip.MustParseAddr(nameservers[i])
|
||||
}
|
||||
p.dnsWg.Add(1)
|
||||
go func() {
|
||||
defer p.dnsWg.Done()
|
||||
p.watchResolvConf(netIface, servers, setResolvConf)
|
||||
}()
|
||||
}
|
||||
if p.dnsWatchdogEnabled() {
|
||||
p.dnsWg.Add(1)
|
||||
go func() {
|
||||
defer p.dnsWg.Done()
|
||||
p.dnsWatchdog(netIface, nameservers, allIfaces)
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// dnsWatchdogEnabled reports whether DNS watchdog is enabled.
|
||||
@@ -765,12 +788,12 @@ func (p *prog) dnsWatchdogDuration() time.Duration {
|
||||
|
||||
// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings.
|
||||
// This is only works when deactivation pin set.
|
||||
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces bool) {
|
||||
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string) {
|
||||
if !requiredMultiNICsConfig() {
|
||||
return
|
||||
}
|
||||
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||
logger.Debug().Msg("start DNS settings watchdog")
|
||||
|
||||
mainLog.Load().Debug().Msg("start DNS settings watchdog")
|
||||
|
||||
ns := nameservers
|
||||
slices.Sort(ns)
|
||||
@@ -788,14 +811,56 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
return
|
||||
}
|
||||
if dnsChanged(iface, ns) {
|
||||
logger.Debug().Msg("DNS settings were changed, re-applying settings")
|
||||
mainLog.Load().Debug().Msg("DNS settings were changed, re-applying settings")
|
||||
// Check if the interface already has static DNS servers configured.
|
||||
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
|
||||
staticDNS, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", iface.Name)
|
||||
} else if len(staticDNS) > 0 {
|
||||
//filter out loopback addresses
|
||||
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
|
||||
return net.ParseIP(s).IsLoopback()
|
||||
})
|
||||
// if we have a static config and no saved IPs already, save them
|
||||
if len(staticDNS) > 0 && len(savedStaticNameservers(iface)) == 0 {
|
||||
// Save these static DNS values so that they can be restored later.
|
||||
if err := saveCurrentStaticDNS(iface); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", iface.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
|
||||
}
|
||||
}
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(iface.Name, "", func(i *net.Interface) error {
|
||||
if p.requiredMultiNICsConfig {
|
||||
ifaceName := ""
|
||||
if iface != nil {
|
||||
ifaceName = iface.Name
|
||||
}
|
||||
withEachPhysicalInterfaces(ifaceName, "", func(i *net.Interface) error {
|
||||
if dnsChanged(i, ns) {
|
||||
|
||||
// Check if the interface already has static DNS servers configured.
|
||||
// currentStaticDNS is an OS-dependent helper that returns the current static DNS.
|
||||
staticDNS, err := currentStaticDNS(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to get static DNS for interface %s", i.Name)
|
||||
} else if len(staticDNS) > 0 {
|
||||
//filter out loopback addresses
|
||||
staticDNS = slices.DeleteFunc(staticDNS, func(s string) bool {
|
||||
return net.ParseIP(s).IsLoopback()
|
||||
})
|
||||
// if we have a static config and no saved IPs already, save them
|
||||
if len(staticDNS) > 0 && len(savedStaticNameservers(i)) == 0 {
|
||||
// Save these static DNS values so that they can be restored later.
|
||||
if err := saveCurrentStaticDNS(i); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to save static DNS for interface %s", i.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
|
||||
} else {
|
||||
@@ -809,33 +874,78 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) resetDNS() {
|
||||
// resetDNS performs a DNS reset for all interfaces.
|
||||
func (p *prog) resetDNS(isStart bool, restoreStatic bool) {
|
||||
netIfaceName := ""
|
||||
if netIface := p.resetDNSForRunningIface(isStart, restoreStatic); netIface != nil {
|
||||
netIfaceName = netIface.Name
|
||||
}
|
||||
// See corresponding comments in (*prog).setDNS function.
|
||||
if p.requiredMultiNICsConfig {
|
||||
withEachPhysicalInterfaces(netIfaceName, "reset DNS", resetDnsIgnoreUnusableInterface)
|
||||
}
|
||||
}
|
||||
|
||||
// resetDNSForRunningIface performs a DNS reset on the running interface.
|
||||
// The parameter isStart indicates whether this is being called as part of a start (or restart)
|
||||
// command. When true, we check if the current static DNS configuration already differs from the
|
||||
// service listener (127.0.0.1). If so, we assume that an admin has manually changed the interface's
|
||||
// static DNS settings and we do not override them using the potentially out-of-date saved file.
|
||||
// Otherwise, we restore the saved configuration (if any) or reset to DHCP.
|
||||
func (p *prog) resetDNSForRunningIface(isStart bool, restoreStatic bool) (runningIface *net.Interface) {
|
||||
if p.runningIface == "" {
|
||||
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
|
||||
return
|
||||
}
|
||||
// See corresponding comments in (*prog).setDNS function.
|
||||
allIfaces := p.requiredMultiNICsConfig
|
||||
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
netIface, err := netInterface(p.runningIface)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("could not get interface")
|
||||
return
|
||||
}
|
||||
|
||||
runningIface = netIface
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
}
|
||||
logger.Debug().Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
logger.Error().Err(err).Msgf("could not reset DNS")
|
||||
return
|
||||
|
||||
// If starting, check the current static DNS configuration.
|
||||
if isStart {
|
||||
current, err := currentStaticDNS(netIface)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("unable to obtain current static DNS configuration; proceeding to restore saved config")
|
||||
} else if len(current) > 0 {
|
||||
// If any static DNS value is not our own listener, assume an admin override.
|
||||
hasManualConfig := false
|
||||
for _, ns := range current {
|
||||
if ns != "127.0.0.1" && ns != "::1" {
|
||||
hasManualConfig = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasManualConfig {
|
||||
logger.Debug().Msgf("Detected manual DNS configuration on interface %q: %v; not overriding with saved configuration", netIface.Name, current)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.Debug().Msg("Restoring DNS successfully")
|
||||
if allIfaces {
|
||||
withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDnsIgnoreUnusableInterface)
|
||||
|
||||
// Default logic: if there is a saved static DNS configuration, restore it.
|
||||
saved := savedStaticNameservers(netIface)
|
||||
if len(saved) > 0 && restoreStatic {
|
||||
logger.Debug().Msgf("Restoring interface %q from saved static config: %v", netIface.Name, saved)
|
||||
if err := setDNS(netIface, saved); err != nil {
|
||||
logger.Error().Err(err).Msgf("failed to restore static DNS config on interface %q", netIface.Name)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
logger.Debug().Msgf("No saved static DNS config for interface %q; resetting to DHCP", netIface.Name)
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
logger.Error().Err(err).Msgf("failed to reset DNS to DHCP on interface %q", netIface.Name)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *prog) logInterfacesState() {
|
||||
@@ -985,12 +1095,6 @@ func findWorkingInterface(currentIface string) string {
|
||||
return currentIface
|
||||
}
|
||||
|
||||
// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure.
|
||||
func (p *prog) recoverOnUpstreamFailure() bool {
|
||||
// Default is false on routers, since this recovery flow is only useful for devices that move between networks.
|
||||
return router.Name() == ""
|
||||
}
|
||||
|
||||
func randomLocalIP() string {
|
||||
n := rand.Intn(254-2) + 2
|
||||
return fmt.Sprintf("127.0.0.%d", n)
|
||||
@@ -1192,7 +1296,7 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.
|
||||
// TODO: investigate whether we should report this error?
|
||||
if err := f(netIface); err == nil {
|
||||
if context != "" {
|
||||
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
|
||||
mainLog.Load().Debug().Msgf("Ran %s for interface %q successfully", context, i.Name)
|
||||
}
|
||||
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
|
||||
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
|
||||
@@ -1215,19 +1319,38 @@ var errSaveCurrentStaticDNSNotSupported = errors.New("saving current DNS is not
|
||||
// saveCurrentStaticDNS saves the current static DNS settings for restoring later.
|
||||
// Only works on Windows and Mac.
|
||||
func saveCurrentStaticDNS(iface *net.Interface) error {
|
||||
if iface == nil {
|
||||
mainLog.Load().Debug().Msg("could not save current static DNS settings for nil interface")
|
||||
return nil
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "windows", "darwin":
|
||||
default:
|
||||
return errSaveCurrentStaticDNSNotSupported
|
||||
}
|
||||
file := savedStaticDnsSettingsFilePath(iface)
|
||||
ns, _ := currentStaticDNS(iface)
|
||||
ns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not get current static DNS settings for %q", iface.Name)
|
||||
return err
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("no static DNS settings for %q, removing old static DNS settings file", iface.Name)
|
||||
_ = os.Remove(file) // removing old static DNS settings
|
||||
return nil
|
||||
}
|
||||
//filter out loopback addresses
|
||||
ns = slices.DeleteFunc(ns, func(s string) bool {
|
||||
return net.ParseIP(s).IsLoopback()
|
||||
})
|
||||
//if we now have no static DNS settings and the file already exists
|
||||
// return and do not save the file
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("loopback on %q, skipping saving static DNS settings", iface.Name)
|
||||
return nil
|
||||
}
|
||||
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not remove old static DNS settings file")
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not remove old static DNS settings file: %s", file)
|
||||
}
|
||||
nss := strings.Join(ns, ",")
|
||||
mainLog.Load().Debug().Msgf("DNS settings for %q is static: %v, saving ...", iface.Name, nss)
|
||||
@@ -1241,6 +1364,9 @@ func saveCurrentStaticDNS(iface *net.Interface) error {
|
||||
|
||||
// savedStaticDnsSettingsFilePath returns the path to saved DNS settings of the given interface.
|
||||
func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
|
||||
if iface == nil {
|
||||
return ""
|
||||
}
|
||||
return absHomeDir(".dns_" + iface.Name)
|
||||
}
|
||||
|
||||
@@ -1248,6 +1374,10 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
|
||||
//
|
||||
//lint:ignore U1000 use in os_windows.go and os_darwin.go
|
||||
func savedStaticNameservers(iface *net.Interface) []string {
|
||||
if iface == nil {
|
||||
mainLog.Load().Debug().Msg("could not get saved static DNS settings for nil interface")
|
||||
return nil
|
||||
}
|
||||
file := savedStaticDnsSettingsFilePath(iface)
|
||||
if data, _ := os.ReadFile(file); len(data) > 0 {
|
||||
saveValues := strings.Split(string(data), ",")
|
||||
@@ -1265,8 +1395,13 @@ func savedStaticNameservers(iface *net.Interface) []string {
|
||||
}
|
||||
|
||||
// dnsChanged reports whether DNS settings for given interface was changed.
|
||||
// It returns false for a nil iface.
|
||||
//
|
||||
// The caller must sort the nameservers before calling this function.
|
||||
func dnsChanged(iface *net.Interface, nameservers []string) bool {
|
||||
if iface == nil {
|
||||
return false
|
||||
}
|
||||
curNameservers, _ := currentStaticDNS(iface)
|
||||
slices.Sort(curNameservers)
|
||||
if !slices.Equal(curNameservers, nameservers) {
|
||||
@@ -1286,3 +1421,36 @@ func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
|
||||
selfUninstall(p, logger)
|
||||
}
|
||||
}
|
||||
|
||||
// leakOnUpstreamFailure reports whether ctrld should initiate a recovery flow
|
||||
// when upstream failures occur.
|
||||
func (p *prog) leakOnUpstreamFailure() bool {
|
||||
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
|
||||
return *ptr
|
||||
}
|
||||
// Default is false on routers, since this leaking is only useful for devices that move between networks.
|
||||
if router.Name() != "" {
|
||||
return false
|
||||
}
|
||||
// if we are running on ADDC, we should not leak on upstream failure
|
||||
if p.runningOnDomainController {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Domain controller role values from Win32_ComputerSystem
|
||||
// https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/win32-computersystem
|
||||
const (
|
||||
BackupDomainController = 4
|
||||
PrimaryDomainController = 5
|
||||
)
|
||||
|
||||
// isRunningOnDomainController checks if the current machine is a domain controller
|
||||
// by querying the DomainRole property from Win32_ComputerSystem via WMI.
|
||||
func isRunningOnDomainController() (bool, int) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return false, 0
|
||||
}
|
||||
return isRunningOnDomainControllerWindows()
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"tailscale.com/health"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -39,6 +40,9 @@ func setDependencies(svc *service.Config) {
|
||||
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
|
||||
}
|
||||
}
|
||||
if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 {
|
||||
svc.Dependencies = append(svc.Dependencies, routerDeps...)
|
||||
}
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !freebsd && !darwin
|
||||
//go:build !linux && !freebsd && !darwin && !windows
|
||||
|
||||
package cli
|
||||
|
||||
|
||||
14
cmd/cli/prog_windows.go
Normal file
14
cmd/cli/prog_windows.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
func setDependencies(svc *service.Config) {
|
||||
if hasLocalDnsServerRunning() {
|
||||
svc.Dependencies = []string{"DNS"}
|
||||
}
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/openwrt"
|
||||
)
|
||||
|
||||
// newService wraps service.New call to return service.Service
|
||||
@@ -167,7 +169,11 @@ func doTasks(tasks []task) bool {
|
||||
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
|
||||
return false
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
|
||||
// if this is darwin stop command, dont print debug
|
||||
// since launchctl complains on every start
|
||||
if runtime.GOOS != "darwin" || task.Name != "Stop" {
|
||||
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -188,6 +194,13 @@ func checkHasElevatedPrivilege() {
|
||||
func unixSystemVServiceStatus() (service.Status, error) {
|
||||
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
|
||||
if err != nil {
|
||||
// Specific case for openwrt >= 24.10, it returns non-success code
|
||||
// for above status command, which may not right.
|
||||
if router.Name() == openwrt.Name {
|
||||
if string(bytes.ToLower(bytes.TrimSpace(out))) == "inactive" {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -18,3 +18,5 @@ func openLogFile(path string, flags int) (*os.File, error) {
|
||||
func hasLocalDnsServerRunning() bool { return false }
|
||||
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 }
|
||||
|
||||
@@ -2,12 +2,18 @@ package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
@@ -165,3 +171,57 @@ func hasLocalDnsServerRunning() bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) {
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("WMI query failed")
|
||||
return false, 0
|
||||
}
|
||||
if instances == nil {
|
||||
mainLog.Load().Debug().Msg("WMI query returned nil instances")
|
||||
return false, 0
|
||||
}
|
||||
defer instances.Close()
|
||||
|
||||
if len(instances) == 0 {
|
||||
mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
val, err := instances[0].GetProperty("DomainRole")
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property")
|
||||
return false, 0
|
||||
}
|
||||
if val == nil {
|
||||
mainLog.Load().Debug().Msg("DomainRole property is nil")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Safely handle varied types: string or integer
|
||||
var roleInt int
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
// "4", "5", etc.
|
||||
parsed, parseErr := strconv.Atoi(v)
|
||||
if parseErr != nil {
|
||||
mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v)
|
||||
return false, 0
|
||||
}
|
||||
roleInt = parsed
|
||||
case int8, int16, int32, int64:
|
||||
roleInt = int(reflect.ValueOf(v).Int())
|
||||
case uint8, uint16, uint32, uint64:
|
||||
roleInt = int(reflect.ValueOf(v).Uint())
|
||||
default:
|
||||
mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v)
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Check if role indicates a domain controller
|
||||
isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController
|
||||
return isDC, roleInt
|
||||
}
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package main
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cli.Main()
|
||||
// make sure we exit with 0 if there are no errors
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -529,7 +529,7 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs)
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -166,7 +166,6 @@ before serving the query.
|
||||
|
||||
### max_concurrent_requests
|
||||
The number of concurrent requests that will be handled, must be a non-negative integer.
|
||||
Tweaking this value depends on the capacity of your system.
|
||||
|
||||
- Type: number
|
||||
- Required: no
|
||||
@@ -253,9 +252,7 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus
|
||||
- Default: ""
|
||||
|
||||
### dns_watchdog_enabled
|
||||
Checking DNS changes to network interfaces and reverting to ctrld's own settings.
|
||||
|
||||
The DNS watchdog process only runs on Windows and MacOS, while in `--cd` mode.
|
||||
Watches all physical interfaces for DNS changes and reverts them to ctrld's settings.The DNS watchdog process only runs on Windows and MacOS.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
@@ -274,7 +271,7 @@ If the time duration is non-positive, default value will be used.
|
||||
- Default: 20s
|
||||
|
||||
### refetch_time
|
||||
Time in seconds between each iteration that reloads custom config if changed.
|
||||
Time in seconds between each iteration that reloads custom config from the API.
|
||||
|
||||
The value must be a positive number, any invalid value will be ignored and default value will be used.
|
||||
- Type: number
|
||||
@@ -282,7 +279,7 @@ The value must be a positive number, any invalid value will be ignored and defau
|
||||
- Default: 3600
|
||||
|
||||
### leak_on_upstream_failure
|
||||
Once ctrld is "offline", mean ctrld could not connect to any upstream, next queries will be leaked to OS resolver.
|
||||
If a remote upstream fails to resolve a query or is unreachable, `ctrld` will forward the queries to the default DNS resolver on the network. If failures persist, `ctrld` will remove itself from all networking interfaces until connectivity is restored.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
@@ -531,6 +528,15 @@ rules = [
|
||||
]
|
||||
```
|
||||
|
||||
If there is no explicitly defined rules, LAN queries will be handled solely by the OS resolver.
|
||||
|
||||
These following domains are considered LAN queries:
|
||||
|
||||
- Queries does not have dot `.` in domain name, like `machine1`, `example`, ... (1)
|
||||
- Queries have domain ends with: `.domain`, `.lan`, `.local`. (2)
|
||||
- All `SRV` queries of LAN hostname (1) + (2).
|
||||
- `PTR` queries with private IPs.
|
||||
|
||||
---
|
||||
|
||||
Note that the order of matching preference:
|
||||
|
||||
@@ -207,11 +207,10 @@ func (t *Table) init() {
|
||||
}
|
||||
for platform, discover := range discovers {
|
||||
if err := discover.refresh(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", platform)
|
||||
} else {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, discover)
|
||||
t.refreshers = append(t.refreshers, discover)
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform)
|
||||
}
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, discover)
|
||||
t.refreshers = append(t.refreshers, discover)
|
||||
}
|
||||
}
|
||||
// Hosts file mapping.
|
||||
@@ -423,17 +422,27 @@ func (t *Table) ListClients() []*Client {
|
||||
t.Refresh()
|
||||
ipMap := make(map[string]*Client)
|
||||
il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni}
|
||||
|
||||
for _, ir := range il {
|
||||
if ir == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ip := range ir.List() {
|
||||
c, ok := ipMap[ip]
|
||||
if !ok {
|
||||
c = &Client{
|
||||
IP: netip.MustParseAddr(ip),
|
||||
Source: map[string]struct{}{ir.String(): {}},
|
||||
// Validate IP before using MustParseAddr
|
||||
if addr, err := netip.ParseAddr(ip); err == nil {
|
||||
c, ok := ipMap[ip]
|
||||
if !ok {
|
||||
c = &Client{
|
||||
IP: addr,
|
||||
Source: map[string]struct{}{},
|
||||
}
|
||||
ipMap[ip] = c
|
||||
}
|
||||
// Safely get source name
|
||||
if src := ir.String(); src != "" {
|
||||
c.Source[src] = struct{}{}
|
||||
}
|
||||
ipMap[ip] = c
|
||||
} else {
|
||||
c.Source[ir.String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,6 +92,11 @@ func (m *mdns) init(quitCh chan struct{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if IPv6 is available once and use the result for the rest of the function.
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("checking for IPv6 availability in mdns init")
|
||||
ipv6 := ctrldnet.IPv6Available(context.Background())
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("IPv6 is %v in mdns init", ipv6)
|
||||
|
||||
v4ConnList := make([]*net.UDPConn, 0, len(ifaces))
|
||||
v6ConnList := make([]*net.UDPConn, 0, len(ifaces))
|
||||
for _, iface := range ifaces {
|
||||
@@ -102,7 +107,8 @@ func (m *mdns) init(quitCh chan struct{}) error {
|
||||
v4ConnList = append(v4ConnList, conn)
|
||||
go m.readLoop(conn)
|
||||
}
|
||||
if ctrldnet.IPv6Available(context.Background()) {
|
||||
|
||||
if ipv6 {
|
||||
if conn, err := net.ListenMulticastUDP("udp6", &iface, mdnsV6Addr); err == nil {
|
||||
v6ConnList = append(v6ConnList, conn)
|
||||
go m.readLoop(conn)
|
||||
|
||||
@@ -67,4 +67,16 @@ var services = [...]string{
|
||||
|
||||
// Merlin
|
||||
"_alexa._tcp",
|
||||
|
||||
// Newer Android TV devices
|
||||
"_androidtvremote2._tcp.local.",
|
||||
|
||||
// https://esphome.io/
|
||||
"_esphomelib._tcp.local.",
|
||||
|
||||
// https://www.home-assistant.io/
|
||||
"_home-assistant._tcp.local.",
|
||||
|
||||
// https://kno.wled.ge/
|
||||
"_wled._tcp.local.",
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package clientinfo
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
@@ -44,9 +45,9 @@ func (u *ubiosDiscover) refreshDevices() error {
|
||||
cmd := exec.Command("/usr/bin/mongo", "localhost:27117/ace", "--quiet", "--eval", `
|
||||
DBQuery.shellBatchSize = 256;
|
||||
db.user.find({name: {$exists: true, $ne: ""}}, {_id:0, mac:1, name:1});`)
|
||||
b, err := cmd.Output()
|
||||
b, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("out: %s, err: %w", string(b), err)
|
||||
}
|
||||
return u.storeDevices(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
@@ -32,6 +32,8 @@ const (
|
||||
logURLCom = apiURLCom + "/logs"
|
||||
logURLDev = apiURLDev + "/logs"
|
||||
InvalidConfigCode = 40402
|
||||
defaultTimeout = 20 * time.Second
|
||||
sendLogTimeout = 300 * time.Second
|
||||
)
|
||||
|
||||
// ResolverConfig represents Control D resolver data.
|
||||
@@ -135,7 +137,7 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
transport := apiTransport(cdDev)
|
||||
client := http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: defaultTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
@@ -176,7 +178,7 @@ func SendLogs(lr *LogsRequest, cdDev bool) error {
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
transport := apiTransport(cdDev)
|
||||
client := http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
Timeout: sendLogTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
@@ -214,19 +216,55 @@ func apiTransport(cdDev bool) *http.Transport {
|
||||
if cdDev {
|
||||
apiDomain = apiDomainDev
|
||||
}
|
||||
|
||||
// First try IPv4
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
ips := ctrld.LookupIP(apiDomain)
|
||||
if len(ips) == 0 {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr)
|
||||
return ctrldnet.Dialer.DialContext(ctx, network, addr)
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, falling back to direct connection to %s", apiDomain, addr)
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("API IPs: %v", ips)
|
||||
|
||||
// Separate IPv4 and IPv6 addresses
|
||||
var ipv4s, ipv6s []string
|
||||
for _, ip := range ips {
|
||||
if strings.Contains(ip, ":") {
|
||||
ipv6s = append(ipv6s, ip)
|
||||
} else {
|
||||
ipv4s = append(ipv4s, ip)
|
||||
}
|
||||
}
|
||||
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
addrs := make([]string, len(ips))
|
||||
for i := range ips {
|
||||
addrs[i] = net.JoinHostPort(ips[i], port)
|
||||
|
||||
// Try IPv4 first
|
||||
if len(ipv4s) > 0 {
|
||||
addrs := make([]string, len(ipv4s))
|
||||
for i, ip := range ipv4s {
|
||||
addrs[i] = net.JoinHostPort(ip, port)
|
||||
}
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
if conn, err := d.DialContext(ctx, "tcp4", addrs, ctrld.ProxyLogger.Load()); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
return d.DialContext(ctx, network, addrs)
|
||||
|
||||
// Fall back to IPv6 if available
|
||||
if len(ipv6s) > 0 {
|
||||
addrs := make([]string, len(ipv6s))
|
||||
for i, ip := range ipv6s {
|
||||
addrs[i] = net.JoinHostPort(ip, port)
|
||||
}
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
return d.DialContext(ctx, "tcp6", addrs, ctrld.ProxyLogger.Load())
|
||||
}
|
||||
|
||||
// Final fallback to direct connection
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
|
||||
|
||||
@@ -3,6 +3,7 @@ package net
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"tailscale.com/logtail/backoff"
|
||||
)
|
||||
|
||||
@@ -26,7 +28,8 @@ var Dialer = &net.Dialer{
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := ParallelDialer{}
|
||||
d.Timeout = 10 * time.Second
|
||||
return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS})
|
||||
l := zerolog.New(io.Discard)
|
||||
return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS}, &l)
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -49,8 +52,12 @@ func init() {
|
||||
}
|
||||
|
||||
func supportIPv6(ctx context.Context) bool {
|
||||
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "443"))
|
||||
return err == nil
|
||||
c, err := probeStackDialer.DialContext(ctx, "tcp6", v6BootstrapDNS)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
c.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func supportListenIPv6Local() bool {
|
||||
@@ -133,7 +140,7 @@ type ParallelDialer struct {
|
||||
net.Dialer
|
||||
}
|
||||
|
||||
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) {
|
||||
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string, logger *zerolog.Logger) (net.Conn, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
@@ -153,11 +160,16 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs
|
||||
for _, addr := range addrs {
|
||||
go func(addr string) {
|
||||
defer wg.Done()
|
||||
logger.Debug().Msgf("dialing to %s", addr)
|
||||
conn, err := d.Dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
logger.Debug().Msgf("failed to dial %s: %v", addr, err)
|
||||
}
|
||||
select {
|
||||
case ch <- ¶llelDialerResult{conn: conn, err: err}:
|
||||
case <-done:
|
||||
if conn != nil {
|
||||
logger.Debug().Msgf("connection closed: %s", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
@@ -168,6 +180,7 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs
|
||||
for res := range ch {
|
||||
if res.err == nil {
|
||||
cancel()
|
||||
logger.Debug().Msgf("connected to %s", res.conn.RemoteAddr())
|
||||
return res.conn, res.err
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
|
||||
@@ -2,10 +2,13 @@ package openwrt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -15,10 +18,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "openwrt"
|
||||
openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
|
||||
Name = "openwrt"
|
||||
openwrtDNSMasqConfigName = "ctrld.conf"
|
||||
openwrtDNSMasqDefaultConfigDir = "/tmp/dnsmasq.d"
|
||||
)
|
||||
|
||||
var openwrtDnsmasqDefaultConfigPath = filepath.Join(openwrtDNSMasqDefaultConfigDir, openwrtDNSMasqConfigName)
|
||||
|
||||
type Openwrt struct {
|
||||
cfg *ctrld.Config
|
||||
dnsmasqCacheSize string
|
||||
@@ -67,7 +73,7 @@ func (o *Openwrt) Setup() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil {
|
||||
if err := os.WriteFile(dnsmasqConfPathFromUbus(), []byte(data), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
@@ -82,7 +88,7 @@ func (o *Openwrt) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
// Remove the custom dnsmasq config
|
||||
if err := os.Remove(openwrtDNSMasqConfigPath); err != nil {
|
||||
if err := os.Remove(dnsmasqConfPathFromUbus()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -126,3 +132,60 @@ func uci(args ...string) (string, error) {
|
||||
}
|
||||
return strings.TrimSpace(stdout.String()), nil
|
||||
}
|
||||
|
||||
// openwrtServiceList represents openwrt services config.
|
||||
type openwrtServiceList struct {
|
||||
Dnsmasq dnsmasqConf `json:"dnsmasq"`
|
||||
}
|
||||
|
||||
// dnsmasqConf represents dnsmasq config.
|
||||
type dnsmasqConf struct {
|
||||
Instances map[string]confInstances `json:"instances"`
|
||||
}
|
||||
|
||||
// confInstances represents an instance config of a service.
|
||||
type confInstances struct {
|
||||
Mount map[string]string `json:"mount"`
|
||||
}
|
||||
|
||||
// dnsmasqConfPath returns the dnsmasq config path.
|
||||
//
|
||||
// Since version 24.10, openwrt makes some changes to dnsmasq to support
|
||||
// multiple instances of dnsmasq. This change causes breaking changes to
|
||||
// software which depends on the default dnsmasq path.
|
||||
//
|
||||
// There are some discussion/PRs in openwrt repo to address this:
|
||||
//
|
||||
// - https://github.com/openwrt/openwrt/pull/16806
|
||||
// - https://github.com/openwrt/openwrt/pull/16890
|
||||
//
|
||||
// In the meantime, workaround this problem by querying the actual config path
|
||||
// by querying ubus service list.
|
||||
func dnsmasqConfPath(r io.Reader) string {
|
||||
var svc openwrtServiceList
|
||||
if err := json.NewDecoder(r).Decode(&svc); err != nil {
|
||||
return openwrtDnsmasqDefaultConfigPath
|
||||
}
|
||||
for _, inst := range svc.Dnsmasq.Instances {
|
||||
for mount := range inst.Mount {
|
||||
dirName := filepath.Base(mount)
|
||||
parts := strings.Split(dirName, ".")
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
if parts[0] == "dnsmasq" && parts[len(parts)-1] == "d" {
|
||||
return filepath.Join(mount, openwrtDNSMasqConfigName)
|
||||
}
|
||||
}
|
||||
}
|
||||
return openwrtDnsmasqDefaultConfigPath
|
||||
}
|
||||
|
||||
// dnsmasqConfPathFromUbus get dnsmasq config path from ubus service list.
|
||||
func dnsmasqConfPathFromUbus() string {
|
||||
output, err := exec.Command("ubus", "call", "service", "list").Output()
|
||||
if err != nil {
|
||||
return openwrtDnsmasqDefaultConfigPath
|
||||
}
|
||||
return dnsmasqConfPath(bytes.NewReader(output))
|
||||
}
|
||||
|
||||
58
internal/router/openwrt/openwrt_test.go
Normal file
58
internal/router/openwrt/openwrt_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package openwrt
|
||||
|
||||
import (
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Sample output from https://github.com/openwrt/openwrt/pull/16806#issuecomment-2448255734
|
||||
const ubusDnsmasqBefore2410 = `{
|
||||
"dnsmasq": {
|
||||
"instances": {
|
||||
"guest_dns": {
|
||||
"mount": {
|
||||
"/tmp/dnsmasq.d": "0",
|
||||
"/var/run/dnsmasq/": "1"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
const ubusDnsmasq2410 = `{
|
||||
"dnsmasq": {
|
||||
"instances": {
|
||||
"guest_dns": {
|
||||
"mount": {
|
||||
"/tmp/dnsmasq.guest_dns.d": "0",
|
||||
"/var/run/dnsmasq/": "1"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
func Test_dnsmasqConfPath(t *testing.T) {
|
||||
var dnsmasq2410expected = filepath.Join("/tmp/dnsmasq.guest_dns.d", openwrtDNSMasqConfigName)
|
||||
tests := []struct {
|
||||
name string
|
||||
in io.Reader
|
||||
expected string
|
||||
}{
|
||||
{"empty", strings.NewReader(""), openwrtDnsmasqDefaultConfigPath},
|
||||
{"invalid", strings.NewReader("}}"), openwrtDnsmasqDefaultConfigPath},
|
||||
{"before 24.10", strings.NewReader(ubusDnsmasqBefore2410), openwrtDnsmasqDefaultConfigPath},
|
||||
{"24.10", strings.NewReader(ubusDnsmasq2410), dnsmasq2410expected},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := dnsmasqConfPath(tc.in); got != tc.expected {
|
||||
t.Errorf("dnsmasqConfPath() = %v, want %v", got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -215,6 +215,20 @@ func LeaseFilesDir() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ServiceDependencies returns list of dependencies that ctrld services needs on this router.
|
||||
// See https://pkg.go.dev/github.com/kardianos/service#Config for list format.
|
||||
func ServiceDependencies() []string {
|
||||
if Name() == ubios.Name {
|
||||
// On Ubios, ctrld needs to start after unifi-mongodb,
|
||||
// so it can query custom client info mapping.
|
||||
return []string{
|
||||
"Wants=unifi-mongodb.service",
|
||||
"After=unifi-mongodb.service",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func distroName() string {
|
||||
switch {
|
||||
case bytes.HasPrefix(unameO(), []byte("DD-WRT")):
|
||||
|
||||
@@ -45,11 +45,15 @@ func (s *tomatoSvc) Platform() string {
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) configPath() string {
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
bin := s.Config.Executable
|
||||
if bin == "" {
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
bin = path
|
||||
}
|
||||
return path + ".startup"
|
||||
return bin + ".startup"
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) template() *template.Template {
|
||||
|
||||
@@ -219,6 +219,8 @@ const ubiosBootSystemdService = `[Unit]
|
||||
Description=Run ctrld On Startup UDM
|
||||
Wants=network-online.target
|
||||
After=network-online.target
|
||||
Wants=unifi-mongodb
|
||||
After=unifi-mongodb
|
||||
StartLimitIntervalSec=500
|
||||
StartLimitBurst=5
|
||||
|
||||
|
||||
5
log.go
5
log.go
@@ -9,11 +9,6 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func init() {
|
||||
l := zerolog.New(io.Discard)
|
||||
ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// ProxyLog emits the log record for proxy operations.
|
||||
// The caller should set it only once.
|
||||
// DEPRECATED: use ProxyLogger instead.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
@@ -155,6 +156,8 @@ func getDHCPNameservers(iface string) ([]string, error) {
|
||||
}
|
||||
|
||||
func getAllDHCPNameservers() []string {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -213,5 +216,67 @@ func getAllDHCPNameservers() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// if we have static DNS servers saved for the current default route, we should add them to the list
|
||||
drIfaceName, err := netmon.DefaultRouteInterface()
|
||||
Log(context.Background(), logger.Debug(), "checking for static DNS servers for default route interface: %s", drIfaceName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get default route interface: %v", err)
|
||||
} else {
|
||||
drIface, err := net.InterfaceByName(drIfaceName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get interface by name %s: %v", drIfaceName, err)
|
||||
} else if drIface != nil {
|
||||
if _, err := patchNetIfaceName(drIface); err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to patch interface name %s: %v", drIfaceName, err)
|
||||
}
|
||||
staticNs, file := SavedStaticNameservers(drIface)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"static dns servers from %s: %v", file, staticNs)
|
||||
if len(staticNs) > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Adding static DNS servers from %s: %v", drIface.Name, staticNs)
|
||||
allNameservers = append(allNameservers, staticNs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allNameservers
|
||||
}
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
}
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
scanner := bufio.NewScanner(r)
|
||||
prevLine := ""
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, "*") {
|
||||
// Network services is disabled.
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(line, "Device: "+ifaceName) {
|
||||
prevLine = line
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(prevLine, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -62,10 +62,7 @@ func dnsFromAdapter() []string {
|
||||
var ns []string
|
||||
var err error
|
||||
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
for i := 0; i < maxDNSAdapterRetries; i++ {
|
||||
if ctx.Err() != nil {
|
||||
@@ -111,10 +108,8 @@ func dnsFromAdapter() []string {
|
||||
}
|
||||
|
||||
func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
// Check context before making the call
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
@@ -303,6 +298,28 @@ func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// if we have static DNS servers saved for the current default route, we should add them to the list
|
||||
drIfaceName, err := netmon.DefaultRouteInterface()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get default route interface: %v", err)
|
||||
} else {
|
||||
drIface, err := net.InterfaceByName(drIfaceName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get interface by name %s: %v", drIfaceName, err)
|
||||
} else {
|
||||
staticNs, file := SavedStaticNameservers(drIface)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"static dns servers from %s: %v", file, staticNs)
|
||||
if len(staticNs) > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Adding static DNS servers from %s: %v", drIfaceName, staticNs)
|
||||
ns = append(ns, staticNs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ns) == 0 {
|
||||
return nil, fmt.Errorf("no valid DNS servers found")
|
||||
}
|
||||
@@ -320,10 +337,8 @@ func nameserversFromResolvconf() []string {
|
||||
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
||||
// Returns whether it's domain joined and the domain name if available
|
||||
func checkDomainJoined() bool {
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
var domain *uint16
|
||||
var status uint32
|
||||
|
||||
@@ -400,10 +415,7 @@ func validInterfaces() map[string]struct{} {
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
//load the logger
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
|
||||
2
net.go
2
net.go
@@ -18,6 +18,7 @@ const ipv6ProbingInterval = 10 * time.Second
|
||||
|
||||
func hasIPv6() bool {
|
||||
hasIPv6Once.Do(func() {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "checking for IPv6 availability once")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
val := ctrldnet.IPv6Available(ctx)
|
||||
@@ -43,6 +44,7 @@ func probingIPv6(ctx context.Context, old bool) {
|
||||
if ipv6Available.CompareAndSwap(old, cur) {
|
||||
old = cur
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "IPv6 availability: %v", cur)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
95
resolver.go
95
resolver.go
@@ -48,7 +48,15 @@ const (
|
||||
|
||||
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
|
||||
var localResolver = newLocalResolver()
|
||||
var localResolver Resolver
|
||||
|
||||
func init() {
|
||||
// Initializing ProxyLogger here, so other places don't have to do nil check.
|
||||
l := zerolog.New(io.Discard)
|
||||
ProxyLogger.Store(&l)
|
||||
|
||||
localResolver = newLocalResolver()
|
||||
}
|
||||
|
||||
var (
|
||||
resolverMutex sync.Mutex
|
||||
@@ -91,10 +99,8 @@ func availableNameservers() []string {
|
||||
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
||||
|
||||
//load the logger
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs)
|
||||
|
||||
@@ -193,9 +199,12 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
case ResolverTypeDOQ:
|
||||
return &doqResolver{uc: uc}, nil
|
||||
case ResolverTypeOS:
|
||||
resolverMutex.Lock()
|
||||
if or == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver")
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
resolverMutex.Unlock()
|
||||
return or, nil
|
||||
case ResolverTypeLegacy:
|
||||
return &legacyResolver{uc: uc}, nil
|
||||
@@ -277,14 +286,29 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
nss = append(nss, (*p)...)
|
||||
}
|
||||
numServers := len(nss) + len(publicServers)
|
||||
|
||||
// If this is a LAN query, skip public DNS.
|
||||
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
|
||||
|
||||
// remove controldPublicDnsWithPort from publicServers for LAN queries
|
||||
// this is to prevent DoS for high frequency local requests
|
||||
if ok && lan {
|
||||
numServers -= len(publicServers)
|
||||
if index := slices.Index(publicServers, controldPublicDnsWithPort); index != -1 {
|
||||
publicServers = slices.Delete(publicServers, index, index+1)
|
||||
numServers--
|
||||
}
|
||||
}
|
||||
question := ""
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
question = msg.Question[0].Name
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers)
|
||||
|
||||
// New check: If no resolvers are available, return an error.
|
||||
if numServers == 0 {
|
||||
return nil, errors.New("no nameservers available")
|
||||
return nil, errors.New("no nameservers available for query")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -320,10 +344,6 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
}(server)
|
||||
}
|
||||
}
|
||||
do(nss, true)
|
||||
if !lan {
|
||||
do(publicServers, false)
|
||||
}
|
||||
|
||||
logAnswer := func(server string) {
|
||||
host, _, err := net.SplitHostPort(server)
|
||||
@@ -333,6 +353,18 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host)
|
||||
}
|
||||
|
||||
// try local nameservers
|
||||
if len(nss) > 0 {
|
||||
do(nss, true)
|
||||
}
|
||||
|
||||
// we must always try the public servers too, since DCHP may have only public servers
|
||||
// this is okay to do since we always prefer LAN nameserver responses
|
||||
if len(publicServers) > 0 {
|
||||
do(publicServers, false)
|
||||
}
|
||||
|
||||
var (
|
||||
nonSuccessAnswer *dns.Msg
|
||||
nonSuccessServer string
|
||||
@@ -353,33 +385,49 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
case res.server == controldPublicDnsWithPort:
|
||||
controldSuccessAnswer = res.answer
|
||||
case !res.lan:
|
||||
// if there are no LAN nameservers, we should not wait
|
||||
// just use the first response
|
||||
if len(nss) == 0 {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server)
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
}
|
||||
publicResponses = append(publicResponses, publicResponse{
|
||||
answer: res.answer,
|
||||
server: res.server,
|
||||
})
|
||||
}
|
||||
case res.answer != nil:
|
||||
nonSuccessAnswer = res.answer
|
||||
nonSuccessServer = res.server
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d",
|
||||
res.server, res.answer.Rcode)
|
||||
// When there are no LAN nameservers, we should not wait
|
||||
// for other nameservers to respond.
|
||||
if len(nss) == 0 {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer")
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
}
|
||||
nonSuccessAnswer = res.answer
|
||||
nonSuccessServer = res.server
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
|
||||
if len(publicResponses) > 0 {
|
||||
resp := publicResponses[0]
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server)
|
||||
logAnswer(resp.server)
|
||||
return resp.answer, nil
|
||||
}
|
||||
if controldSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort)
|
||||
logAnswer(controldPublicDnsWithPort)
|
||||
return controldSuccessAnswer, nil
|
||||
}
|
||||
if nonSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer)
|
||||
logAnswer(nonSuccessServer)
|
||||
return nonSuccessAnswer, nil
|
||||
}
|
||||
@@ -428,9 +476,13 @@ func LookupIP(domain string) []string {
|
||||
}
|
||||
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
resolverMutex.Lock()
|
||||
if or == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("Initialize OS resolver in lookupIP")
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
resolverMutex.Unlock()
|
||||
|
||||
nss := *or.lanServers.Load()
|
||||
nss = append(nss, *or.publicServers.Load()...)
|
||||
if withBootstrapDNS {
|
||||
@@ -510,6 +562,9 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
|
||||
// - Gateway IP address (depends on OS).
|
||||
// - Input servers.
|
||||
func NewBootstrapResolver(servers ...string) Resolver {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers)
|
||||
nss := defaultNameservers()
|
||||
nss = append([]string{controldPublicDnsWithPort}, nss...)
|
||||
for _, ns := range servers {
|
||||
@@ -526,6 +581,11 @@ func NewBootstrapResolver(servers ...string) Resolver {
|
||||
//
|
||||
// This is useful for doing PTR lookup in LAN network.
|
||||
func NewPrivateResolver() Resolver {
|
||||
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(), "NewPrivateResolver called")
|
||||
|
||||
nss := defaultNameservers()
|
||||
resolveConfNss := nameserversFromResolvconf()
|
||||
localRfc1918Addrs := Rfc1918Addresses()
|
||||
@@ -570,6 +630,9 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
// newResolverWithNameserver returns an OS resolver from given nameservers list.
|
||||
// The caller must ensure each server in list is formed "ip:53".
|
||||
func newResolverWithNameserver(nameservers []string) *osResolver {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(), "newResolverWithNameserver called with nameservers: %v", nameservers)
|
||||
r := &osResolver{}
|
||||
var publicNss []string
|
||||
var lanNss []string
|
||||
|
||||
@@ -70,41 +70,59 @@ func Test_osResolver_ResolveLanHostname(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
ns := make([]string, 0, 2)
|
||||
servers := make([]*dns.Server, 0, 2)
|
||||
handlers := []dns.Handler{
|
||||
// Set up a LAN nameserver that returns a success response.
|
||||
lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") // 127.0.0.1 is considered LAN (loopback)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on LAN address: %v", err)
|
||||
}
|
||||
lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, successHandler())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to run LAN test server: %v", err)
|
||||
}
|
||||
defer lanServer.Shutdown()
|
||||
|
||||
// Set up two public nameservers that return non-success responses.
|
||||
publicHandlers := []dns.Handler{
|
||||
nonSuccessHandlerWithRcode(dns.RcodeRefused),
|
||||
nonSuccessHandlerWithRcode(dns.RcodeNameError),
|
||||
successHandler(),
|
||||
}
|
||||
for i := range handlers {
|
||||
var publicNS []string
|
||||
var publicServers []*dns.Server
|
||||
for _, handler := range publicHandlers {
|
||||
pc, err := net.ListenPacket("udp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
t.Fatalf("failed to listen on public address: %v", err)
|
||||
}
|
||||
|
||||
s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i])
|
||||
s, addr, err := runLocalPacketConnTestServer(t, pc, handler)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
t.Fatalf("failed to run public test server: %v", err)
|
||||
}
|
||||
ns = append(ns, addr)
|
||||
servers = append(servers, s)
|
||||
publicNS = append(publicNS, addr)
|
||||
publicServers = append(publicServers, s)
|
||||
}
|
||||
defer func() {
|
||||
for _, server := range servers {
|
||||
server.Shutdown()
|
||||
for _, s := range publicServers {
|
||||
s.Shutdown()
|
||||
}
|
||||
}()
|
||||
|
||||
// We now create an osResolver which has both a LAN and public nameserver.
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServers.Store(&ns)
|
||||
// Explicitly store the LAN nameserver.
|
||||
resolver.lanServers.Store(&[]string{lanAddr})
|
||||
// And store the public nameservers.
|
||||
resolver.publicServers.Store(&publicNS)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
answer, err := resolver.Resolve(context.Background(), msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Since a LAN nameserver is available and returns a success answer, we expect RcodeSuccess.
|
||||
if answer.Rcode != dns.RcodeSuccess {
|
||||
t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode])
|
||||
t.Errorf("expected a success answer from LAN nameserver (RcodeSuccess) but got: %s", dns.RcodeToString[answer.Rcode])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
79
staticdns.go
Normal file
79
staticdns.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var homedir string
|
||||
|
||||
// absHomeDir returns the absolute path to given filename using home directory as root dir.
|
||||
func absHomeDir(filename string) string {
|
||||
if homedir != "" {
|
||||
return filepath.Join(homedir, filename)
|
||||
}
|
||||
dir, err := userHomeDir()
|
||||
if err != nil {
|
||||
return filename
|
||||
}
|
||||
return filepath.Join(dir, filename)
|
||||
}
|
||||
|
||||
func dirWritable(dir string) (bool, error) {
|
||||
f, err := os.CreateTemp(dir, "")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
return true, f.Close()
|
||||
}
|
||||
|
||||
func userHomeDir() (string, error) {
|
||||
// viper will expand for us.
|
||||
if runtime.GOOS == "windows" {
|
||||
// If we're on windows, use the install path for this.
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Dir(exePath), nil
|
||||
}
|
||||
dir := "/etc/controld"
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return os.UserHomeDir() // fallback to user home directory
|
||||
}
|
||||
if ok, _ := dirWritable(dir); !ok {
|
||||
return os.UserHomeDir()
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
// SavedStaticDnsSettingsFilePath returns the file path where the static DNS settings
|
||||
// for the provided interface are saved.
|
||||
func SavedStaticDnsSettingsFilePath(iface *net.Interface) string {
|
||||
// The file is stored in the user home directory under a hidden file.
|
||||
return absHomeDir(".dns_" + iface.Name)
|
||||
}
|
||||
|
||||
// SavedStaticNameservers returns the stored static nameservers for the given interface.
|
||||
func SavedStaticNameservers(iface *net.Interface) ([]string, string) {
|
||||
file := SavedStaticDnsSettingsFilePath(iface)
|
||||
data, err := os.ReadFile(file)
|
||||
if err != nil || len(data) == 0 {
|
||||
return nil, file
|
||||
}
|
||||
saveValues := strings.Split(string(data), ",")
|
||||
var ns []string
|
||||
for _, v := range saveValues {
|
||||
// Skip any IP that is loopback
|
||||
if ip := net.ParseIP(v); ip != nil && ip.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
ns = append(ns, v)
|
||||
}
|
||||
return ns, file
|
||||
}
|
||||
Reference in New Issue
Block a user