mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
998b9a5c5d | ||
|
|
0084e9ef26 | ||
|
|
122600bff2 | ||
|
|
41846b6d4c | ||
|
|
dfbcb1489d | ||
|
|
684019c2e3 | ||
|
|
e92619620d | ||
|
|
cebfd12d5c | ||
|
|
874ff01ab8 | ||
|
|
0bb8703f78 | ||
|
|
0bb51aa71d | ||
|
|
af2c1c87e0 | ||
|
|
8939debbc0 | ||
|
|
7591a0ccc6 | ||
|
|
c3ff8182af | ||
|
|
5897c174d3 | ||
|
|
f9a3f4c045 | ||
|
|
a2cb895cdc | ||
|
|
2bebe93e47 | ||
|
|
28ec1869fc | ||
|
|
17f6d7a77b | ||
|
|
9e6e647ff8 | ||
|
|
a2116e5eb5 | ||
|
|
564c9ef712 | ||
|
|
856abb71b7 | ||
|
|
0a30fdea69 | ||
|
|
4f125cf107 | ||
|
|
494d8be777 | ||
|
|
cd9c750884 | ||
|
|
91d319804b | ||
|
|
180eae60f2 | ||
|
|
d01f5c2777 | ||
|
|
294a90a807 | ||
|
|
c3b4ae9c79 | ||
|
|
09188bedf7 | ||
|
|
4614b98e94 | ||
|
|
990bc620f7 | ||
|
|
efb5a92571 | ||
|
|
8e0a96a44c | ||
|
|
43ff2f648c | ||
|
|
4816a09e3a | ||
|
|
3fea92c8b1 | ||
|
|
63f959c951 | ||
|
|
44ba6aadd9 | ||
|
|
d88cf52b4e | ||
|
|
58a00ea24a | ||
|
|
712b23a4bb | ||
|
|
baf836557c | ||
|
|
904b23eeac | ||
|
|
6aafe445f5 | ||
|
|
ebd516855b | ||
|
|
df4e04719e | ||
|
|
2440d922c6 |
@@ -5,10 +5,11 @@ type ClientInfoCtxKey struct{}
|
||||
|
||||
// ClientInfo represents ctrld's clients information.
|
||||
type ClientInfo struct {
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Self bool
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Self bool
|
||||
ClientIDPref string
|
||||
}
|
||||
|
||||
// LeaseFileFormat specifies the format of DHCP lease file.
|
||||
|
||||
386
cmd/cli/cli.go
386
cmd/cli/cli.go
@@ -144,6 +144,7 @@ func initCLI() {
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
@@ -158,6 +159,7 @@ func initCLI() {
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
checkStrFlagEmpty(cmd, cdUidFlagName)
|
||||
checkStrFlagEmpty(cmd, cdOrgFlagName)
|
||||
validateCdAndNextDNSFlags()
|
||||
sc := &service.Config{}
|
||||
*sc = *svcConfig
|
||||
osArgs := os.Args[2:]
|
||||
@@ -176,6 +178,9 @@ func initCLI() {
|
||||
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
|
||||
sc.Arguments = append(sc.Arguments, "--cd="+cdUID)
|
||||
}
|
||||
if cdUID != "" {
|
||||
validateCdUpstreamProtocol()
|
||||
}
|
||||
|
||||
p := &prog{
|
||||
router: router.New(&cfg, cdUID != ""),
|
||||
@@ -223,7 +228,7 @@ func initCLI() {
|
||||
}()
|
||||
}
|
||||
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
tryReadingConfigWithNotice(writeDefaultConfig, true)
|
||||
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
@@ -231,6 +236,10 @@ func initCLI() {
|
||||
|
||||
initLogging()
|
||||
|
||||
if nextdns != "" {
|
||||
removeNextDNSFromArgs(sc)
|
||||
}
|
||||
|
||||
// Explicitly passing config, so on system where home directory could not be obtained,
|
||||
// or sub-process env is different with the parent, we still behave correctly and use
|
||||
// the expected config file.
|
||||
@@ -244,16 +253,20 @@ func initCLI() {
|
||||
return
|
||||
}
|
||||
|
||||
if router.Name() != "" {
|
||||
if router.Name() != "" && iface != "" {
|
||||
mainLog.Load().Debug().Msg("cleaning up router before installing")
|
||||
_ = p.router.Cleanup()
|
||||
}
|
||||
|
||||
tasks := []task{
|
||||
{s.Stop, false},
|
||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true},
|
||||
{s.Uninstall, false},
|
||||
{s.Install, false},
|
||||
{s.Start, true},
|
||||
// Note that startCmd do not actually write ControlD config, but the config file was
|
||||
// generated after s.Start, so we notice users here for consistent with nextdns mode.
|
||||
{noticeWritingControlDConfig, false},
|
||||
}
|
||||
mainLog.Load().Notice().Msg("Starting service")
|
||||
if doTasks(tasks) {
|
||||
@@ -281,7 +294,7 @@ func initCLI() {
|
||||
}
|
||||
},
|
||||
}
|
||||
// Keep these flags in sync with runCmd above, except for "-d".
|
||||
// Keep these flags in sync with runCmd above, except for "-d"/"--nextdns".
|
||||
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
@@ -295,6 +308,8 @@ func initCLI() {
|
||||
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = startCmd.Flags().MarkHidden("dev")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
|
||||
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
|
||||
routerCmd := &cobra.Command{
|
||||
Use: "setup",
|
||||
@@ -367,6 +382,10 @@ func initCLI() {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("service not installed")
|
||||
return
|
||||
}
|
||||
initLogging()
|
||||
|
||||
tasks := []task{
|
||||
@@ -388,6 +407,50 @@ func initCLI() {
|
||||
},
|
||||
}
|
||||
|
||||
reloadCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
dir, err := userHomeDir()
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir")
|
||||
}
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
resp, err := cc.post(reloadPath, nil)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
mainLog.Load().Notice().Msg("Service reloaded")
|
||||
case http.StatusCreated:
|
||||
s, err := newService(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.")
|
||||
mainLog.Load().Warn().Msg("Restarting service")
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return
|
||||
}
|
||||
restartCmd.Run(cmd, args)
|
||||
default:
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("could not read response from control server")
|
||||
}
|
||||
mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf))
|
||||
}
|
||||
},
|
||||
}
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
@@ -503,9 +566,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
Short: "Manage ctrld service",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
statusCmd.Use,
|
||||
startCmd.Use,
|
||||
stopCmd.Use,
|
||||
restartCmd.Use,
|
||||
reloadCmd.Use,
|
||||
statusCmd.Use,
|
||||
uninstallCmd.Use,
|
||||
interfacesCmd.Use,
|
||||
@@ -514,6 +578,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
serviceCmd.AddCommand(startCmd)
|
||||
serviceCmd.AddCommand(stopCmd)
|
||||
serviceCmd.AddCommand(restartCmd)
|
||||
serviceCmd.AddCommand(reloadCmd)
|
||||
serviceCmd.AddCommand(statusCmd)
|
||||
serviceCmd.AddCommand(uninstallCmd)
|
||||
serviceCmd.AddCommand(interfacesCmd)
|
||||
@@ -568,6 +633,19 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
}
|
||||
rootCmd.AddCommand(restartCmdAlias)
|
||||
|
||||
reloadCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
reloadCmd.Run(cmd, args)
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(reloadCmdAlias)
|
||||
|
||||
statusCmdAlias := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
@@ -688,6 +766,7 @@ func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struc
|
||||
homedir = appConfig.HomeDir
|
||||
verbose = appConfig.Verbose
|
||||
cdUID = appConfig.CdUID
|
||||
cdUpstreamProto = ctrld.ResolverTypeDOH
|
||||
logPath = appConfig.LogPath
|
||||
run(appCallback, stopCh)
|
||||
}
|
||||
@@ -699,10 +778,12 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
}
|
||||
waitCh := make(chan struct{})
|
||||
p := &prog{
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
cfg: &cfg,
|
||||
appCallback: appCallback,
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
reloadCh: make(chan struct{}),
|
||||
reloadDoneCh: make(chan struct{}),
|
||||
cfg: &cfg,
|
||||
appCallback: appCallback,
|
||||
}
|
||||
if homedir == "" {
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
@@ -740,9 +821,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
|
||||
readBase64Config(configBase64)
|
||||
processNoConfigFlags(noConfigStart)
|
||||
p.mu.Lock()
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
processLogAndCacheFlags()
|
||||
|
||||
@@ -777,14 +860,47 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
cdUID = uid
|
||||
}
|
||||
if cdUID != "" {
|
||||
err := processCDFlags()
|
||||
if err != nil {
|
||||
appCallback.Exit(err.Error())
|
||||
return
|
||||
validateCdUpstreamProtocol()
|
||||
if err := processCDFlags(&cfg); err != nil {
|
||||
if isMobile() {
|
||||
appCallback.Exit(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
uninstallIfInvalidCdUID := func() {
|
||||
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
|
||||
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
s, err := newService(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
cdLogger.Warn().Err(err).Msg("failed to create new service")
|
||||
return
|
||||
}
|
||||
if netIface, _ := netInterface(iface); netIface != nil {
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
cdLogger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
}
|
||||
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS")
|
||||
} else {
|
||||
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
|
||||
}
|
||||
}
|
||||
|
||||
tasks := []task{{s.Uninstall, true}}
|
||||
if doTasks(tasks) {
|
||||
cdLogger.Info().Msg("uninstalled service")
|
||||
}
|
||||
cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config")
|
||||
return
|
||||
}
|
||||
}
|
||||
uninstallIfInvalidCdUID()
|
||||
}
|
||||
}
|
||||
|
||||
updated := updateListenerConfig()
|
||||
updated := updateListenerConfig(&cfg)
|
||||
|
||||
if cdUID != "" {
|
||||
processLogAndCacheFlags()
|
||||
@@ -812,7 +928,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
initLoggingWithBackup(false)
|
||||
}
|
||||
|
||||
validateConfig(&cfg)
|
||||
if err := validateConfig(&cfg); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
initCache()
|
||||
|
||||
if daemon {
|
||||
@@ -859,19 +977,21 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
if cp := router.CertPool(); cp != nil {
|
||||
rootCertPool = cp
|
||||
}
|
||||
p.onStarted = append(p.onStarted, func() {
|
||||
mainLog.Load().Debug().Msg("router setup on start")
|
||||
if err := p.router.Setup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not configure router")
|
||||
}
|
||||
})
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
mainLog.Load().Debug().Msg("router cleanup on stop")
|
||||
if err := p.router.Cleanup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
p.resetDNS()
|
||||
})
|
||||
if iface != "" {
|
||||
p.onStarted = append(p.onStarted, func() {
|
||||
mainLog.Load().Debug().Msg("router setup on start")
|
||||
if err := p.router.Setup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not configure router")
|
||||
}
|
||||
})
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
mainLog.Load().Debug().Msg("router cleanup on stop")
|
||||
if err := p.router.Cleanup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
p.resetDNS()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
close(waitCh)
|
||||
@@ -907,10 +1027,17 @@ func writeConfigFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func readConfigFile(writeDefaultConfig bool) bool {
|
||||
// readConfigFile reads in config file.
|
||||
//
|
||||
// - It writes default config file if config file not found if writeDefaultConfig is true.
|
||||
// - It emits notice message to user if notice is true.
|
||||
func readConfigFile(writeDefaultConfig, notice bool) bool {
|
||||
// If err == nil, there's a config supplied via `--config`, no default config written.
|
||||
err := v.ReadInConfig()
|
||||
if err == nil {
|
||||
if notice {
|
||||
mainLog.Load().Notice().Msg("Reading config: " + v.ConfigFileUsed())
|
||||
}
|
||||
mainLog.Load().Info().Msg("loading config file from: " + v.ConfigFileUsed())
|
||||
defaultConfigFile = v.ConfigFileUsed()
|
||||
return true
|
||||
@@ -925,7 +1052,8 @@ func readConfigFile(writeDefaultConfig bool) bool {
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
_ = updateListenerConfig()
|
||||
nop := zerolog.Nop()
|
||||
_, _ = tryUpdateListenerConfig(&cfg, &nop, true)
|
||||
if err := writeConfigFile(); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
|
||||
} else {
|
||||
@@ -933,6 +1061,9 @@ func readConfigFile(writeDefaultConfig bool) bool {
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to get default config file path: %v", err)
|
||||
}
|
||||
if cdUID == "" && nextdns == "" {
|
||||
mainLog.Load().Notice().Msg("Generating controld default config: " + fp)
|
||||
}
|
||||
mainLog.Load().Info().Msg("writing default config file to: " + fp)
|
||||
}
|
||||
return false
|
||||
@@ -1015,7 +1146,7 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
v.Set("upstream", upstream)
|
||||
}
|
||||
|
||||
func processCDFlags() error {
|
||||
func processCDFlags(cfg *ctrld.Config) error {
|
||||
logger := mainLog.Load().With().Str("mode", "cd").Logger()
|
||||
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
|
||||
bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second)
|
||||
@@ -1031,44 +1162,17 @@ func processCDFlags() error {
|
||||
}
|
||||
break
|
||||
}
|
||||
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
s, err := newService(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to create new service")
|
||||
return nil
|
||||
}
|
||||
if netIface, _ := netInterface(iface); netIface != nil {
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return nil
|
||||
}
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
logger.Warn().Err(err).Msg("something went wrong while restoring DNS")
|
||||
} else {
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
|
||||
}
|
||||
}
|
||||
|
||||
tasks := []task{{s.Uninstall, true}}
|
||||
if doTasks(tasks) {
|
||||
logger.Info().Msg("uninstalled service")
|
||||
}
|
||||
event := logger.Fatal()
|
||||
if isMobile() {
|
||||
event = logger.Warn()
|
||||
}
|
||||
event.Err(uer).Msg("failed to fetch resolver config")
|
||||
return uer
|
||||
}
|
||||
if err != nil {
|
||||
if isMobile() {
|
||||
return err
|
||||
}
|
||||
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info().Msg("generating ctrld config from Control-D configuration")
|
||||
cfg = ctrld.Config{}
|
||||
|
||||
*cfg = ctrld.Config{}
|
||||
// Fetch config, unmarshal to cfg.
|
||||
if resolverConfig.Ctrld.CustomConfig != "" {
|
||||
logger.Info().Msg("using defined custom config of Control-D resolver")
|
||||
@@ -1085,7 +1189,7 @@ func processCDFlags() error {
|
||||
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
|
||||
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: ctrld.ResolverTypeDOH,
|
||||
Type: cdUpstreamProto,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
@@ -1213,6 +1317,11 @@ func selfCheckStatus(s service.Service) service.Status {
|
||||
return service.StatusUnknown
|
||||
}
|
||||
|
||||
// Not a ctrld upstream, return status as-is.
|
||||
if cfg.FirstUpstream().VerifyDomain() == "" {
|
||||
return status
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msg("ctrld listener is ready")
|
||||
mainLog.Load().Debug().Msg("performing self-check")
|
||||
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
|
||||
@@ -1338,21 +1447,35 @@ func userHomeDir() (string, error) {
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
// tryReadingConfig is like tryReadingConfigWithNotice, with notice set to false.
|
||||
func tryReadingConfig(writeDefaultConfig bool) {
|
||||
tryReadingConfigWithNotice(writeDefaultConfig, false)
|
||||
}
|
||||
|
||||
// tryReadingConfigWithNotice tries reading in config files, either specified by user or from default
|
||||
// locations. If notice is true, emitting a notice message to user which config file was read.
|
||||
func tryReadingConfigWithNotice(writeDefaultConfig, notice bool) {
|
||||
// --config is specified.
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
readConfigFile(false)
|
||||
readConfigFile(false, notice)
|
||||
return
|
||||
}
|
||||
// no config start or base64 config mode.
|
||||
if !writeDefaultConfig {
|
||||
return
|
||||
}
|
||||
readConfig(writeDefaultConfig)
|
||||
readConfigWithNotice(writeDefaultConfig, notice)
|
||||
}
|
||||
|
||||
// readConfig calls readConfigWithNotice with notice set to false.
|
||||
func readConfig(writeDefaultConfig bool) {
|
||||
readConfigWithNotice(writeDefaultConfig, false)
|
||||
}
|
||||
|
||||
// readConfigWithNotice calls readConfigFile with config file set to ctrld.toml
|
||||
// or config.toml for compatible with earlier versions of ctrld.
|
||||
func readConfigWithNotice(writeDefaultConfig, notice bool) {
|
||||
configs := []struct {
|
||||
name string
|
||||
written bool
|
||||
@@ -1369,7 +1492,7 @@ func readConfig(writeDefaultConfig bool) {
|
||||
for _, config := range configs {
|
||||
ctrld.SetConfigNameWithPath(v, config.name, dir)
|
||||
v.SetConfigFile(configPath)
|
||||
if readConfigFile(config.written) {
|
||||
if readConfigFile(config.written, notice) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -1401,18 +1524,17 @@ func uninstall(p *prog, s service.Service) {
|
||||
}
|
||||
}
|
||||
|
||||
func validateConfig(cfg *ctrld.Config) {
|
||||
err := ctrld.ValidateConfig(validator.New(), cfg)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
var ve validator.ValidationErrors
|
||||
if errors.As(err, &ve) {
|
||||
for _, fe := range ve {
|
||||
mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe))
|
||||
func validateConfig(cfg *ctrld.Config) error {
|
||||
if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil {
|
||||
var ve validator.ValidationErrors
|
||||
if errors.As(err, &ve) {
|
||||
for _, fe := range ve {
|
||||
mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOTE: Add more case here once new validation tag is used in ctrld.Config struct.
|
||||
@@ -1483,9 +1605,19 @@ func mobileListenerPort() int {
|
||||
// updateListenerConfig updates the config for listeners if not defined,
|
||||
// or defined but invalid to be used, e.g: using loopback address other
|
||||
// than 127.0.0.1 with systemd-resolved.
|
||||
func updateListenerConfig() (updated bool) {
|
||||
func updateListenerConfig(cfg *ctrld.Config) bool {
|
||||
updated, _ := tryUpdateListenerConfig(cfg, nil, true)
|
||||
return updated
|
||||
}
|
||||
|
||||
// tryUpdateListenerConfig tries updating listener config with a working one.
|
||||
// If fatal is true, and there's listen address conflicted, the function do
|
||||
// fatal error.
|
||||
func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fatal bool) (updated, ok bool) {
|
||||
ok = true
|
||||
lcc := make(map[string]*listenerConfigCheck)
|
||||
cdMode := cdUID != ""
|
||||
nextdnsMode := nextdns != ""
|
||||
for n, listener := range cfg.Listener {
|
||||
lcc[n] = &listenerConfigCheck{}
|
||||
if listener.IP == "" {
|
||||
@@ -1497,12 +1629,17 @@ func updateListenerConfig() (updated bool) {
|
||||
lcc[n].Port = true
|
||||
}
|
||||
// In cd mode, we always try to pick an ip:port pair to work.
|
||||
if cdMode {
|
||||
// Same if nextdns resolver is used.
|
||||
if cdMode || nextdnsMode {
|
||||
lcc[n].IP = true
|
||||
lcc[n].Port = true
|
||||
}
|
||||
updated = updated || lcc[n].IP || lcc[n].Port
|
||||
}
|
||||
il := mainLog.Load()
|
||||
if infoLogger != nil {
|
||||
il = infoLogger
|
||||
}
|
||||
if isMobile() {
|
||||
// On Mobile, only use first listener, ignore others.
|
||||
firstLn := cfg.FirstListener()
|
||||
@@ -1594,7 +1731,11 @@ func updateListenerConfig() (updated bool) {
|
||||
break
|
||||
}
|
||||
if !check.IP && !check.Port {
|
||||
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
|
||||
if fatal {
|
||||
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
if tryAllPort53 {
|
||||
tryAllPort53 = false
|
||||
@@ -1605,7 +1746,7 @@ func updateListenerConfig() (updated bool) {
|
||||
listener.Port = 53
|
||||
}
|
||||
if check.IP {
|
||||
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)))
|
||||
logMsg(il.Info(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)))
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -1618,7 +1759,7 @@ func updateListenerConfig() (updated bool) {
|
||||
listener.Port = 53
|
||||
}
|
||||
if check.IP {
|
||||
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)))
|
||||
logMsg(il.Info(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)))
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -1630,7 +1771,7 @@ func updateListenerConfig() (updated bool) {
|
||||
if check.Port {
|
||||
listener.Port = 5354
|
||||
}
|
||||
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying current ip with port 5354", addr)
|
||||
logMsg(il.Info(), n, "could not listen on address: %s, trying current ip with port 5354", addr)
|
||||
continue
|
||||
}
|
||||
if tryPort5354 {
|
||||
@@ -1641,7 +1782,7 @@ func updateListenerConfig() (updated bool) {
|
||||
if check.Port {
|
||||
listener.Port = 5354
|
||||
}
|
||||
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr)
|
||||
logMsg(il.Info(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr)
|
||||
continue
|
||||
}
|
||||
if check.IP && !isZeroIP { // for "0.0.0.0" or "::", we only need to try new port.
|
||||
@@ -1655,12 +1796,19 @@ func updateListenerConfig() (updated bool) {
|
||||
listener.Port = oldPort
|
||||
}
|
||||
if listener.IP == oldIP && listener.Port == oldPort {
|
||||
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
|
||||
if fatal {
|
||||
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, pick a random ip+port", addr)
|
||||
logMsg(il.Info(), n, "could not listen on address: %s, pick a random ip+port", addr)
|
||||
attempts++
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Specific case for systemd-resolved.
|
||||
if useSystemdResolved {
|
||||
@@ -1670,7 +1818,7 @@ func updateListenerConfig() (updated bool) {
|
||||
// ip address, other than "127.0.0.1", so trying to listen on default route interface
|
||||
// address instead.
|
||||
if ip := net.ParseIP(listener.IP); ip != nil && ip.IsLoopback() && ip.String() != "127.0.0.1" {
|
||||
logMsg(mainLog.Load().Warn(), n, "using loopback interface do not work with systemd-resolved")
|
||||
logMsg(il.Info(), n, "using loopback interface do not work with systemd-resolved")
|
||||
found := false
|
||||
if netIface, _ := net.InterfaceByName(defaultIfaceName()); netIface != nil {
|
||||
addrs, _ := netIface.Addrs()
|
||||
@@ -1680,7 +1828,7 @@ func updateListenerConfig() (updated bool) {
|
||||
if err := tryListen(addr); err == nil {
|
||||
found = true
|
||||
listener.IP = netIP.IP.String()
|
||||
logMsg(mainLog.Load().Warn(), n, "use %s as listener address", listener.IP)
|
||||
logMsg(il.Info(), n, "use %s as listener address", listener.IP)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -1742,12 +1890,12 @@ func removeProvTokenFromArgs(sc *service.Config) {
|
||||
continue
|
||||
}
|
||||
// For "--cd-org XXX", skip it and mark next arg skipped.
|
||||
if x == cdOrgFlagName {
|
||||
if x == "--"+cdOrgFlagName {
|
||||
skip = true
|
||||
continue
|
||||
}
|
||||
// For "--cd-org=XXX", just skip it.
|
||||
if strings.HasPrefix(x, cdOrgFlagName+"=") {
|
||||
if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") {
|
||||
continue
|
||||
}
|
||||
a = append(a, x)
|
||||
@@ -1795,6 +1943,64 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) {
|
||||
return
|
||||
}
|
||||
if fl.Value.String() == "" {
|
||||
mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name)
|
||||
mainLog.Load().Fatal().Msgf(`flag "--%s" value must be non-empty`, fl.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func validateCdUpstreamProtocol() {
|
||||
if cdUID == "" {
|
||||
return
|
||||
}
|
||||
switch cdUpstreamProto {
|
||||
case ctrld.ResolverTypeDOH, ctrld.ResolverTypeDOH3:
|
||||
default:
|
||||
mainLog.Load().Fatal().Msg(`flag "--protocol" must be "doh" or "doh3"`)
|
||||
}
|
||||
}
|
||||
|
||||
func validateCdAndNextDNSFlags() {
|
||||
if (cdUID != "" || cdOrg != "") && nextdns != "" {
|
||||
mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName)
|
||||
}
|
||||
}
|
||||
|
||||
// removeNextDNSFromArgs removes the --nextdns from command line arguments.
|
||||
func removeNextDNSFromArgs(sc *service.Config) {
|
||||
a := sc.Arguments[:0]
|
||||
skip := false
|
||||
for _, x := range sc.Arguments {
|
||||
if skip {
|
||||
skip = false
|
||||
continue
|
||||
}
|
||||
// For "--nextdns XXX", skip it and mark next arg skipped.
|
||||
if x == "--"+nextdnsFlagName {
|
||||
skip = true
|
||||
continue
|
||||
}
|
||||
// For "--nextdns=XXX", just skip it.
|
||||
if strings.HasPrefix(x, "--"+nextdnsFlagName+"=") {
|
||||
continue
|
||||
}
|
||||
a = append(a, x)
|
||||
}
|
||||
sc.Arguments = a
|
||||
}
|
||||
|
||||
// doGenerateNextDNSConfig generates a working config with nextdns resolver.
|
||||
func doGenerateNextDNSConfig(uid string) error {
|
||||
if uid == "" {
|
||||
return nil
|
||||
}
|
||||
mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile)
|
||||
generateNextDNSConfig(uid)
|
||||
updateListenerConfig(&cfg)
|
||||
return writeConfigFile()
|
||||
}
|
||||
|
||||
func noticeWritingControlDConfig() error {
|
||||
if cdUID != "" {
|
||||
mainLog.Load().Notice().Msgf("Generating controld config: %s", defaultConfigFile)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,14 +6,18 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
contentTypeJson = "application/json"
|
||||
listClientsPath = "/clients"
|
||||
startedPath = "/started"
|
||||
reloadPath = "/reload"
|
||||
)
|
||||
|
||||
type controlServer struct {
|
||||
@@ -75,6 +79,52 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusRequestTimeout)
|
||||
}
|
||||
}))
|
||||
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
listeners := make(map[string]*ctrld.ListenerConfig)
|
||||
p.mu.Lock()
|
||||
for k, v := range p.cfg.Listener {
|
||||
listeners[k] = &ctrld.ListenerConfig{
|
||||
IP: v.IP,
|
||||
Port: v.Port,
|
||||
}
|
||||
}
|
||||
oldSvc := p.cfg.Service
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.reloadDoneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Checking for cases that we could not do a reload.
|
||||
|
||||
// 1. Listener config ip or port changes.
|
||||
for k, v := range p.cfg.Listener {
|
||||
l := listeners[k]
|
||||
if l == nil || l.IP != v.IP || l.Port != v.Port {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Service config changes.
|
||||
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, reload is done.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
}
|
||||
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/netaddr"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
@@ -25,6 +27,7 @@ import (
|
||||
|
||||
const (
|
||||
staleTTL = 60 * time.Second
|
||||
localTTL = 3600 * time.Second
|
||||
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
||||
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
||||
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
||||
@@ -37,6 +40,29 @@ var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "Private resolver",
|
||||
Type: ctrld.ResolverTypePrivate,
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
// proxyRequest contains data for proxying a DNS query to upstream.
|
||||
type proxyRequest struct {
|
||||
msg *dns.Msg
|
||||
ci *ctrld.ClientInfo
|
||||
failoverRcodes []int
|
||||
ufr *upstreamForResult
|
||||
}
|
||||
|
||||
// upstreamForResult represents the result of processing rules for a request.
|
||||
type upstreamForResult struct {
|
||||
upstreams []string
|
||||
matchedPolicy string
|
||||
matchedNetwork string
|
||||
matchedRule string
|
||||
matched bool
|
||||
}
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string) error {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
@@ -44,36 +70,58 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
|
||||
return allocErr
|
||||
}
|
||||
var failoverRcodes []int
|
||||
if listenerConfig.Policy != nil {
|
||||
failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
p.sema.acquire()
|
||||
defer p.sema.release()
|
||||
if len(m.Question) == 0 {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeFormatError)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
reqId := requestID()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
go p.detectLoop(m)
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
reqId := requestID()
|
||||
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
||||
ci := p.getClientInfo(remoteIP, m)
|
||||
ci.ClientIDPref = p.cfg.Service.ClientIDPref
|
||||
stripClientSubnet(m)
|
||||
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci)
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String())
|
||||
t := time.Now()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
||||
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, domain)
|
||||
res := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
|
||||
var answer *dns.Msg
|
||||
if !matched && listenerConfig.Restricted {
|
||||
if !res.matched && listenerConfig.Restricted {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String())
|
||||
answer = new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
} else {
|
||||
answer = p.proxy(ctx, upstreams, failoverRcodes, m, ci)
|
||||
var failoverRcode []int
|
||||
if listenerConfig.Policy != nil {
|
||||
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
answer = p.proxy(ctx, &proxyRequest{
|
||||
msg: m,
|
||||
ci: ci,
|
||||
failoverRcodes: failoverRcode,
|
||||
ufr: res,
|
||||
})
|
||||
rtt := time.Since(t)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
}
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveUDP: failed to send DNS response to client")
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -99,7 +147,7 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
// addresses of the machine. So ctrld could receive queries from LAN clients.
|
||||
if needRFC1918Listeners(listenerConfig) {
|
||||
g.Go(func() error {
|
||||
for _, addr := range rfc1918Addresses() {
|
||||
for _, addr := range ctrld.Rfc1918Addresses() {
|
||||
func() {
|
||||
listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port))
|
||||
s, errCh := runDNSServer(listenAddr, proto, handler)
|
||||
@@ -146,27 +194,24 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
// Though domain policy has higher priority than network policy, it is still
|
||||
// processed later, because policy logging want to know whether a network rule
|
||||
// is disregarded in favor of the domain level rule.
|
||||
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
|
||||
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) {
|
||||
upstreams := []string{upstreamPrefix + defaultUpstreamNum}
|
||||
matchedPolicy := "no policy"
|
||||
matchedNetwork := "no network"
|
||||
matchedRule := "no rule"
|
||||
matched := false
|
||||
res = &upstreamForResult{}
|
||||
|
||||
defer func() {
|
||||
if !matched && lc.Restricted {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", addr.String())
|
||||
return
|
||||
}
|
||||
if matched {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams)
|
||||
} else {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
}
|
||||
res.upstreams = upstreams
|
||||
res.matched = matched
|
||||
res.matchedPolicy = matchedPolicy
|
||||
res.matchedNetwork = matchedNetwork
|
||||
res.matchedRule = matchedRule
|
||||
}()
|
||||
|
||||
if lc.Policy == nil {
|
||||
return upstreams, false
|
||||
return
|
||||
}
|
||||
|
||||
do := func(policyUpstreams []string) {
|
||||
@@ -202,6 +247,19 @@ networkRules:
|
||||
}
|
||||
}
|
||||
|
||||
macRules:
|
||||
for _, rule := range lc.Policy.Macs {
|
||||
for source, targets := range rule {
|
||||
if source != "" && strings.EqualFold(source, srcMac) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
matchedNetwork = source
|
||||
networkTargets = targets
|
||||
matched = true
|
||||
break macRules
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
// There's only one entry per rule, config validation ensures this.
|
||||
for source, targets := range rule {
|
||||
@@ -213,7 +271,7 @@ networkRules:
|
||||
matchedRule = source
|
||||
do(targets)
|
||||
matched = true
|
||||
return upstreams, matched
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -222,26 +280,134 @@ networkRules:
|
||||
do(networkTargets)
|
||||
}
|
||||
|
||||
return upstreams, matched
|
||||
return
|
||||
}
|
||||
|
||||
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo) *dns.Msg {
|
||||
func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
||||
cDomainName := msg.Question[0].Name
|
||||
locked := p.ptrLoopGuard.TryLock(cDomainName)
|
||||
defer p.ptrLoopGuard.Unlock(cDomainName)
|
||||
if !locked {
|
||||
return nil
|
||||
}
|
||||
ip := ipFromARPA(cDomainName)
|
||||
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetReply(msg)
|
||||
answer.Compress = true
|
||||
answer.Answer = []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Ptr: dns.Fqdn(name),
|
||||
}}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
Mac: p.ciTable.LookupMac(ip.String()),
|
||||
IP: ip.String(),
|
||||
Hostname: name,
|
||||
})
|
||||
return answer
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
||||
q := msg.Question[0]
|
||||
hostname := strings.TrimSuffix(q.Name, ".")
|
||||
locked := p.lanLoopGuard.TryLock(hostname)
|
||||
defer p.lanLoopGuard.Unlock(hostname)
|
||||
if !locked {
|
||||
return nil
|
||||
}
|
||||
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetReply(msg)
|
||||
answer.Compress = true
|
||||
switch {
|
||||
case ip.Is4():
|
||||
answer.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(localTTL.Seconds()),
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}}
|
||||
case ip.Is6():
|
||||
answer.Answer = []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(localTTL.Seconds()),
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}}
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
Mac: p.ciTable.LookupMac(ip.String()),
|
||||
IP: ip.String(),
|
||||
Hostname: hostname,
|
||||
})
|
||||
return answer
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *dns.Msg {
|
||||
var staleAnswer *dns.Msg
|
||||
upstreams := req.ufr.upstreams
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
}
|
||||
|
||||
// LAN/PTR lookup flow:
|
||||
//
|
||||
// 1. If there's matching rule, follow it.
|
||||
// 2. Try from client info table.
|
||||
// 3. Try private resolver.
|
||||
// 4. Try remote upstream.
|
||||
isLanOrPtrQuery := false
|
||||
if req.ufr.matched {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
switch {
|
||||
case isPrivatePtrLookup(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
|
||||
return answer
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
case isLanHostnameQuery(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil {
|
||||
return answer
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
default:
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
}
|
||||
}
|
||||
|
||||
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
||||
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
|
||||
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
||||
for _, upstream := range upstreams {
|
||||
cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream))
|
||||
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
|
||||
if cachedValue == nil {
|
||||
continue
|
||||
}
|
||||
answer := cachedValue.Msg.Copy()
|
||||
answer.SetRcode(msg, answer.Rcode)
|
||||
answer.SetRcode(req.msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
|
||||
@@ -268,9 +434,9 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
return dnsResolver.Resolve(resolveCtx, msg)
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && ci != nil {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci)
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
||||
}
|
||||
answer, err := resolve1(n, upstreamConfig, msg)
|
||||
if err != nil {
|
||||
@@ -281,6 +447,11 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
go p.um.checkUpstream(upstreams[n], upstreamConfig)
|
||||
}
|
||||
}
|
||||
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
||||
var e net.Error
|
||||
if errors.As(err, &e) && e.Timeout() {
|
||||
upstreamConfig.ReBootstrap()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return answer
|
||||
@@ -297,7 +468,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, msg)
|
||||
answer := resolve(n, upstreamConfig, req.msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
|
||||
@@ -307,7 +478,13 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
}
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
|
||||
// We are doing LAN/PTR lookup using private resolver, so always process next one.
|
||||
// Except for the last, we want to send response instead of saying all upstream failed.
|
||||
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n])
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream")
|
||||
continue
|
||||
}
|
||||
@@ -315,7 +492,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
// set compression, as it is not set by default when unpacking
|
||||
answer.Compress = true
|
||||
|
||||
if p.cache != nil {
|
||||
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
||||
ttl := ttlFromMsg(answer)
|
||||
now := time.Now()
|
||||
expired := now.Add(time.Duration(ttl) * time.Second)
|
||||
@@ -323,17 +500,27 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
||||
}
|
||||
setCachedAnswerTTL(answer, now, expired)
|
||||
p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
||||
p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response")
|
||||
}
|
||||
return answer
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(msg, dns.RcodeServerFailure)
|
||||
answer.SetRcode(req.msg, dns.RcodeServerFailure)
|
||||
return answer
|
||||
}
|
||||
|
||||
func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
||||
if len(p.localUpstreams) > 0 {
|
||||
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
||||
tmp = append(tmp, p.localUpstreams...)
|
||||
tmp = append(tmp, upstreams...)
|
||||
return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp)
|
||||
}
|
||||
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
|
||||
}
|
||||
|
||||
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
||||
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
@@ -454,6 +641,23 @@ func ipAndMacFromMsg(msg *dns.Msg) (string, string) {
|
||||
return ip, mac
|
||||
}
|
||||
|
||||
// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address,
|
||||
// passing them to upstream is pointless, these cannot be used by anything on the WAN.
|
||||
func stripClientSubnet(msg *dns.Msg) {
|
||||
if opt := msg.IsEdns0(); opt != nil {
|
||||
opts := make([]dns.EDNS0, 0, len(opt.Option))
|
||||
for _, s := range opt.Option {
|
||||
if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) {
|
||||
continue
|
||||
}
|
||||
opts = append(opts, s)
|
||||
}
|
||||
if len(opts) != len(opt.Option) {
|
||||
opt.Option = opts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
||||
if ci != nil && ci.IP != "" {
|
||||
switch addr := addr.(type) {
|
||||
@@ -531,23 +735,44 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
|
||||
}
|
||||
|
||||
// If MAC is still empty here, that mean the requests are made from virtual interface,
|
||||
// like VPN/Wireguard clients, so we use whatever MAC address associated with remoteIP
|
||||
// (most likely 127.0.0.1), and ci.IP as hostname, so we can distinguish those clients.
|
||||
// like VPN/Wireguard clients, so we use ci.IP as hostname to distinguish those clients.
|
||||
if ci.Mac == "" {
|
||||
ci.Mac = p.ciTable.LookupMac(remoteIP)
|
||||
if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" {
|
||||
ci.Hostname = hostname
|
||||
} else {
|
||||
ci.Hostname = ci.IP
|
||||
p.ciTable.StoreVPNClient(ci)
|
||||
// Only use IP as hostname for IPv4 clients.
|
||||
// For Android devices, when it joins the network, it uses ctrld to resolve
|
||||
// its private DNS once and never reaches ctrld again. For each time, it uses
|
||||
// a different IPv6 address, which causes hundreds/thousands different client
|
||||
// IDs created for the same device, which is pointless.
|
||||
//
|
||||
// TODO(cuonglm): investigate whether this can be a false positive for other clients?
|
||||
if !ctrldnet.IsIPv6(ci.IP) {
|
||||
ci.Hostname = ci.IP
|
||||
p.ciTable.StoreVPNClient(ci)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
|
||||
}
|
||||
ci.Self = queryFromSelf(ci.IP)
|
||||
p.spoofLoopbackIpInClientInfo(ci)
|
||||
return ci
|
||||
}
|
||||
|
||||
// spoofLoopbackIpInClientInfo replaces loopback IPs in client info.
|
||||
//
|
||||
// - Preference IPv4.
|
||||
// - Preference RFC1918.
|
||||
func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
|
||||
if ip := net.ParseIP(ci.IP); ip == nil || !ip.IsLoopback() {
|
||||
return
|
||||
}
|
||||
if ip := p.ciTable.LookupRFC1918IPv4(ci.Mac); ip != "" {
|
||||
ci.IP = ip
|
||||
}
|
||||
}
|
||||
|
||||
// queryFromSelf reports whether the input IP is from device running ctrld.
|
||||
func queryFromSelf(ip string) bool {
|
||||
netIP := netip.MustParseAddr(ip)
|
||||
@@ -578,17 +803,86 @@ func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
|
||||
return lc.IP == "127.0.0.1" && lc.Port == 53
|
||||
}
|
||||
|
||||
func rfc1918Addresses() []string {
|
||||
var res []string
|
||||
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
addrs, _ := i.Addrs()
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok || !ipNet.IP.IsPrivate() {
|
||||
continue
|
||||
}
|
||||
res = append(res, ipNet.IP.String())
|
||||
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.
|
||||
func ipFromARPA(arpa string) net.IP {
|
||||
if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok {
|
||||
if ptrIP := net.ParseIP(arpa); ptrIP != nil {
|
||||
return net.IP{ptrIP[15], ptrIP[14], ptrIP[13], ptrIP[12]}
|
||||
}
|
||||
})
|
||||
return res
|
||||
}
|
||||
if arpa, ok := strings.CutSuffix(arpa, ".ip6.arpa."); ok {
|
||||
l := net.IPv6len * 2
|
||||
base := 16
|
||||
ip := make(net.IP, net.IPv6len)
|
||||
for i := 0; i < l && arpa != ""; i++ {
|
||||
idx := strings.LastIndexByte(arpa, '.')
|
||||
off := idx + 1
|
||||
if idx == -1 {
|
||||
idx = 0
|
||||
off = 0
|
||||
} else if idx == len(arpa)-1 {
|
||||
return nil
|
||||
}
|
||||
n, err := strconv.ParseUint(arpa[off:], base, 8)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
b := byte(n)
|
||||
ii := i / 2
|
||||
if i&1 == 1 {
|
||||
b |= ip[ii] << 4
|
||||
}
|
||||
ip[ii] = b
|
||||
arpa = arpa[:idx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network.
|
||||
func isPrivatePtrLookup(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
q := m.Question[0]
|
||||
if ip := ipFromARPA(q.Name); ip != nil {
|
||||
if addr, ok := netip.AddrFromSlice(ip); ok {
|
||||
return addr.IsPrivate() ||
|
||||
addr.IsLoopback() ||
|
||||
addr.IsLinkLocalUnicast() ||
|
||||
tsaddr.CGNATRange().Contains(addr)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname.
|
||||
func isLanHostnameQuery(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
q := m.Question[0]
|
||||
switch q.Qtype {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
return !strings.Contains(name, ".") ||
|
||||
strings.HasSuffix(name, ".domain") ||
|
||||
strings.HasSuffix(name, ".lan")
|
||||
}
|
||||
|
||||
// isWanClient reports whether the input is a WAN address.
|
||||
func isWanClient(na net.Addr) bool {
|
||||
var ip netip.Addr
|
||||
if ap, err := netip.ParseAddrPort(na.String()); err == nil {
|
||||
ip = ap.Addr()
|
||||
}
|
||||
return !ip.IsLoopback() &&
|
||||
!ip.IsPrivate() &&
|
||||
!ip.IsLinkLocalUnicast() &&
|
||||
!ip.IsLinkLocalMulticast() &&
|
||||
!tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
|
||||
@@ -67,8 +67,11 @@ func Test_canonicalName(t *testing.T) {
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
for _, nc := range p.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
@@ -81,6 +84,7 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
mac string
|
||||
defaultUpstreamNum string
|
||||
lc *ctrld.ListenerConfig
|
||||
domain string
|
||||
@@ -88,11 +92,14 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
|
||||
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -111,9 +118,13 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
|
||||
assert.Equal(t, tc.matched, matched)
|
||||
assert.Equal(t, tc.upstreams, upstreams)
|
||||
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
|
||||
p.proxy(ctx, &proxyRequest{
|
||||
msg: newDnsMsgWithHostname("foo", dns.TypeA),
|
||||
ufr: ufr,
|
||||
})
|
||||
assert.Equal(t, tc.matched, ufr.matched)
|
||||
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
@@ -149,8 +160,32 @@ func TestCache(t *testing.T) {
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil)
|
||||
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil)
|
||||
req1 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.1"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
req2 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.0"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
got1 := prog.proxy(context.Background(), req1)
|
||||
got2 := prog.proxy(context.Background(), req2)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
@@ -234,3 +269,165 @@ func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipFromARPA(t *testing.T) {
|
||||
tests := []struct {
|
||||
IP string
|
||||
ARPA string
|
||||
}{
|
||||
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
|
||||
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
|
||||
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
|
||||
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
|
||||
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
|
||||
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"", "asd.in-addr.arpa."},
|
||||
{"", "asd.ip6.arpa."},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.IP, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
|
||||
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithClientIP(ip string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_stripClientSubnet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
wantSubnet bool
|
||||
}{
|
||||
{"no edns0", new(dns.Msg), false},
|
||||
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
|
||||
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
|
||||
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
|
||||
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
|
||||
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
|
||||
{"invalid IP", newDnsMsgWithClientIP(""), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
stripClientSubnet(tc.msg)
|
||||
hasSubnet := false
|
||||
if opt := tc.msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
|
||||
hasSubnet = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if tc.wantSubnet != hasSubnet {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(hostname, typ)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isLanHostnameQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isLanHostnameQuery bool
|
||||
}{
|
||||
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
|
||||
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
|
||||
t.Helper()
|
||||
m := new(dns.Msg)
|
||||
ptr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.SetQuestion(ptr, dns.TypePTR)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isPrivatePtrLookup bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
|
||||
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
|
||||
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
|
||||
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
|
||||
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
|
||||
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
|
||||
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
isWanClient bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
|
||||
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
|
||||
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
|
||||
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
|
||||
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
|
||||
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
|
||||
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isWanClient(tc.addr); tc.isWanClient != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -15,6 +16,36 @@ const (
|
||||
loopTestQtype = dns.TypeTXT
|
||||
)
|
||||
|
||||
// newLoopGuard returns new loopGuard.
|
||||
func newLoopGuard() *loopGuard {
|
||||
return &loopGuard{inflight: make(map[string]struct{})}
|
||||
}
|
||||
|
||||
// loopGuard guards against DNS loop, ensuring only one query
|
||||
// for a given domain is processed at a time.
|
||||
type loopGuard struct {
|
||||
mu sync.Mutex
|
||||
inflight map[string]struct{}
|
||||
}
|
||||
|
||||
// TryLock marks the domain as being processed.
|
||||
func (lg *loopGuard) TryLock(domain string) bool {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
if _, inflight := lg.inflight[domain]; !inflight {
|
||||
lg.inflight[domain] = struct{}{}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Unlock marks the domain as being done.
|
||||
func (lg *loopGuard) Unlock(domain string) {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
delete(lg.inflight, domain)
|
||||
}
|
||||
|
||||
// isLoop reports whether the given upstream config is detected as having DNS loop.
|
||||
func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool {
|
||||
p.loopMu.Lock()
|
||||
@@ -56,7 +87,15 @@ func (p *prog) checkDnsLoop() {
|
||||
mainLog.Load().Debug().Msg("start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
for n, uc := range p.cfg.Upstream {
|
||||
if p.um.isDown("upstream." + n) {
|
||||
continue
|
||||
}
|
||||
// Do not send test query to external upstream.
|
||||
if !canBeLocalUpstream(uc.Domain) {
|
||||
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
|
||||
continue
|
||||
}
|
||||
uid := uc.UID()
|
||||
p.loop[uid] = false
|
||||
upstream[uid] = uc
|
||||
@@ -79,13 +118,15 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
func (p *prog) checkDnsLoopTicker() {
|
||||
func (p *prog) checkDnsLoopTicker(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
p.checkDnsLoop()
|
||||
}
|
||||
|
||||
42
cmd/cli/loop_test.go
Normal file
42
cmd/cli/loop_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_loopGuard(t *testing.T) {
|
||||
lg := newLoopGuard()
|
||||
key := "foo"
|
||||
|
||||
var i atomic.Int64
|
||||
var started atomic.Int64
|
||||
n := 1000
|
||||
do := func() {
|
||||
locked := lg.TryLock(key)
|
||||
defer lg.Unlock(key)
|
||||
started.Add(1)
|
||||
for started.Load() < 2 {
|
||||
// Wait until at least 2 goroutines started, otherwise, on system with heavy load,
|
||||
// or having only 1 CPU, all goroutines can be scheduled to run consequently.
|
||||
}
|
||||
if locked {
|
||||
i.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
do()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if i.Load() == int64(n) {
|
||||
t.Fatalf("i must not be increased %d times", n)
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,8 @@ var (
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
nextdns string
|
||||
cdUpstreamProto string
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
@@ -39,8 +41,9 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
nextdnsFlagName = "nextdns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -93,6 +96,7 @@ func initConsoleLogging() {
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func initLogging() {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(true)
|
||||
}
|
||||
|
||||
@@ -131,7 +135,7 @@ func initLoggingWithBackup(doBackup bool) {
|
||||
}
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState() {
|
||||
func (p *prog) watchLinkState(ctx context.Context) {
|
||||
ch := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
@@ -13,14 +15,19 @@ func (p *prog) watchLinkState() {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
|
||||
return
|
||||
}
|
||||
for lu := range ch {
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case lu := <-ch:
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,4 +2,6 @@
|
||||
|
||||
package cli
|
||||
|
||||
func (p *prog) watchLinkState() {}
|
||||
import "context"
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {}
|
||||
|
||||
31
cmd/cli/nextdns.go
Normal file
31
cmd/cli/nextdns.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const nextdnsURL = "https://dns.nextdns.io"
|
||||
|
||||
func generateNextDNSConfig(uid string) {
|
||||
if uid == "" {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver")
|
||||
cfg = ctrld.Config{
|
||||
Listener: map[string]*ctrld.ListenerConfig{
|
||||
"0": {
|
||||
IP: "0.0.0.0",
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Upstream: map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
Type: ctrld.ResolverTypeDOH3,
|
||||
Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, uid),
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -9,10 +9,12 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/client6"
|
||||
@@ -23,7 +25,10 @@ import (
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
const (
|
||||
resolvConfPath = "/etc/resolv.conf"
|
||||
resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
)
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ip a add 127.0.0.2/24 dev lo
|
||||
@@ -64,6 +69,11 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
defer func() {
|
||||
if r.Mode() == "direct" {
|
||||
go watchResolveConf(osConfig)
|
||||
}
|
||||
}()
|
||||
|
||||
trySystemdResolve := false
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
@@ -299,3 +309,59 @@ func sliceIndex[S ~[]E, E comparable](s S, v E) int {
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// watchResolveConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func watchResolveConf(oc dns.OSConfig) {
|
||||
mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file")
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
return
|
||||
}
|
||||
|
||||
// We watch /etc instead of /etc/resolv.conf directly,
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list")
|
||||
return
|
||||
}
|
||||
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter.
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Name != resolvConfPath { // skip if not /etc/resolv.conf changes.
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
if err := r.SetDNS(oc); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
327
cmd/cli/prog.go
327
cmd/cli/prog.go
@@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
@@ -16,7 +17,9 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/viper"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
||||
@@ -30,6 +33,7 @@ const (
|
||||
ctrldControlUnixSock = "ctrld_control.sock"
|
||||
upstreamPrefix = "upstream."
|
||||
upstreamOS = upstreamPrefix + "os"
|
||||
upstreamPrivate = upstreamPrefix + "private"
|
||||
)
|
||||
|
||||
var logf = func(format string, args ...any) {
|
||||
@@ -45,19 +49,25 @@ var svcConfig = &service.Config{
|
||||
var useSystemdResolved = false
|
||||
|
||||
type prog struct {
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
reloadCh chan struct{} // For Windows.
|
||||
reloadDoneCh chan struct{}
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
|
||||
cfg *ctrld.Config
|
||||
appCallback *AppCallback
|
||||
cache dnscache.Cacher
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
um *upstreamMonitor
|
||||
router router.Router
|
||||
cfg *ctrld.Config
|
||||
localUpstreams []string
|
||||
ptrNameservers []string
|
||||
appCallback *AppCallback
|
||||
cache dnscache.Cacher
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
um *upstreamMonitor
|
||||
router router.Router
|
||||
ptrLoopGuard *loopGuard
|
||||
lanLoopGuard *loopGuard
|
||||
|
||||
loopMu sync.Mutex
|
||||
loop map[string]bool
|
||||
@@ -69,11 +79,106 @@ type prog struct {
|
||||
}
|
||||
|
||||
func (p *prog) Start(s service.Service) error {
|
||||
p.cfg = &cfg
|
||||
go p.run()
|
||||
go p.runWait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// runWait runs ctrld components, waiting for signal to reload.
|
||||
func (p *prog) runWait() {
|
||||
p.mu.Lock()
|
||||
p.cfg = &cfg
|
||||
p.mu.Unlock()
|
||||
reloadSigCh := make(chan os.Signal, 1)
|
||||
notifyReloadSigCh(reloadSigCh)
|
||||
|
||||
reload := false
|
||||
logger := mainLog.Load()
|
||||
for {
|
||||
reloadCh := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
p.run(reload, reloadCh)
|
||||
reload = true
|
||||
}()
|
||||
select {
|
||||
case sig := <-reloadSigCh:
|
||||
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
||||
case <-p.reloadCh:
|
||||
logger.Notice().Msg("reloading...")
|
||||
case <-p.stopCh:
|
||||
close(reloadCh)
|
||||
return
|
||||
}
|
||||
|
||||
waitOldRunDone := func() {
|
||||
close(reloadCh)
|
||||
<-done
|
||||
}
|
||||
newCfg := &ctrld.Config{}
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
logger.Err(err).Msg("could not read new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if err := v.Unmarshal(&newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not unmarshal new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if cdUID != "" {
|
||||
if err := processCDFlags(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not fetch ControlD config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
waitOldRunDone()
|
||||
|
||||
p.mu.Lock()
|
||||
curListener := p.cfg.Listener
|
||||
p.mu.Unlock()
|
||||
|
||||
for n, lc := range newCfg.Listener {
|
||||
curLc := curListener[n]
|
||||
if curLc == nil {
|
||||
continue
|
||||
}
|
||||
if lc.IP == "" {
|
||||
lc.IP = curLc.IP
|
||||
}
|
||||
if lc.Port == 0 {
|
||||
lc.Port = curLc.Port
|
||||
}
|
||||
}
|
||||
if err := validateConfig(newCfg); err != nil {
|
||||
logger.Err(err).Msg("invalid config")
|
||||
continue
|
||||
}
|
||||
|
||||
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
||||
// upstream config because its initialization function have not been called yet.
|
||||
mainLog.Load().Debug().Msg("setup upstream with new config")
|
||||
p.setupUpstream(newCfg)
|
||||
|
||||
p.mu.Lock()
|
||||
*p.cfg = *newCfg
|
||||
p.mu.Unlock()
|
||||
|
||||
logger.Notice().Msg("reloading config successfully")
|
||||
select {
|
||||
case p.reloadDoneCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if !service.Interactive() {
|
||||
p.setDNS()
|
||||
@@ -87,14 +192,54 @@ func (p *prog) preRun() {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) run() {
|
||||
func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
||||
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
||||
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
||||
for n := range cfg.Upstream {
|
||||
uc := cfg.Upstream[n]
|
||||
uc.Init()
|
||||
if uc.BootstrapIP == "" {
|
||||
uc.SetupBootstrapIP()
|
||||
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
} else {
|
||||
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
go uc.Ping()
|
||||
|
||||
if canBeLocalUpstream(uc.Domain) {
|
||||
localUpstreams = append(localUpstreams, upstreamPrefix+n)
|
||||
}
|
||||
if uc.IsDiscoverable() {
|
||||
ptrNameservers = append(ptrNameservers, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
p.localUpstreams = localUpstreams
|
||||
p.ptrNameservers = ptrNameservers
|
||||
}
|
||||
|
||||
// run runs the ctrld main components.
|
||||
//
|
||||
// The reload boolean indicates that the function is run when ctrld first start
|
||||
// or when ctrld receive reloading signal. Platform specifics setup is only done
|
||||
// on started, mean reload is "false".
|
||||
//
|
||||
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
|
||||
// so all listeners could be terminated and re-spawned again.
|
||||
func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
// Wait the caller to signal that we can do our logic.
|
||||
<-p.waitCh
|
||||
p.preRun()
|
||||
if !reload {
|
||||
p.preRun()
|
||||
}
|
||||
numListeners := len(p.cfg.Listener)
|
||||
p.started = make(chan struct{}, numListeners)
|
||||
if !reload {
|
||||
p.started = make(chan struct{}, numListeners)
|
||||
}
|
||||
p.onStartedDone = make(chan struct{})
|
||||
p.loop = make(map[string]bool)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
if p.cfg.Service.CacheEnable {
|
||||
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
||||
if err != nil {
|
||||
@@ -103,15 +248,7 @@ func (p *prog) run() {
|
||||
p.cache = cacher
|
||||
}
|
||||
}
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
|
||||
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
|
||||
n := *mcr
|
||||
if n == 0 {
|
||||
p.sema = &noopSemaphore{}
|
||||
} else {
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(p.cfg.Listener))
|
||||
|
||||
@@ -127,74 +264,102 @@ func (p *prog) run() {
|
||||
}
|
||||
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
for n := range p.cfg.Upstream {
|
||||
uc := p.cfg.Upstream[n]
|
||||
uc.Init()
|
||||
if uc.BootstrapIP == "" {
|
||||
uc.SetupBootstrapIP()
|
||||
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
} else {
|
||||
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
||||
|
||||
if !reload {
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
|
||||
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
|
||||
n := *mcr
|
||||
if n == 0 {
|
||||
p.sema = &noopSemaphore{}
|
||||
} else {
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
|
||||
}
|
||||
}
|
||||
p.setupUpstream(p.cfg)
|
||||
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers)
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
go uc.Ping()
|
||||
}
|
||||
|
||||
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID)
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
// context for managing spawn goroutines.
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
// Newer versions of android and iOS denies permission which breaks connectivity.
|
||||
if !isMobile() {
|
||||
if !isMobile() && !reload {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
p.ciTable.Init()
|
||||
p.ciTable.RefreshLoop(p.stopCh)
|
||||
p.ciTable.RefreshLoop(ctx)
|
||||
}()
|
||||
go p.watchLinkState()
|
||||
go p.watchLinkState(ctx)
|
||||
}
|
||||
|
||||
for listenerNum := range p.cfg.Listener {
|
||||
p.cfg.Listener[listenerNum].Init()
|
||||
go func(listenerNum string) {
|
||||
defer wg.Done()
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
upstreamConfig := p.cfg.Upstream[listenerNum]
|
||||
if upstreamConfig == nil {
|
||||
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
||||
if !reload {
|
||||
go func(listenerNum string) {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
upstreamConfig := p.cfg.Upstream[listenerNum]
|
||||
if upstreamConfig == nil {
|
||||
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
||||
}
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
if err := p.serveDNS(listenerNum); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
}(listenerNum)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
cancelFunc()
|
||||
wg.Done()
|
||||
}()
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
case <-ctx.Done():
|
||||
case <-reloadCh:
|
||||
}
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
if err := p.serveDNS(listenerNum); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
}(listenerNum)
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < numListeners; i++ {
|
||||
<-p.started
|
||||
if !reload {
|
||||
for i := 0; i < numListeners; i++ {
|
||||
<-p.started
|
||||
}
|
||||
for _, f := range p.onStarted {
|
||||
f()
|
||||
}
|
||||
}
|
||||
for _, f := range p.onStarted {
|
||||
f()
|
||||
}
|
||||
// Check for possible DNS loop.
|
||||
p.checkDnsLoop()
|
||||
|
||||
close(p.onStartedDone)
|
||||
|
||||
// Start check DNS loop ticker.
|
||||
go p.checkDnsLoopTicker()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// Check for possible DNS loop.
|
||||
p.checkDnsLoop()
|
||||
// Start check DNS loop ticker.
|
||||
p.checkDnsLoopTicker(ctx)
|
||||
}()
|
||||
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
initLoggingWithBackup(false)
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
if p.cs != nil {
|
||||
p.registerControlServerHandler()
|
||||
if err := p.cs.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
||||
if !reload {
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
initLoggingWithBackup(false)
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
if p.cs != nil {
|
||||
p.registerControlServerHandler()
|
||||
if err := p.cs.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
@@ -276,7 +441,7 @@ func (p *prog) setDNS() {
|
||||
|
||||
nameservers := []string{ns}
|
||||
if needRFC1918Listeners(lc) {
|
||||
nameservers = append(nameservers, rfc1918Addresses()...)
|
||||
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
|
||||
}
|
||||
if err := setDNS(netIface, nameservers); err != nil {
|
||||
logger.Error().Err(err).Msgf("could not set DNS for interface")
|
||||
@@ -462,3 +627,11 @@ func defaultRouteIP() string {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
|
||||
return ip
|
||||
}
|
||||
|
||||
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
|
||||
func canBeLocalUpstream(addr string) bool {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ func setDependencies(svc *service.Config) {
|
||||
"Wants=NetworkManager-wait-online.service",
|
||||
"After=NetworkManager-wait-online.service",
|
||||
"Wants=systemd-networkd-wait-online.service",
|
||||
"After=systemd-networkd-wait-online.service",
|
||||
"Wants=nss-lookup.target",
|
||||
"After=nss-lookup.target",
|
||||
}
|
||||
|
||||
17
cmd/cli/reload_others.go
Normal file
17
cmd/cli/reload_others.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}
|
||||
18
cmd/cli/reload_windows.go
Normal file
18
cmd/cli/reload_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
select {
|
||||
case p.reloadCh <- struct{}{}:
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
return errors.New("timeout while sending reload signal")
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
@@ -15,8 +14,8 @@ import (
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 100
|
||||
// checkUpstreamMaxBackoff is the max backoff time when checking upstream status.
|
||||
checkUpstreamMaxBackoff = 2 * time.Minute
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
@@ -76,7 +75,6 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf
|
||||
um.checking[upstream] = true
|
||||
um.mu.Unlock()
|
||||
|
||||
bo := backoff.NewBackoff("checkUpstream", logf, checkUpstreamMaxBackoff)
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
|
||||
@@ -84,15 +82,20 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
ctx := context.Background()
|
||||
|
||||
for {
|
||||
check := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
uc.ReBootstrap()
|
||||
_, err := resolver.Resolve(ctx, msg)
|
||||
if err == nil {
|
||||
return err
|
||||
}
|
||||
for {
|
||||
if err := check(); err == nil {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
|
||||
um.reset(upstream)
|
||||
return
|
||||
}
|
||||
bo.BackOff(ctx, err)
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
}
|
||||
}
|
||||
|
||||
58
config.go
58
config.go
@@ -11,6 +11,7 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
@@ -82,6 +84,16 @@ func InitConfig(v *viper.Viper, name string) {
|
||||
"0": {
|
||||
IP: "",
|
||||
Port: 0,
|
||||
Policy: &ListenerPolicyConfig{
|
||||
Name: "Main Policy",
|
||||
Networks: []Rule{
|
||||
{"network.0": []string{"upstream.0"}},
|
||||
},
|
||||
Rules: []Rule{
|
||||
{"example.com": []string{"upstream.0"}},
|
||||
{"*.ads.com": []string{"upstream.1"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
v.SetDefault("network", map[string]*NetworkConfig{
|
||||
@@ -181,6 +193,7 @@ type ServiceConfig struct {
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
@@ -204,6 +217,9 @@ type UpstreamConfig struct {
|
||||
// The caller should not access this field directly.
|
||||
// Use UpstreamSendClientInfo instead.
|
||||
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
|
||||
// The caller should not access this field directly.
|
||||
// Use IsDiscoverable instead.
|
||||
Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"`
|
||||
|
||||
g singleflight.Group
|
||||
rebootstrap atomic.Bool
|
||||
@@ -224,10 +240,11 @@ type UpstreamConfig struct {
|
||||
|
||||
// ListenerConfig specifies the networks configuration that ctrld will run on.
|
||||
type ListenerConfig struct {
|
||||
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
|
||||
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
|
||||
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
|
||||
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
|
||||
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
|
||||
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
|
||||
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
|
||||
AllowWanClients bool `mapstructure:"allow_wan_clients" toml:"allow_wan_clients,omitempty"`
|
||||
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
|
||||
}
|
||||
|
||||
// IsDirectDnsListener reports whether ctrld can be a direct listener on port 53.
|
||||
@@ -253,6 +270,7 @@ type ListenerPolicyConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
|
||||
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
@@ -322,13 +340,28 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
if uc.isControlD() {
|
||||
if uc.isControlD() || uc.isNextDNS() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDiscoverable reports whether the upstream can be used for PTR discovery.
|
||||
// The caller must ensure uc.Init() was called before calling this.
|
||||
func (uc *UpstreamConfig) IsDiscoverable() bool {
|
||||
if uc.Discoverable != nil {
|
||||
return *uc.Discoverable
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate:
|
||||
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BootstrapIPs returns the bootstrap IPs list of upstreams.
|
||||
func (uc *UpstreamConfig) BootstrapIPs() []string {
|
||||
return uc.bootstrapIPs
|
||||
@@ -394,8 +427,9 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
return
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
uc.rebootstrap.Store(true)
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
@@ -519,6 +553,16 @@ func (uc *UpstreamConfig) isControlD() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) isNextDNS() bool {
|
||||
domain := uc.Domain
|
||||
if domain == "" {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
domain = u.Hostname()
|
||||
}
|
||||
}
|
||||
return domain == "dns.nextdns.io"
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
|
||||
@@ -279,6 +279,61 @@ func TestUpstreamConfig_UpstreamSendClientInfo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
discoverable bool
|
||||
}{
|
||||
{
|
||||
"loopback",
|
||||
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"rfc1918",
|
||||
&UpstreamConfig{Endpoint: "192.168.1.1", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"CGNAT",
|
||||
&UpstreamConfig{Endpoint: "100.66.67.68", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Public IP",
|
||||
&UpstreamConfig{Endpoint: "8.8.8.8", Type: ResolverTypeLegacy},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"override discoverable",
|
||||
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(false)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"override non-public",
|
||||
&UpstreamConfig{Endpoint: "1.1.1.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(true)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-legacy upstream",
|
||||
&UpstreamConfig{Endpoint: "https://192.168.1.1/custom-doh", Type: ResolverTypeDOH},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ptrBool(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
@@ -10,13 +10,10 @@ import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
@@ -29,9 +26,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
if hasIPv6() {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
@@ -127,11 +122,6 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
go func(addr string) {
|
||||
defer wg.Done()
|
||||
@@ -140,6 +130,11 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
|
||||
ch <- ¶llelDialerResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
ch <- ¶llelDialerResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
||||
select {
|
||||
case ch <- ¶llelDialerResult{conn: conn, err: err}:
|
||||
@@ -147,6 +142,9 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
|
||||
if conn != nil {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
|
||||
}
|
||||
if udpConn != nil {
|
||||
udpConn.Close()
|
||||
}
|
||||
}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,12 @@ func TestLoadDefaultConfig(t *testing.T) {
|
||||
cfg := defaultConfig(t)
|
||||
validate := validator.New()
|
||||
require.NoError(t, ctrld.ValidateConfig(validate, cfg))
|
||||
assert.Len(t, cfg.Listener, 1)
|
||||
if assert.Len(t, cfg.Listener, 1) {
|
||||
l0 := cfg.Listener["0"]
|
||||
require.NotNil(t, l0.Policy)
|
||||
assert.Len(t, l0.Policy.Networks, 1)
|
||||
assert.Len(t, l0.Policy.Rules, 2)
|
||||
}
|
||||
assert.Len(t, cfg.Upstream, 2)
|
||||
}
|
||||
|
||||
@@ -96,6 +101,7 @@ func TestConfigValidation(t *testing.T) {
|
||||
{"lease file format required if lease file exist", configWithExistedLeaseFile(t), true},
|
||||
{"invalid lease file format", configWithInvalidLeaseFileFormat(t), true},
|
||||
{"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true},
|
||||
{"invalid client id pref", configWithInvalidClientIDPref(t), true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -233,3 +239,9 @@ func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
|
||||
cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.ClientIDPref = "foo"
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -215,6 +215,17 @@ DHCP leases file format.
|
||||
- Valid values: `dnsmasq`, `isc-dhcp`
|
||||
- Default: ""
|
||||
|
||||
### client_id_preference
|
||||
Decide how the client ID is generated
|
||||
|
||||
If `host` -> client id will only use the hostname i.e.`hash(hostname)`.
|
||||
If `mac` -> client id will only use the MAC address `hash(mac)`.
|
||||
Else -> client ID will use both Mac and Hostname i.e. `hash(mac + host)
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Valid values: `mac`, `host`
|
||||
- Default: ""
|
||||
|
||||
## Upstream
|
||||
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
||||
|
||||
@@ -319,6 +330,24 @@ If `ip_stack` is empty, or undefined:
|
||||
- Default value is `both` for non-Control D resolvers.
|
||||
- Default value is `split` for Control D resolvers.
|
||||
|
||||
### send_client_info
|
||||
Specifying whether to include client info when sending query to upstream.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default:
|
||||
- `true` for ControlD upstreams.
|
||||
- `false` for other upstreams.
|
||||
|
||||
### discoverable
|
||||
Specifying whether the upstream can be used for PTR discovery.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default:
|
||||
- `true` for loopback/RFC1918/CGNAT IP address.
|
||||
- `false` for public IP address.
|
||||
|
||||
## Network
|
||||
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.
|
||||
|
||||
@@ -376,7 +405,14 @@ Port number that the listener will listen on for incoming requests. If `port` is
|
||||
- Default: 0 or 53 or 5354 (depending on platform)
|
||||
|
||||
### restricted
|
||||
If set to `true` makes the listener `REFUSE` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
|
||||
If set to `true`, makes the listener `REFUSED` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
|
||||
|
||||
- Type: bool
|
||||
- Required: no
|
||||
- Default: false
|
||||
|
||||
### allow_wan_clients
|
||||
The listener `REFUSED` DNS queries from WAN clients by default. If set to `true`, makes the listener replies to them.
|
||||
|
||||
- Type: bool
|
||||
- Required: no
|
||||
@@ -386,7 +422,15 @@ If set to `true` makes the listener `REFUSE` DNS queries from all source IP addr
|
||||
Allows `ctrld` to set policy rules to determine which upstreams the requests will be forwarded to.
|
||||
If no `policy` is defined or the requests do not match any policy rules, it will be forwarded to corresponding upstream of the listener. For example, the request to `listener.0` will be forwarded to `upstream.0`.
|
||||
|
||||
The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either the `network` or a domain. Value is the list of the upstreams. For example:
|
||||
The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either:
|
||||
|
||||
- Network.
|
||||
- Domain.
|
||||
- Mac Address.
|
||||
|
||||
Value is the list of the upstreams.
|
||||
|
||||
For example:
|
||||
|
||||
```toml
|
||||
[listener.0.policy]
|
||||
@@ -400,12 +444,18 @@ rules = [
|
||||
{"*.local" = ["upstream.1"]},
|
||||
{"test.com" = ["upstream.2", "upstream.1"]},
|
||||
]
|
||||
|
||||
macs = [
|
||||
{"14:54:4a:8e:08:2d" = ["upstream.3"]},
|
||||
]
|
||||
```
|
||||
|
||||
Above policy will:
|
||||
- Forward requests on `listener.0` from `network.0` to `upstream.1`.
|
||||
|
||||
- Forward requests on `listener.0` for `.local` suffixed domains to `upstream.1`.
|
||||
- Forward requests on `listener.0` for `test.com` to `upstream.2`. If timeout is reached, retry on `upstream.1`.
|
||||
- Forward requests on `listener.0` from client with Mac `14:54:4a:8e:08:2d` to `upstream.3`.
|
||||
- Forward requests on `listener.0` from `network.0` to `upstream.1`.
|
||||
- All other requests on `listener.0` that do not match above conditions will be forwarded to `upstream.0`.
|
||||
|
||||
An empty upstream would not route the request to any defined upstreams, and use the OS default resolver.
|
||||
@@ -419,6 +469,18 @@ rules = [
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Note that the order of matching preference:
|
||||
|
||||
```
|
||||
rules => macs => networks
|
||||
```
|
||||
|
||||
And within each policy, the rules are processed from top to bottom.
|
||||
|
||||
---
|
||||
|
||||
#### name
|
||||
`name` is the name for the policy.
|
||||
|
||||
@@ -440,6 +502,13 @@ rules = [
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
### macs:
|
||||
`macs` is the list of mac rules within the policy. Mac address value is case-insensitive.
|
||||
|
||||
- Type: array of macs
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
### failover_rcodes
|
||||
For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`.
|
||||
|
||||
|
||||
75
doh.go
75
doh.go
@@ -18,11 +18,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dohMacHeader = "x-cd-mac"
|
||||
dohIPHeader = "x-cd-ip"
|
||||
dohHostHeader = "x-cd-host"
|
||||
dohOsHeader = "x-cd-os"
|
||||
headerApplicationDNS = "application/dns-message"
|
||||
dohMacHeader = "x-cd-mac"
|
||||
dohIPHeader = "x-cd-ip"
|
||||
dohHostHeader = "x-cd-host"
|
||||
dohOsHeader = "x-cd-os"
|
||||
dohClientIDPrefHeader = "x-cd-cpref"
|
||||
headerApplicationDNS = "application/dns-message"
|
||||
)
|
||||
|
||||
// EncodeOsNameMap provides mapping from OS name to a shorter string, used for encoding x-cd-os value.
|
||||
@@ -76,7 +77,6 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver {
|
||||
endpoint: uc.u,
|
||||
isDoH3: uc.Type == ResolverTypeDOH3,
|
||||
http3RoundTripper: uc.http3RoundTripper,
|
||||
sendClientInfo: uc.UpstreamSendClientInfo(),
|
||||
uc: uc,
|
||||
}
|
||||
return r
|
||||
@@ -87,9 +87,9 @@ type dohResolver struct {
|
||||
endpoint *url.URL
|
||||
isDoH3 bool
|
||||
http3RoundTripper http.RoundTripper
|
||||
sendClientInfo bool
|
||||
}
|
||||
|
||||
// Resolve performs DNS query with given DNS message using DOH protocol.
|
||||
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
@@ -106,7 +106,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create request: %w", err)
|
||||
}
|
||||
addHeader(ctx, req, r.sendClientInfo)
|
||||
addHeader(ctx, req, r.uc)
|
||||
dnsTyp := uint16(0)
|
||||
if len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
@@ -146,26 +146,19 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
|
||||
func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
|
||||
req.Header.Set("Content-Type", headerApplicationDNS)
|
||||
req.Header.Set("Accept", headerApplicationDNS)
|
||||
req.Header.Set(dohOsHeader, dohOsHeaderValue())
|
||||
|
||||
printed := false
|
||||
if sendClientInfo {
|
||||
if uc.UpstreamSendClientInfo() {
|
||||
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
|
||||
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
|
||||
if ci.Mac != "" {
|
||||
req.Header.Set(dohMacHeader, ci.Mac)
|
||||
}
|
||||
if ci.IP != "" {
|
||||
req.Header.Set(dohIPHeader, ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set(dohHostHeader, ci.Hostname)
|
||||
}
|
||||
if ci.Self {
|
||||
req.Header.Set(dohOsHeader, dohOsHeaderValue())
|
||||
switch {
|
||||
case uc.isControlD():
|
||||
addControlDHeaders(req, ci)
|
||||
case uc.isNextDNS():
|
||||
addNextDNSHeaders(req, ci)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -173,3 +166,41 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
|
||||
Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header")
|
||||
}
|
||||
}
|
||||
|
||||
// addControlDHeaders set DoH/Doh3 HTTP request headers for ControlD upstream.
|
||||
func addControlDHeaders(req *http.Request, ci *ClientInfo) {
|
||||
req.Header.Set(dohOsHeader, dohOsHeaderValue())
|
||||
if ci.Mac != "" {
|
||||
req.Header.Set(dohMacHeader, ci.Mac)
|
||||
}
|
||||
if ci.IP != "" {
|
||||
req.Header.Set(dohIPHeader, ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set(dohHostHeader, ci.Hostname)
|
||||
}
|
||||
if ci.Self {
|
||||
req.Header.Set(dohOsHeader, dohOsHeaderValue())
|
||||
}
|
||||
switch ci.ClientIDPref {
|
||||
case "mac":
|
||||
req.Header.Set(dohClientIDPrefHeader, "1")
|
||||
case "host":
|
||||
req.Header.Set(dohClientIDPrefHeader, "2")
|
||||
}
|
||||
}
|
||||
|
||||
// addNextDNSHeaders set DoH/Doh3 HTTP request headers for nextdns upstream.
|
||||
// https://github.com/nextdns/nextdns/blob/v1.41.0/resolver/doh.go#L100
|
||||
func addNextDNSHeaders(req *http.Request, ci *ClientInfo) {
|
||||
if ci.Mac != "" {
|
||||
// https: //github.com/nextdns/nextdns/blob/v1.41.0/run.go#L543
|
||||
req.Header.Set("X-Device-Model", "mac:"+ci.Mac[:8])
|
||||
}
|
||||
if ci.IP != "" {
|
||||
req.Header.Set("X-Device-Ip", ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set("X-Device-Name", ci.Hostname)
|
||||
}
|
||||
}
|
||||
|
||||
9
go.mod
9
go.mod
@@ -25,9 +25,9 @@ require (
|
||||
github.com/spf13/viper v1.16.0
|
||||
github.com/stretchr/testify v1.8.3
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/net v0.10.0
|
||||
golang.org/x/net v0.17.0
|
||||
golang.org/x/sync v0.2.0
|
||||
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
tailscale.com v1.44.0
|
||||
)
|
||||
@@ -70,11 +70,10 @@ require (
|
||||
github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
|
||||
golang.org/x/crypto v0.9.0 // indirect
|
||||
golang.org/x/crypto v0.14.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect
|
||||
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda // indirect
|
||||
golang.org/x/mod v0.10.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/tools v0.9.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
20
go.sum
20
go.sum
@@ -55,8 +55,6 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf h1:40DHYsri+d1bnroFDU2FQAeq68f3kAlOzlQ93kCf26Q=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -302,8 +300,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -331,8 +329,6 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI
|
||||
golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
|
||||
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda h1:O+EUvnBNPwI4eLthn8W5K+cS8zQZfgTABPLNm6Bna34=
|
||||
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda/go.mod h1:aAjjkJNdrh3PMckS4B10TGS2nag27cbKR1y2BpUxsiY=
|
||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
@@ -378,8 +374,8 @@ golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -452,8 +448,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a h1:qMsju+PNttu/NMbq8bQ9waDdxgJMu9QNoUDuhnBaYt0=
|
||||
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -463,8 +459,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -68,25 +71,27 @@ type Table struct {
|
||||
refreshers []refresher
|
||||
initOnce sync.Once
|
||||
|
||||
dhcp *dhcp
|
||||
merlin *merlinDiscover
|
||||
arp *arpDiscover
|
||||
ptr *ptrDiscover
|
||||
mdns *mdns
|
||||
hf *hostsFile
|
||||
vni *virtualNetworkIface
|
||||
cfg *ctrld.Config
|
||||
quitCh chan struct{}
|
||||
selfIP string
|
||||
cdUID string
|
||||
dhcp *dhcp
|
||||
merlin *merlinDiscover
|
||||
arp *arpDiscover
|
||||
ptr *ptrDiscover
|
||||
mdns *mdns
|
||||
hf *hostsFile
|
||||
vni *virtualNetworkIface
|
||||
svcCfg ctrld.ServiceConfig
|
||||
quitCh chan struct{}
|
||||
selfIP string
|
||||
cdUID string
|
||||
ptrNameservers []string
|
||||
}
|
||||
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table {
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
|
||||
return &Table{
|
||||
cfg: cfg,
|
||||
quitCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
svcCfg: cfg.Service,
|
||||
quitCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
ptrNameservers: ns,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +102,7 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) {
|
||||
clientInfoFiles[name] = format
|
||||
}
|
||||
|
||||
func (t *Table) RefreshLoop(stopCh chan struct{}) {
|
||||
func (t *Table) RefreshLoop(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute * 5)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
@@ -106,7 +111,7 @@ func (t *Table) RefreshLoop(stopCh chan struct{}) {
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
case <-stopCh:
|
||||
case <-ctx.Done():
|
||||
close(t.quitCh)
|
||||
return
|
||||
}
|
||||
@@ -182,6 +187,26 @@ func (t *Table) init() {
|
||||
// PTR lookup.
|
||||
if t.discoverPTR() {
|
||||
t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()}
|
||||
if len(t.ptrNameservers) > 0 {
|
||||
nss := make([]string, 0, len(t.ptrNameservers))
|
||||
for _, ns := range t.ptrNameservers {
|
||||
host, port := ns, "53"
|
||||
if h, p, err := net.SplitHostPort(ns); err == nil {
|
||||
host, port = h, p
|
||||
}
|
||||
// Only use valid ip:port pair.
|
||||
if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil {
|
||||
nss = append(nss, net.JoinHostPort(host, port))
|
||||
} else {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns)
|
||||
}
|
||||
}
|
||||
if len(nss) > 0 {
|
||||
t.ptr.resolver = ctrld.NewResolverWithNameserver(nss)
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss)
|
||||
}
|
||||
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery")
|
||||
if err := t.ptr.refresh(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover")
|
||||
@@ -240,6 +265,21 @@ func (t *Table) LookupHostname(ip, mac string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// LookupRFC1918IPv4 returns the RFC1918 IPv4 address for the given MAC address, if any.
|
||||
func (t *Table) LookupRFC1918IPv4(mac string) string {
|
||||
t.initOnce.Do(t.init)
|
||||
for _, r := range t.ipResolvers {
|
||||
ip, err := netip.ParseAddr(r.LookupIP(mac))
|
||||
if err != nil || ip.Is6() {
|
||||
continue
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type macEntry struct {
|
||||
mac string
|
||||
src string
|
||||
@@ -338,39 +378,60 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) {
|
||||
t.vni.ip2name.Store(ci.IP, ci.Hostname)
|
||||
}
|
||||
|
||||
// ipFinder is the interface for retrieving IP address from hostname.
|
||||
type ipFinder interface {
|
||||
lookupIPByHostname(name string, v6 bool) string
|
||||
}
|
||||
|
||||
// LookupIPByHostname returns the ip address of given hostname.
|
||||
// If v6 is true, return IPv6 instead of default IPv4.
|
||||
func (t *Table) LookupIPByHostname(hostname string, v6 bool) *netip.Addr {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
for _, finder := range []ipFinder{t.hf, t.ptr, t.mdns, t.dhcp} {
|
||||
if addr := finder.lookupIPByHostname(hostname, v6); addr != "" {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil {
|
||||
return &ip
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Table) discoverDHCP() bool {
|
||||
if t.cfg.Service.DiscoverDHCP == nil {
|
||||
if t.svcCfg.DiscoverDHCP == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverDHCP
|
||||
return *t.svcCfg.DiscoverDHCP
|
||||
}
|
||||
|
||||
func (t *Table) discoverARP() bool {
|
||||
if t.cfg.Service.DiscoverARP == nil {
|
||||
if t.svcCfg.DiscoverARP == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverARP
|
||||
return *t.svcCfg.DiscoverARP
|
||||
}
|
||||
|
||||
func (t *Table) discoverMDNS() bool {
|
||||
if t.cfg.Service.DiscoverMDNS == nil {
|
||||
if t.svcCfg.DiscoverMDNS == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverMDNS
|
||||
return *t.svcCfg.DiscoverMDNS
|
||||
}
|
||||
|
||||
func (t *Table) discoverPTR() bool {
|
||||
if t.cfg.Service.DiscoverPtr == nil {
|
||||
if t.svcCfg.DiscoverPtr == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverPtr
|
||||
return *t.svcCfg.DiscoverPtr
|
||||
}
|
||||
|
||||
func (t *Table) discoverHosts() bool {
|
||||
if t.cfg.Service.DiscoverHosts == nil {
|
||||
if t.svcCfg.DiscoverHosts == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverHosts
|
||||
return *t.svcCfg.DiscoverHosts
|
||||
}
|
||||
|
||||
// normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file.
|
||||
|
||||
@@ -25,3 +25,22 @@ func Test_normalizeIP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_LookupRFC1918IPv4(t *testing.T) {
|
||||
table := &Table{
|
||||
dhcp: &dhcp{},
|
||||
arp: &arpDiscover{},
|
||||
}
|
||||
|
||||
table.ipResolvers = append(table.ipResolvers, table.dhcp)
|
||||
table.ipResolvers = append(table.ipResolvers, table.arp)
|
||||
|
||||
macAddress := "cc:19:f9:8a:49:e6"
|
||||
rfc1918IPv4 := "10.0.10.245"
|
||||
table.dhcp.ip.Store(macAddress, "127.0.0.1")
|
||||
table.arp.ip.Store(macAddress, rfc1918IPv4)
|
||||
|
||||
if got := table.LookupRFC1918IPv4(macAddress); got != rfc1918IPv4 {
|
||||
t.Fatalf("unexpected result, want: %s, got: %s", rfc1918IPv4, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -134,6 +135,39 @@ func (d *dhcp) List() []string {
|
||||
return ips
|
||||
}
|
||||
|
||||
func (d *dhcp) lookupIPByHostname(name string, v6 bool) string {
|
||||
if d == nil {
|
||||
return ""
|
||||
}
|
||||
var (
|
||||
rfc1918Addrs []netip.Addr
|
||||
others []netip.Addr
|
||||
)
|
||||
d.ip2name.Range(func(key, value any) bool {
|
||||
if value != name {
|
||||
return true
|
||||
}
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
if addr.IsPrivate() {
|
||||
rfc1918Addrs = append(rfc1918Addrs, addr)
|
||||
} else {
|
||||
others = append(others, addr)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
result := [][]netip.Addr{rfc1918Addrs, others}
|
||||
for _, addrs := range result {
|
||||
if len(addrs) > 0 {
|
||||
sort.Slice(addrs, func(i, j int) bool {
|
||||
return addrs[i].Less(addrs[j])
|
||||
})
|
||||
return addrs[0].String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// AddLeaseFile adds given lease file for reading/watching clients info.
|
||||
func (d *dhcp) addLeaseFile(name string, format ctrld.LeaseFileFormat) error {
|
||||
if d.watcher == nil {
|
||||
|
||||
@@ -86,3 +86,15 @@ lease 192.168.1.2 {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dhcp_lookupIPByHostname(t *testing.T) {
|
||||
d := &dhcp{}
|
||||
want := "192.168.1.123"
|
||||
d.ip2name.Store(want, "foo")
|
||||
d.ip2name.Store("127.0.0.1", "foo")
|
||||
d.ip2name.Store("169.254.123.123", "foo")
|
||||
|
||||
if got := d.lookupIPByHostname("foo", false); got != want {
|
||||
t.Fatalf("unexpected result, want: %s, got: %s", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
@@ -109,6 +110,24 @@ func (hf *hostsFile) String() string {
|
||||
return "hosts"
|
||||
}
|
||||
|
||||
func (hf *hostsFile) lookupIPByHostname(name string, v6 bool) string {
|
||||
if hf == nil {
|
||||
return ""
|
||||
}
|
||||
hf.mu.Lock()
|
||||
defer hf.mu.Unlock()
|
||||
for addr, names := range hf.m {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLoopback() {
|
||||
for _, n := range names {
|
||||
if n == name && ip.Is6() == v6 {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isLocalhostName reports whether the given hostname represents localhost.
|
||||
func isLocalhostName(hostname string) bool {
|
||||
switch hostname {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -59,6 +60,27 @@ func (m *mdns) List() []string {
|
||||
return ips
|
||||
}
|
||||
|
||||
func (m *mdns) lookupIPByHostname(name string, v6 bool) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
var ip string
|
||||
m.name.Range(func(key, value any) bool {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
//lint:ignore S1008 This is used for readable.
|
||||
if addr.IsLoopback() { // Continue searching if this is loopback address.
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return ip
|
||||
}
|
||||
|
||||
func (m *mdns) init(quitCh chan struct{}) error {
|
||||
ifaces, err := multicastInterfaces()
|
||||
if err != nil {
|
||||
@@ -123,6 +145,10 @@ func (m *mdns) readLoop(conn *net.UDPConn) {
|
||||
if err, ok := err.(*net.OpError); ok && (err.Timeout() || err.Temporary()) {
|
||||
continue
|
||||
}
|
||||
// Do not complain about use of closed network connection.
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package clientinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -72,15 +73,16 @@ func (p *ptrDiscover) lookupHostname(ip string) string {
|
||||
msg := new(dns.Msg)
|
||||
addr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
|
||||
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
|
||||
return ""
|
||||
}
|
||||
msg.SetQuestion(addr, dns.TypePTR)
|
||||
ans, err := p.resolver.Resolve(ctx, msg)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
|
||||
p.serverDown.Store(true)
|
||||
go p.checkServer()
|
||||
if p.serverDown.CompareAndSwap(false, true) {
|
||||
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
|
||||
go p.checkServer()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
for _, rr := range ans.Answer {
|
||||
@@ -93,6 +95,27 @@ func (p *ptrDiscover) lookupHostname(ip string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
var ip string
|
||||
p.hostname.Range(func(key, value any) bool {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
//lint:ignore S1008 This is used for readable.
|
||||
if addr.IsLoopback() { // Continue searching if this is loopback address.
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return ip
|
||||
}
|
||||
|
||||
// checkServer monitors if the resolver can reach its nameserver. When the nameserver
|
||||
// is reachable, set p.serverDown to false, so p.lookupHostname can continue working.
|
||||
func (p *ptrDiscover) checkServer() {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -119,7 +120,7 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig
|
||||
return d.DialContext(ctx, network, addrs)
|
||||
}
|
||||
|
||||
if router.Name() == ddwrt.Name {
|
||||
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
|
||||
}
|
||||
client := http.Client{
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package dnsmasq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -19,6 +22,9 @@ server={{ .IP }}#{{ .Port }}
|
||||
add-mac
|
||||
add-subnet=32,128
|
||||
{{- end}}
|
||||
{{- if .CacheDisabled}}
|
||||
cache-size=0
|
||||
{{- end}}
|
||||
`
|
||||
|
||||
const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf"
|
||||
@@ -47,6 +53,8 @@ if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then
|
||||
{{- end}}
|
||||
pc_delete "dnssec" "$config_file" # disable DNSSEC
|
||||
pc_delete "trust-anchor=" "$config_file" # disable DNSSEC
|
||||
pc_delete "cache-size=" "$config_file"
|
||||
pc_append "cache-size=0" "$config_file" # disable cache
|
||||
|
||||
# For John fork
|
||||
pc_delete "resolv-file" "$config_file" # no WAN DNS settings
|
||||
@@ -65,6 +73,10 @@ type Upstream struct {
|
||||
}
|
||||
|
||||
func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
|
||||
return ConfTmplWitchCacheDisabled(tmplText, cfg, true)
|
||||
}
|
||||
|
||||
func ConfTmplWitchCacheDisabled(tmplText string, cfg *ctrld.Config, cacheDisabled bool) (string, error) {
|
||||
listener := cfg.FirstListener()
|
||||
if listener == nil {
|
||||
return "", errors.New("missing listener")
|
||||
@@ -74,24 +86,26 @@ func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
|
||||
ip = "127.0.0.1"
|
||||
}
|
||||
upstreams := []Upstream{{IP: ip, Port: listener.Port}}
|
||||
return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo())
|
||||
return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo(), cacheDisabled)
|
||||
}
|
||||
|
||||
func FirewallaConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
|
||||
if lc := cfg.FirstListener(); lc != nil && (lc.IP == "0.0.0.0" || lc.IP == "") {
|
||||
return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo())
|
||||
return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo(), true)
|
||||
}
|
||||
return ConfTmpl(tmplText, cfg)
|
||||
}
|
||||
|
||||
func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo bool) (string, error) {
|
||||
func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo, cacheDisabled bool) (string, error) {
|
||||
tmpl := template.Must(template.New("").Parse(tmplText))
|
||||
var to = &struct {
|
||||
SendClientInfo bool
|
||||
Upstreams []Upstream
|
||||
CacheDisabled bool
|
||||
}{
|
||||
SendClientInfo: sendClientInfo,
|
||||
Upstreams: upstreams,
|
||||
CacheDisabled: cacheDisabled,
|
||||
}
|
||||
var sb strings.Builder
|
||||
if err := tmpl.Execute(&sb, to); err != nil {
|
||||
@@ -117,9 +131,28 @@ func firewallaUpstreams(port int) []Upstream {
|
||||
return upstreams
|
||||
}
|
||||
|
||||
// firewallaDnsmasqConfFiles returns dnsmasq config files of all firewalla interfaces.
|
||||
func firewallaDnsmasqConfFiles() ([]string, error) {
|
||||
return filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf")
|
||||
}
|
||||
|
||||
// firewallUpdateConf updates all firewall config files using given function.
|
||||
func firewallUpdateConf(update func(conf string) error) error {
|
||||
confFiles, err := firewallaDnsmasqConfFiles()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, conf := range confFiles {
|
||||
if err := update(conf); err != nil {
|
||||
return fmt.Errorf("%s: %w", conf, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FirewallaSelfInterfaces returns list of interfaces that will be configured with default dnsmasq setup on Firewalla.
|
||||
func FirewallaSelfInterfaces() []*net.Interface {
|
||||
matches, err := filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf")
|
||||
matches, err := firewallaDnsmasqConfFiles()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -133,3 +166,32 @@ func FirewallaSelfInterfaces() []*net.Interface {
|
||||
}
|
||||
return ifaces
|
||||
}
|
||||
|
||||
// FirewallaDisableCache comments out "cache-size" line in all firewalla dnsmasq config files.
|
||||
func FirewallaDisableCache() error {
|
||||
return firewallUpdateConf(DisableCache)
|
||||
}
|
||||
|
||||
// FirewallaEnableCache un-comments out "cache-size" line in all firewalla dnsmasq config files.
|
||||
func FirewallaEnableCache() error {
|
||||
return firewallUpdateConf(EnableCache)
|
||||
}
|
||||
|
||||
// DisableCache comments out "cache-size" line in dnsmasq config file.
|
||||
func DisableCache(conf string) error {
|
||||
return replaceFileContent(conf, "\ncache-size=", "\n#cache-size=")
|
||||
}
|
||||
|
||||
// EnableCache un-comments "cache-size" line in dnsmasq config file.
|
||||
func EnableCache(conf string) error {
|
||||
return replaceFileContent(conf, "\n#cache-size=", "\ncache-size=")
|
||||
}
|
||||
|
||||
func replaceFileContent(filename, old, new string) error {
|
||||
content, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
content = bytes.ReplaceAll(content, []byte(old), []byte(new))
|
||||
return os.WriteFile(filename, content, 0644)
|
||||
}
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -95,7 +95,7 @@ func (e *EdgeOS) setupUSG() error {
|
||||
return fmt.Errorf("setupUSG: backup current config: %w", err)
|
||||
}
|
||||
|
||||
// Removing all configured upstreams.
|
||||
// Removing all configured upstreams and cache config.
|
||||
var sb strings.Builder
|
||||
scanner := bufio.NewScanner(bytes.NewReader(buf))
|
||||
for scanner.Scan() {
|
||||
@@ -109,7 +109,7 @@ func (e *EdgeOS) setupUSG() error {
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg)
|
||||
data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func (e *EdgeOS) setupUSG() error {
|
||||
}
|
||||
|
||||
func (e *EdgeOS) setupUDM() error {
|
||||
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg)
|
||||
data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -65,6 +65,11 @@ func (f *Firewalla) Setup() error {
|
||||
return fmt.Errorf("writing ctrld config: %w", err)
|
||||
}
|
||||
|
||||
// Disable dnsmasq cache.
|
||||
if err := dnsmasq.FirewallaDisableCache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return fmt.Errorf("restartDNSMasq: %w", err)
|
||||
@@ -82,6 +87,11 @@ func (f *Firewalla) Cleanup() error {
|
||||
return fmt.Errorf("removing ctrld config: %w", err)
|
||||
}
|
||||
|
||||
// Enable dnsmasq cache.
|
||||
if err := dnsmasq.FirewallaEnableCache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return fmt.Errorf("restartDNSMasq: %w", err)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -44,8 +45,24 @@ func (m *Merlin) Uninstall(_ *service.Config) error {
|
||||
}
|
||||
|
||||
func (m *Merlin) PreRun() error {
|
||||
// Wait NTP ready.
|
||||
_ = m.Cleanup()
|
||||
return ntp.WaitNvram()
|
||||
if err := ntp.WaitNvram(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Wait until directories mounted.
|
||||
for _, dir := range []string{"/tmp", "/proc"} {
|
||||
waitDirExists(dir)
|
||||
}
|
||||
// Wait dnsmasq started.
|
||||
for {
|
||||
out, _ := exec.Command("pidof", "dnsmasq").CombinedOutput()
|
||||
if len(bytes.TrimSpace(out)) > 0 {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Merlin) Setup() error {
|
||||
@@ -56,9 +73,6 @@ func (m *Merlin) Setup() error {
|
||||
if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" {
|
||||
return nil
|
||||
}
|
||||
if _, err := nvram.Run("set", nvram.CtrldSetupKey+"=1"); err != nil {
|
||||
return err
|
||||
}
|
||||
buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath)
|
||||
// Already setup.
|
||||
if bytes.Contains(buf, []byte(dnsmasq.MerlinPostConfMarker)) {
|
||||
@@ -140,3 +154,12 @@ func merlinParsePostConf(buf []byte) []byte {
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func waitDirExists(dir string) {
|
||||
for {
|
||||
if _, err := os.Stat(dir); !os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,11 +8,10 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,10 +19,9 @@ const (
|
||||
openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
|
||||
)
|
||||
|
||||
var errUCIEntryNotFound = errors.New("uci: Entry not found")
|
||||
|
||||
type Openwrt struct {
|
||||
cfg *ctrld.Config
|
||||
cfg *ctrld.Config
|
||||
dnsmasqCacheSize string
|
||||
}
|
||||
|
||||
// New returns a router.Router for configuring/setup/run ctrld on Openwrt routers.
|
||||
@@ -52,6 +50,19 @@ func (o *Openwrt) Setup() error {
|
||||
if o.cfg.FirstListener().IsDirectDnsListener() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save current dnsmasq config cache size if present.
|
||||
if cs, err := uci("get", "dhcp.@dnsmasq[0].cachesize"); err == nil {
|
||||
o.dnsmasqCacheSize = cs
|
||||
if _, err := uci("delete", "dhcp.@dnsmasq[0].cachesize"); err != nil {
|
||||
return err
|
||||
}
|
||||
// Commit.
|
||||
if _, err := uci("commit", "dhcp"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, o.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -59,10 +70,6 @@ func (o *Openwrt) Setup() error {
|
||||
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Commit.
|
||||
if _, err := uci("commit"); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
@@ -78,6 +85,18 @@ func (o *Openwrt) Cleanup() error {
|
||||
if err := os.Remove(openwrtDNSMasqConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore original value if present.
|
||||
if o.dnsmasqCacheSize != "" {
|
||||
if _, err := uci("set", fmt.Sprintf("dhcp.@dnsmasq[0].cachesize=%s", o.dnsmasqCacheSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Commit.
|
||||
if _, err := uci("commit", "dhcp"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
@@ -92,6 +111,8 @@ func restartDNSMasq() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUCIEntryNotFound = errors.New("uci: Entry not found")
|
||||
|
||||
func uci(args ...string) (string, error) {
|
||||
cmd := exec.Command("uci", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
@@ -5,16 +5,17 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/edgeos"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "ubios"
|
||||
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
|
||||
Name = "ubios"
|
||||
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
|
||||
ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf"
|
||||
)
|
||||
|
||||
type Ubios struct {
|
||||
@@ -57,6 +58,10 @@ func (u *Ubios) Setup() error {
|
||||
if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Disable dnsmasq cache.
|
||||
if err := dnsmasq.DisableCache(ubiosDNSMasqDnsConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
@@ -72,6 +77,10 @@ func (u *Ubios) Cleanup() error {
|
||||
if err := os.Remove(ubiosDNSMasqConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Enable dnsmasq cache.
|
||||
if err := dnsmasq.EnableCache(ubiosDNSMasqDnsConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
|
||||
35
net.go
35
net.go
@@ -2,13 +2,10 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
@@ -17,30 +14,36 @@ var (
|
||||
ipv6Available atomic.Bool
|
||||
)
|
||||
|
||||
const ipv6ProbingInterval = 10 * time.Second
|
||||
|
||||
func hasIPv6() bool {
|
||||
hasIPv6Once.Do(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
val := ctrldnet.IPv6Available(ctx)
|
||||
ipv6Available.Store(val)
|
||||
go probingIPv6(val)
|
||||
go probingIPv6(context.TODO(), val)
|
||||
})
|
||||
return ipv6Available.Load()
|
||||
}
|
||||
|
||||
// TODO(cuonglm): doing poll check natively for supported platforms.
|
||||
func probingIPv6(old bool) {
|
||||
b := backoff.NewBackoff("probingIPv6", func(format string, args ...any) {}, 30*time.Second)
|
||||
bCtx := context.Background()
|
||||
func probingIPv6(ctx context.Context, old bool) {
|
||||
ticker := time.NewTicker(ipv6ProbingInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
cur := ctrldnet.IPv6Available(ctx)
|
||||
if ipv6Available.CompareAndSwap(old, cur) {
|
||||
old = cur
|
||||
}
|
||||
}()
|
||||
b.BackOff(bCtx, errors.New("no change"))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
cur := ctrldnet.IPv6Available(ctx)
|
||||
if ipv6Available.CompareAndSwap(old, cur) {
|
||||
old = cur
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
51
resolver.go
51
resolver.go
@@ -5,10 +5,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"tailscale.com/net/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,6 +26,8 @@ const (
|
||||
ResolverTypeOS = "os"
|
||||
// ResolverTypeLegacy specifies legacy resolver.
|
||||
ResolverTypeLegacy = "legacy"
|
||||
// ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only.
|
||||
ResolverTypePrivate = "private"
|
||||
)
|
||||
|
||||
var bootstrapDNS = "76.76.2.0"
|
||||
@@ -61,6 +65,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
return or, nil
|
||||
case ResolverTypeLegacy:
|
||||
return &legacyResolver{uc: uc}, nil
|
||||
case ResolverTypePrivate:
|
||||
return NewPrivateResolver(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
|
||||
}
|
||||
@@ -74,8 +80,9 @@ type osResolverResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from
|
||||
// available nameservers with a roundrobin algorithm.
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
numServers := len(o.nameservers)
|
||||
if numServers == 0 {
|
||||
@@ -240,12 +247,16 @@ func NewBootstrapResolver(servers ...string) Resolver {
|
||||
}
|
||||
|
||||
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers,
|
||||
// excluding nameservers from /etc/resolv.conf file.
|
||||
// excluding:
|
||||
//
|
||||
// - Nameservers from /etc/resolv.conf file.
|
||||
// - Nameservers which is local RFC1918 addresses.
|
||||
//
|
||||
// This is useful for doing PTR lookup in LAN network.
|
||||
func NewPrivateResolver() Resolver {
|
||||
nss := nameservers()
|
||||
resolveConfNss := nameserversFromResolvconf()
|
||||
localRfc1918Addrs := Rfc1918Addresses()
|
||||
n := 0
|
||||
for _, ns := range nss {
|
||||
host, _, _ := net.SplitHostPort(ns)
|
||||
@@ -258,6 +269,10 @@ func NewPrivateResolver() Resolver {
|
||||
if sliceContains(resolveConfNss, host) {
|
||||
continue
|
||||
}
|
||||
// Ignoring local RFC 1918 addresses.
|
||||
if sliceContains(localRfc1918Addrs, host) {
|
||||
continue
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.IsPrivate() && !ip.IsLoopback() {
|
||||
nss[n] = ns
|
||||
@@ -265,11 +280,35 @@ func NewPrivateResolver() Resolver {
|
||||
}
|
||||
}
|
||||
nss = nss[:n]
|
||||
if len(nss) == 0 {
|
||||
return NewResolverWithNameserver(nss)
|
||||
}
|
||||
|
||||
// NewResolverWithNameserver returns an OS resolver which uses the given nameservers
|
||||
// for resolving DNS queries. If nameservers is empty, a dummy resolver will be returned.
|
||||
//
|
||||
// Each nameserver must be form "host:port". It's the caller responsibility to ensure all
|
||||
// nameservers are well formatted by using net.JoinHostPort function.
|
||||
func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
if len(nameservers) == 0 {
|
||||
return &dummyResolver{}
|
||||
}
|
||||
resolver := &osResolver{nameservers: nss}
|
||||
return resolver
|
||||
return &osResolver{nameservers: nameservers}
|
||||
}
|
||||
|
||||
// Rfc1918Addresses returns the list of local interfaces private IP addresses
|
||||
func Rfc1918Addresses() []string {
|
||||
var res []string
|
||||
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
addrs, _ := i.Addrs()
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok || !ipNet.IP.IsPrivate() {
|
||||
continue
|
||||
}
|
||||
res = append(res, ipNet.IP.String())
|
||||
}
|
||||
})
|
||||
return res
|
||||
}
|
||||
|
||||
func newDialer(dnsAddress string) *net.Dialer {
|
||||
|
||||
@@ -82,4 +82,8 @@ rules = [
|
||||
{"*.ru" = ["upstream.1"]},
|
||||
{"*.local.host" = ["upstream.2", "upstream.0"]},
|
||||
]
|
||||
macs = [
|
||||
{"14:45:A0:67:83:0A" = ["upstream.2"]},
|
||||
{"14:54:4a:8e:08:2d" = ["upstream.2"]},
|
||||
]
|
||||
`
|
||||
|
||||
Reference in New Issue
Block a user