mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Merge pull request #93 from Control-D-Inc/release-branch-v1.3.1
Release branch v1.3.1
This commit is contained in:
@@ -76,7 +76,7 @@ $ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
or
|
||||
|
||||
```
|
||||
$ docker build -t controldns/ctrld .
|
||||
$ docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
$ docker run -d --name=ctrld -p 53:53/tcp -p 53:53/udp controldns/ctrld --cd=RESOLVER_ID_GOES_HERE -vv
|
||||
```
|
||||
|
||||
@@ -188,8 +188,8 @@ See [Configuration Docs](docs/config.md).
|
||||
[listener]
|
||||
|
||||
[listener.0]
|
||||
ip = "127.0.0.1"
|
||||
port = 53
|
||||
ip = ""
|
||||
port = 0
|
||||
restricted = false
|
||||
|
||||
[network]
|
||||
@@ -220,6 +220,8 @@ See [Configuration Docs](docs/config.md).
|
||||
|
||||
```
|
||||
|
||||
`ctrld` will pick a working config for `listener.0` then writing the default config to disk for the first run.
|
||||
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ type ClientInfo struct {
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Self bool
|
||||
}
|
||||
|
||||
// LeaseFileFormat specifies the format of DHCP lease file.
|
||||
|
||||
482
cmd/cli/cli.go
482
cmd/cli/cli.go
@@ -124,185 +124,7 @@ func initCLI() {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
waitCh := make(chan struct{})
|
||||
stopCh := make(chan struct{})
|
||||
p := &prog{
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
cfg: &cfg,
|
||||
}
|
||||
if homedir == "" {
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
homedir = dir
|
||||
}
|
||||
}
|
||||
sockPath := filepath.Join(homedir, ctrldLogUnixSock)
|
||||
if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil {
|
||||
if conn, err := net.Dial(addr.Network(), addr.String()); err == nil {
|
||||
lc := &logConn{conn: conn}
|
||||
consoleWriter.Out = io.MultiWriter(os.Stdout, lc)
|
||||
p.logConn = lc
|
||||
}
|
||||
}
|
||||
|
||||
if daemon && runtime.GOOS == "windows" {
|
||||
mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.")
|
||||
}
|
||||
|
||||
if !daemon {
|
||||
// We need to call s.Run() as soon as possible to response to the OS manager, so it
|
||||
// can see ctrld is running and don't mark ctrld as failed service.
|
||||
go func() {
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed create new service")
|
||||
}
|
||||
if err := s.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to start service")
|
||||
}
|
||||
}()
|
||||
}
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
readBase64Config(configBase64)
|
||||
processNoConfigFlags(noConfigStart)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
processLogAndCacheFlags()
|
||||
|
||||
// Log config do not have thing to validate, so it's safe to init log here,
|
||||
// so it's able to log information in processCDFlags.
|
||||
initLogging()
|
||||
|
||||
mainLog.Load().Info().Msgf("starting ctrld %s", curVersion())
|
||||
mainLog.Load().Info().Msgf("os: %s", osVersion())
|
||||
|
||||
// Wait for network up.
|
||||
if !ctrldnet.Up() {
|
||||
mainLog.Load().Fatal().Msg("network is not up yet")
|
||||
}
|
||||
|
||||
p.router = router.New(&cfg, cdUID != "")
|
||||
cs, err := newControlServer(filepath.Join(homedir, ctrldControlUnixSock))
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create control server")
|
||||
}
|
||||
p.cs = cs
|
||||
|
||||
// Processing --cd flag require connecting to ControlD API, which needs valid
|
||||
// time for validating server certificate. Some routers need NTP synchronization
|
||||
// to set the current time, so this check must happen before processCDFlags.
|
||||
if err := p.router.PreRun(); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check")
|
||||
}
|
||||
|
||||
oldLogPath := cfg.Service.LogPath
|
||||
if uid := cdUIDFromProvToken(); uid != "" {
|
||||
cdUID = uid
|
||||
}
|
||||
if cdUID != "" {
|
||||
processCDFlags()
|
||||
}
|
||||
|
||||
updated := updateListenerConfig()
|
||||
|
||||
if cdUID != "" {
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := writeConfigFile(); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to write config file")
|
||||
} else {
|
||||
mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile)
|
||||
}
|
||||
}
|
||||
|
||||
if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath {
|
||||
// After processCDFlags, log config may change, so reset mainLog and re-init logging.
|
||||
l := zerolog.New(io.Discard)
|
||||
mainLog.Store(&l)
|
||||
|
||||
// Copy logs written so far to new log file if possible.
|
||||
if buf, err := os.ReadFile(oldLogPath); err == nil {
|
||||
if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not copy old log file")
|
||||
}
|
||||
}
|
||||
initLoggingWithBackup(false)
|
||||
}
|
||||
|
||||
validateConfig(&cfg)
|
||||
initCache()
|
||||
|
||||
if daemon {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to find the binary")
|
||||
os.Exit(1)
|
||||
}
|
||||
curDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get current working directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
// If running as daemon, re-run the command in background, with daemon off.
|
||||
cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...)
|
||||
cmd.Dir = curDir
|
||||
if err := cmd.Start(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to start process as daemon")
|
||||
os.Exit(1)
|
||||
}
|
||||
mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
p.onStarted = append(p.onStarted, func() {
|
||||
for _, lc := range p.cfg.Listener {
|
||||
if shouldAllocateLoopbackIP(lc.IP) {
|
||||
if err := allocateIP(lc.IP); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
for _, lc := range p.cfg.Listener {
|
||||
if shouldAllocateLoopbackIP(lc.IP) {
|
||||
if err := deAllocateIP(lc.IP); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
if platform := router.Name(); platform != "" {
|
||||
if cp := router.CertPool(); cp != nil {
|
||||
rootCertPool = cp
|
||||
}
|
||||
p.onStarted = append(p.onStarted, func() {
|
||||
mainLog.Load().Debug().Msg("router setup on start")
|
||||
if err := p.router.Setup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not configure router")
|
||||
}
|
||||
})
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
mainLog.Load().Debug().Msg("router cleanup on stop")
|
||||
if err := p.router.Cleanup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
p.resetDNS()
|
||||
})
|
||||
}
|
||||
|
||||
close(waitCh)
|
||||
<-stopCh
|
||||
for _, f := range p.onStopped {
|
||||
f()
|
||||
}
|
||||
RunCobraCommand(cmd)
|
||||
},
|
||||
}
|
||||
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
|
||||
@@ -314,8 +136,8 @@ func initCLI() {
|
||||
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
runCmd.Flags().StringVarP(&cdOrg, "cd-org", "", "", "Control D provision token")
|
||||
runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
|
||||
runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
|
||||
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = runCmd.Flags().MarkHidden("dev")
|
||||
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
|
||||
@@ -334,6 +156,8 @@ func initCLI() {
|
||||
Short: "Install and start the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
checkStrFlagEmpty(cmd, cdUidFlagName)
|
||||
checkStrFlagEmpty(cmd, cdOrgFlagName)
|
||||
sc := &service.Config{}
|
||||
*sc = *svcConfig
|
||||
osArgs := os.Args[2:]
|
||||
@@ -466,8 +290,8 @@ func initCLI() {
|
||||
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
startCmd.Flags().StringVarP(&cdOrg, "cd-org", "", "", "Control D provision token")
|
||||
startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
|
||||
startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
|
||||
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = startCmd.Flags().MarkHidden("dev")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
@@ -804,6 +628,9 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
map2Slice := func(m map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
if k == "" { // skip empty source from output.
|
||||
continue
|
||||
}
|
||||
s = append(s, k)
|
||||
}
|
||||
sort.Strings(s)
|
||||
@@ -838,6 +665,222 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
rootCmd.AddCommand(clientsCmd)
|
||||
}
|
||||
|
||||
// isMobile reports whether the current OS is a mobile platform.
|
||||
func isMobile() bool {
|
||||
return runtime.GOOS == "android" || runtime.GOOS == "ios"
|
||||
}
|
||||
|
||||
// RunCobraCommand runs ctrld cli.
|
||||
func RunCobraCommand(cmd *cobra.Command) {
|
||||
noConfigStart = isNoConfigStart(cmd)
|
||||
checkStrFlagEmpty(cmd, cdUidFlagName)
|
||||
checkStrFlagEmpty(cmd, cdOrgFlagName)
|
||||
run(nil, make(chan struct{}))
|
||||
}
|
||||
|
||||
// RunMobile runs the ctrld cli on mobile platforms.
|
||||
func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struct{}) {
|
||||
if appConfig == nil {
|
||||
panic("appConfig is nil")
|
||||
}
|
||||
initConsoleLogging()
|
||||
noConfigStart = false
|
||||
homedir = appConfig.HomeDir
|
||||
verbose = appConfig.Verbose
|
||||
cdUID = appConfig.CdUID
|
||||
logPath = appConfig.LogPath
|
||||
run(appCallback, stopCh)
|
||||
}
|
||||
|
||||
// run runs ctrld cli with given app callback and stop channel.
|
||||
func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
if stopCh == nil {
|
||||
mainLog.Load().Fatal().Msg("stopCh is nil")
|
||||
}
|
||||
waitCh := make(chan struct{})
|
||||
p := &prog{
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
cfg: &cfg,
|
||||
appCallback: appCallback,
|
||||
}
|
||||
if homedir == "" {
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
homedir = dir
|
||||
}
|
||||
}
|
||||
sockPath := filepath.Join(homedir, ctrldLogUnixSock)
|
||||
if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil {
|
||||
if conn, err := net.Dial(addr.Network(), addr.String()); err == nil {
|
||||
lc := &logConn{conn: conn}
|
||||
consoleWriter.Out = io.MultiWriter(os.Stdout, lc)
|
||||
p.logConn = lc
|
||||
}
|
||||
}
|
||||
|
||||
if daemon && runtime.GOOS == "windows" {
|
||||
mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.")
|
||||
}
|
||||
|
||||
if !daemon {
|
||||
// We need to call s.Run() as soon as possible to response to the OS manager, so it
|
||||
// can see ctrld is running and don't mark ctrld as failed service.
|
||||
go func() {
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed create new service")
|
||||
}
|
||||
if err := s.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to start service")
|
||||
}
|
||||
}()
|
||||
}
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
readBase64Config(configBase64)
|
||||
processNoConfigFlags(noConfigStart)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
processLogAndCacheFlags()
|
||||
|
||||
// Log config do not have thing to validate, so it's safe to init log here,
|
||||
// so it's able to log information in processCDFlags.
|
||||
initLogging()
|
||||
|
||||
mainLog.Load().Info().Msgf("starting ctrld %s", curVersion())
|
||||
mainLog.Load().Info().Msgf("os: %s", osVersion())
|
||||
|
||||
// Wait for network up.
|
||||
if !ctrldnet.Up() {
|
||||
mainLog.Load().Fatal().Msg("network is not up yet")
|
||||
}
|
||||
|
||||
p.router = router.New(&cfg, cdUID != "")
|
||||
cs, err := newControlServer(filepath.Join(homedir, ctrldControlUnixSock))
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create control server")
|
||||
}
|
||||
p.cs = cs
|
||||
|
||||
// Processing --cd flag require connecting to ControlD API, which needs valid
|
||||
// time for validating server certificate. Some routers need NTP synchronization
|
||||
// to set the current time, so this check must happen before processCDFlags.
|
||||
if err := p.router.PreRun(); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check")
|
||||
}
|
||||
|
||||
oldLogPath := cfg.Service.LogPath
|
||||
if uid := cdUIDFromProvToken(); uid != "" {
|
||||
cdUID = uid
|
||||
}
|
||||
if cdUID != "" {
|
||||
err := processCDFlags()
|
||||
if err != nil {
|
||||
appCallback.Exit(err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
updated := updateListenerConfig()
|
||||
|
||||
if cdUID != "" {
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := writeConfigFile(); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to write config file")
|
||||
} else {
|
||||
mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile)
|
||||
}
|
||||
}
|
||||
|
||||
if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath {
|
||||
// After processCDFlags, log config may change, so reset mainLog and re-init logging.
|
||||
l := zerolog.New(io.Discard)
|
||||
mainLog.Store(&l)
|
||||
|
||||
// Copy logs written so far to new log file if possible.
|
||||
if buf, err := os.ReadFile(oldLogPath); err == nil {
|
||||
if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not copy old log file")
|
||||
}
|
||||
}
|
||||
initLoggingWithBackup(false)
|
||||
}
|
||||
|
||||
validateConfig(&cfg)
|
||||
initCache()
|
||||
|
||||
if daemon {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to find the binary")
|
||||
os.Exit(1)
|
||||
}
|
||||
curDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get current working directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
// If running as daemon, re-run the command in background, with daemon off.
|
||||
cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...)
|
||||
cmd.Dir = curDir
|
||||
if err := cmd.Start(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to start process as daemon")
|
||||
os.Exit(1)
|
||||
}
|
||||
mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
p.onStarted = append(p.onStarted, func() {
|
||||
for _, lc := range p.cfg.Listener {
|
||||
if shouldAllocateLoopbackIP(lc.IP) {
|
||||
if err := allocateIP(lc.IP); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
for _, lc := range p.cfg.Listener {
|
||||
if shouldAllocateLoopbackIP(lc.IP) {
|
||||
if err := deAllocateIP(lc.IP); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
if platform := router.Name(); platform != "" {
|
||||
if cp := router.CertPool(); cp != nil {
|
||||
rootCertPool = cp
|
||||
}
|
||||
p.onStarted = append(p.onStarted, func() {
|
||||
mainLog.Load().Debug().Msg("router setup on start")
|
||||
if err := p.router.Setup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not configure router")
|
||||
}
|
||||
})
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
mainLog.Load().Debug().Msg("router cleanup on stop")
|
||||
if err := p.router.Cleanup(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
p.resetDNS()
|
||||
})
|
||||
}
|
||||
|
||||
close(waitCh)
|
||||
<-stopCh
|
||||
for _, f := range p.onStopped {
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
func writeConfigFile() error {
|
||||
if cfu := v.ConfigFileUsed(); cfu != "" {
|
||||
defaultConfigFile = cfu
|
||||
@@ -882,6 +925,7 @@ func readConfigFile(writeDefaultConfig bool) bool {
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
_ = updateListenerConfig()
|
||||
if err := writeConfigFile(); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
|
||||
} else {
|
||||
@@ -971,7 +1015,7 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
v.Set("upstream", upstream)
|
||||
}
|
||||
|
||||
func processCDFlags() {
|
||||
func processCDFlags() error {
|
||||
logger := mainLog.Load().With().Str("mode", "cd").Logger()
|
||||
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
|
||||
bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second)
|
||||
@@ -991,12 +1035,12 @@ func processCDFlags() {
|
||||
s, err := newService(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to create new service")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
if netIface, _ := netInterface(iface); netIface != nil {
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
@@ -1010,11 +1054,16 @@ func processCDFlags() {
|
||||
if doTasks(tasks) {
|
||||
logger.Info().Msg("uninstalled service")
|
||||
}
|
||||
logger.Fatal().Err(uer).Msg("failed to fetch resolver config")
|
||||
event := logger.Fatal()
|
||||
if isMobile() {
|
||||
event = logger.Warn()
|
||||
}
|
||||
event.Err(uer).Msg("failed to fetch resolver config")
|
||||
return uer
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info().Msg("generating ctrld config from Control-D configuration")
|
||||
@@ -1058,6 +1107,7 @@ func processCDFlags() {
|
||||
"0": {IP: "", Port: 0},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processListenFlag() {
|
||||
@@ -1266,7 +1316,17 @@ func userHomeDir() (string, error) {
|
||||
}
|
||||
// viper will expand for us.
|
||||
if runtime.GOOS == "windows" {
|
||||
return os.UserHomeDir()
|
||||
// If we're on windows, use the install path for this.
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Dir(exePath), nil
|
||||
}
|
||||
// Mobile platform should provide a rw dir path for this.
|
||||
if isMobile() {
|
||||
return homedir, nil
|
||||
}
|
||||
dir = "/etc/controld"
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
@@ -1412,6 +1472,14 @@ type listenerConfigCheck struct {
|
||||
Port bool
|
||||
}
|
||||
|
||||
// mobileListenerPort returns hardcoded port for mobile platforms.
|
||||
func mobileListenerPort() int {
|
||||
if runtime.GOOS == "ios" {
|
||||
return 53
|
||||
}
|
||||
return 5354
|
||||
}
|
||||
|
||||
// updateListenerConfig updates the config for listeners if not defined,
|
||||
// or defined but invalid to be used, e.g: using loopback address other
|
||||
// than 127.0.0.1 with systemd-resolved.
|
||||
@@ -1435,7 +1503,25 @@ func updateListenerConfig() (updated bool) {
|
||||
}
|
||||
updated = updated || lcc[n].IP || lcc[n].Port
|
||||
}
|
||||
|
||||
if isMobile() {
|
||||
// On Mobile, only use first listener, ignore others.
|
||||
firstLn := cfg.FirstListener()
|
||||
for k := range cfg.Listener {
|
||||
if cfg.Listener[k] != firstLn {
|
||||
delete(cfg.Listener, k)
|
||||
}
|
||||
}
|
||||
// In cd mode, always use 127.0.0.1:5354.
|
||||
if cdMode {
|
||||
firstLn.IP = "127.0.0.1" // Mobile platforms allows running listener only on loop back address.
|
||||
firstLn.Port = mobileListenerPort()
|
||||
// TODO: use clear(lcc) once upgrading to go 1.21
|
||||
for k := range lcc {
|
||||
delete(lcc, k)
|
||||
}
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
var closers []io.Closer
|
||||
defer func() {
|
||||
for _, closer := range closers {
|
||||
@@ -1656,12 +1742,12 @@ func removeProvTokenFromArgs(sc *service.Config) {
|
||||
continue
|
||||
}
|
||||
// For "--cd-org XXX", skip it and mark next arg skipped.
|
||||
if x == "--cd-org" {
|
||||
if x == cdOrgFlagName {
|
||||
skip = true
|
||||
continue
|
||||
}
|
||||
// For "--cd-org=XXX", just skip it.
|
||||
if strings.HasPrefix(x, "--cd-org=") {
|
||||
if strings.HasPrefix(x, cdOrgFlagName+"=") {
|
||||
continue
|
||||
}
|
||||
a = append(a, x)
|
||||
@@ -1700,3 +1786,15 @@ func newSocketControlClient(s service.Service, dir string) *controlClient {
|
||||
|
||||
return cc
|
||||
}
|
||||
|
||||
// checkStrFlagEmpty validates if a string flag was set to an empty string.
|
||||
// If yes, emitting a fatal error message.
|
||||
func checkStrFlagEmpty(cmd *cobra.Command, flagName string) {
|
||||
fl := cmd.Flags().Lookup(flagName)
|
||||
if !fl.Changed || fl.Value.Type() != "string" {
|
||||
return
|
||||
}
|
||||
if fl.Value.String() == "" {
|
||||
mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,10 +5,8 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -16,10 +14,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"go4.org/mem"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/util/lineread"
|
||||
"tailscale.com/net/netaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
@@ -54,12 +51,12 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
p.sema.acquire()
|
||||
defer p.sema.release()
|
||||
go p.detectLoop(m)
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
reqId := requestID()
|
||||
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
||||
mac := macFromMsg(m)
|
||||
ci := p.getClientInfo(remoteIP, mac)
|
||||
ci := p.getClientInfo(remoteIP, m)
|
||||
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci)
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String())
|
||||
t := time.Now()
|
||||
@@ -121,7 +118,8 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
})
|
||||
}
|
||||
g.Go(func() error {
|
||||
s, errCh := runDNSServer(dnsListenAddress(listenerConfig), proto, handler)
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
s, errCh := runDNSServer(addr, proto, handler)
|
||||
defer s.Shutdown()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
@@ -149,7 +147,7 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
// processed later, because policy logging want to know whether a network rule
|
||||
// is disregarded in favor of the domain level rule.
|
||||
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
|
||||
upstreams := []string{"upstream." + defaultUpstreamNum}
|
||||
upstreams := []string{upstreamPrefix + defaultUpstreamNum}
|
||||
matchedPolicy := "no policy"
|
||||
matchedNetwork := "no network"
|
||||
matchedRule := "no rule"
|
||||
@@ -233,7 +231,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{"upstream.os"}
|
||||
upstreams = []string{upstreamOS}
|
||||
}
|
||||
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
||||
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
|
||||
@@ -277,6 +275,12 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
answer, err := resolve1(n, upstreamConfig, msg)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
|
||||
if errNetworkError(err) {
|
||||
p.um.increaseFailureCount(upstreams[n])
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
go p.um.checkUpstream(upstreams[n], upstreamConfig)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return answer
|
||||
@@ -285,6 +289,14 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
if p.isLoop(upstreamConfig) {
|
||||
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint)
|
||||
continue
|
||||
}
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
@@ -316,7 +328,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
}
|
||||
return answer
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all upstreams failed")
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(msg, dns.RcodeServerFailure)
|
||||
return answer
|
||||
@@ -325,7 +337,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
||||
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
|
||||
upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix)
|
||||
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
||||
}
|
||||
return upstreamConfigs
|
||||
@@ -422,29 +434,24 @@ func needLocalIPv6Listener() bool {
|
||||
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
|
||||
}
|
||||
|
||||
func dnsListenAddress(lc *ctrld.ListenerConfig) string {
|
||||
// If we are inside container and the listener loopback address, change
|
||||
// the address to something like 0.0.0.0:53, so user can expose the port to outside.
|
||||
if inContainer() {
|
||||
if ip := net.ParseIP(lc.IP); ip != nil && ip.IsLoopback() {
|
||||
return net.JoinHostPort("0.0.0.0", strconv.Itoa(lc.Port))
|
||||
}
|
||||
}
|
||||
return net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
|
||||
}
|
||||
|
||||
func macFromMsg(msg *dns.Msg) string {
|
||||
// ipAndMacFromMsg extracts IP and MAC information included in a DNS message, if any.
|
||||
func ipAndMacFromMsg(msg *dns.Msg) (string, string) {
|
||||
ip, mac := "", ""
|
||||
if opt := msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
switch e := s.(type) {
|
||||
case *dns.EDNS0_LOCAL:
|
||||
if e.Code == EDNS0_OPTION_MAC {
|
||||
return net.HardwareAddr(e.Data).String()
|
||||
mac = net.HardwareAddr(e.Data).String()
|
||||
}
|
||||
case *dns.EDNS0_SUBNET:
|
||||
if len(e.Address) > 0 && !e.Address.IsLoopback() {
|
||||
ip = e.Address.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return ip, mac
|
||||
}
|
||||
|
||||
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
||||
@@ -498,55 +505,73 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha
|
||||
return s, errCh
|
||||
}
|
||||
|
||||
// inContainer reports whether we're running in a container.
|
||||
//
|
||||
// Copied from https://github.com/tailscale/tailscale/blob/v1.42.0/hostinfo/hostinfo.go#L260
|
||||
// with modification for ctrld usage.
|
||||
func inContainer() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
|
||||
ci := &ctrld.ClientInfo{}
|
||||
if p.appCallback != nil {
|
||||
ci.IP = p.appCallback.LanIp()
|
||||
ci.Mac = p.appCallback.MacAddress()
|
||||
ci.Hostname = p.appCallback.HostName()
|
||||
ci.Self = true
|
||||
return ci
|
||||
}
|
||||
ci.IP, ci.Mac = ipAndMacFromMsg(msg)
|
||||
switch {
|
||||
case ci.IP != "" && ci.Mac != "":
|
||||
// Nothing to do.
|
||||
case ci.IP == "" && ci.Mac != "":
|
||||
// Have MAC, no IP.
|
||||
ci.IP = p.ciTable.LookupIP(ci.Mac)
|
||||
case ci.IP == "" && ci.Mac == "":
|
||||
// Have nothing, use remote IP then lookup MAC.
|
||||
ci.IP = remoteIP
|
||||
fallthrough
|
||||
case ci.IP != "" && ci.Mac == "":
|
||||
// Have IP, no MAC.
|
||||
ci.Mac = p.ciTable.LookupMac(ci.IP)
|
||||
}
|
||||
|
||||
var ret bool
|
||||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := os.Stat("/run/.containerenv"); err == nil {
|
||||
// See https://github.com/cri-o/cri-o/issues/5461
|
||||
return true
|
||||
}
|
||||
lineread.File("/proc/1/cgroup", func(line []byte) error {
|
||||
if mem.Contains(mem.B(line), mem.S("/docker/")) ||
|
||||
mem.Contains(mem.B(line), mem.S("/lxc/")) {
|
||||
ret = true
|
||||
return io.EOF // arbitrary non-nil error to stop loop
|
||||
// If MAC is still empty here, that mean the requests are made from virtual interface,
|
||||
// like VPN/Wireguard clients, so we use whatever MAC address associated with remoteIP
|
||||
// (most likely 127.0.0.1), and ci.IP as hostname, so we can distinguish those clients.
|
||||
if ci.Mac == "" {
|
||||
ci.Mac = p.ciTable.LookupMac(remoteIP)
|
||||
if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" {
|
||||
ci.Hostname = hostname
|
||||
} else {
|
||||
ci.Hostname = ci.IP
|
||||
p.ciTable.StoreVPNClient(ci)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
lineread.File("/proc/mounts", func(line []byte) error {
|
||||
if mem.Contains(mem.B(line), mem.S("lxcfs /proc/cpuinfo fuse.lxcfs")) {
|
||||
ret = true
|
||||
return io.EOF
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return ret
|
||||
} else {
|
||||
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
|
||||
}
|
||||
ci.Self = queryFromSelf(ci.IP)
|
||||
return ci
|
||||
}
|
||||
|
||||
func (p *prog) getClientInfo(ip, mac string) *ctrld.ClientInfo {
|
||||
ci := &ctrld.ClientInfo{}
|
||||
if mac != "" {
|
||||
ci.Mac = mac
|
||||
ci.IP = p.ciTable.LookupIP(mac)
|
||||
} else {
|
||||
ci.IP = ip
|
||||
ci.Mac = p.ciTable.LookupMac(ip)
|
||||
if ip == "127.0.0.1" || ip == "::1" {
|
||||
ci.IP = p.ciTable.LookupIP(ci.Mac)
|
||||
// queryFromSelf reports whether the input IP is from device running ctrld.
|
||||
func queryFromSelf(ip string) bool {
|
||||
netIP := netip.MustParseAddr(ip)
|
||||
ifaces, err := interfaces.GetList()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not get interfaces list")
|
||||
return false
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not get interfaces addresses: %s", iface.Name)
|
||||
continue
|
||||
}
|
||||
for _, a := range addrs {
|
||||
switch v := a.(type) {
|
||||
case *net.IPNet:
|
||||
if pfx, ok := netaddr.FromStdIPNet(v); ok && pfx.Addr().Compare(netIP) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
|
||||
return ci
|
||||
return false
|
||||
}
|
||||
|
||||
func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
|
||||
|
||||
@@ -156,19 +156,27 @@ func TestCache(t *testing.T) {
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
}
|
||||
|
||||
func Test_macFromMsg(t *testing.T) {
|
||||
func Test_ipAndMacFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
wantIp bool
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has mac", "4c:20:b8:ab:87:1b", true},
|
||||
{"no mac", "4c:20:b8:ab:87:1b", false},
|
||||
{"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
{"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatal("missing IP")
|
||||
}
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -180,13 +188,23 @@ func Test_macFromMsg(t *testing.T) {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
got := macFromMsg(m)
|
||||
if tc.wantMac && got != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, got)
|
||||
if tc.wantIp {
|
||||
ec2 := &dns.EDNS0_SUBNET{Address: ip}
|
||||
o.Option = append(o.Option, ec2)
|
||||
}
|
||||
if !tc.wantMac && got != "" {
|
||||
t.Errorf("unexpected mac: %q", got)
|
||||
m.Extra = append(m.Extra, o)
|
||||
gotIP, gotMac := ipAndMacFromMsg(m)
|
||||
if tc.wantMac && gotMac != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac)
|
||||
}
|
||||
if !tc.wantMac && gotMac != "" {
|
||||
t.Errorf("unexpected mac: %q", gotMac)
|
||||
}
|
||||
if tc.wantIp && gotIP != tc.ip {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP)
|
||||
}
|
||||
if !tc.wantIp && gotIP != "" {
|
||||
t.Errorf("unexpected ip: %q", gotIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
18
cmd/cli/library.go
Normal file
18
cmd/cli/library.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
type AppCallback struct {
|
||||
HostName func() string
|
||||
LanIp func() string
|
||||
MacAddress func() string
|
||||
Exit func(error string)
|
||||
}
|
||||
|
||||
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
|
||||
type AppConfig struct {
|
||||
CdUID string
|
||||
HomeDir string
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
100
cmd/cli/loop.go
Normal file
100
cmd/cli/loop.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
loopTestDomain = ".test"
|
||||
loopTestQtype = dns.TypeTXT
|
||||
)
|
||||
|
||||
// isLoop reports whether the given upstream config is detected as having DNS loop.
|
||||
func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool {
|
||||
p.loopMu.Lock()
|
||||
defer p.loopMu.Unlock()
|
||||
return p.loop[uc.UID()]
|
||||
}
|
||||
|
||||
// detectLoop checks if the given DNS message is initialized sent by ctrld.
|
||||
// If yes, marking the corresponding upstream as loop, prevent infinite DNS
|
||||
// forwarding loop.
|
||||
//
|
||||
// See p.checkDnsLoop for more details how it works.
|
||||
func (p *prog) detectLoop(msg *dns.Msg) {
|
||||
if len(msg.Question) != 1 {
|
||||
return
|
||||
}
|
||||
q := msg.Question[0]
|
||||
if q.Qtype != loopTestQtype {
|
||||
return
|
||||
}
|
||||
unFQDNname := strings.TrimSuffix(q.Name, ".")
|
||||
uid := strings.TrimSuffix(unFQDNname, loopTestDomain)
|
||||
p.loopMu.Lock()
|
||||
if _, loop := p.loop[uid]; loop {
|
||||
p.loop[uid] = loop
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
}
|
||||
|
||||
// checkDnsLoop sends a message to check if there's any DNS forwarding loop
|
||||
// with all the upstreams. The way it works based on dnsmasq --dns-loop-detect.
|
||||
//
|
||||
// - Generating a TXT test query and sending it to all upstream.
|
||||
// - The test query is formed by upstream UID and test domain: <uid>.test
|
||||
// - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop).
|
||||
//
|
||||
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
|
||||
func (p *prog) checkDnsLoop() {
|
||||
mainLog.Load().Debug().Msg("start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uid := uc.UID()
|
||||
p.loop[uid] = false
|
||||
upstream[uid] = uc
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
}
|
||||
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("end checking DNS loop")
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
func (p *prog) checkDnsLoopTicker() {
|
||||
timer := time.NewTicker(time.Minute)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-timer.C:
|
||||
p.checkDnsLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loopTestMsg generates DNS message for checking loop.
|
||||
func loopTestMsg(uid string) *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)
|
||||
return msg
|
||||
}
|
||||
@@ -35,6 +35,12 @@ var (
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
noConfigStart bool
|
||||
)
|
||||
|
||||
const (
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -65,6 +71,7 @@ func normalizeLogFilePath(logFilePath string) string {
|
||||
return filepath.Join(dir, logFilePath)
|
||||
}
|
||||
|
||||
// initConsoleLogging initializes console logging, then storing to mainLog.
|
||||
func initConsoleLogging() {
|
||||
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
|
||||
@@ -3,6 +3,7 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
@@ -16,13 +17,21 @@ const (
|
||||
dns=none
|
||||
systemd-resolved=false
|
||||
`
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
systemdEnabledState = "enabled"
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
)
|
||||
|
||||
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
|
||||
|
||||
// hasNetworkManager reports whether NetworkManager executable found.
|
||||
func hasNetworkManager() bool {
|
||||
exe, _ := exec.LookPath("NetworkManager")
|
||||
return exe != ""
|
||||
}
|
||||
|
||||
func setupNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
|
||||
mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
return nil
|
||||
@@ -43,6 +52,9 @@ func setupNetworkManager() error {
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
err := os.Remove(networkManagerCtrldConfFile)
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
@@ -71,6 +83,7 @@ func reloadNetworkManager() {
|
||||
waitCh := make(chan string)
|
||||
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -85,8 +84,13 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
if useSystemdResolved {
|
||||
if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out))
|
||||
}
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if reflect.DeepEqual(currentNS, nameservers) {
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -104,7 +108,7 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if reflect.DeepEqual(currentNS, nameservers) {
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -265,3 +269,33 @@ func ignoringEINTR(fn func() error) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isSubSet reports whether s2 contains all elements of s1.
|
||||
func isSubSet(s1, s2 []string) bool {
|
||||
ok := true
|
||||
for _, ns := range s1 {
|
||||
// TODO(cuonglm): use slices.Contains once upgrading to go1.21
|
||||
if sliceContains(s2, ns) {
|
||||
continue
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// sliceContains reports whether v is present in s.
|
||||
func sliceContains[S ~[]E, E comparable](s S, v E) bool {
|
||||
return sliceIndex(s, v) >= 0
|
||||
}
|
||||
|
||||
// sliceIndex returns the index of the first occurrence of v in s,
|
||||
// or -1 if not present.
|
||||
func sliceIndex[S ~[]E, E comparable](s S, v E) int {
|
||||
for i := range s {
|
||||
if v == s[i] {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
162
cmd/cli/prog.go
162
cmd/cli/prog.go
@@ -1,13 +1,16 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -25,6 +28,8 @@ const (
|
||||
defaultSemaphoreCap = 256
|
||||
ctrldLogUnixSock = "ctrld_start.sock"
|
||||
ctrldControlUnixSock = "ctrld_control.sock"
|
||||
upstreamPrefix = "upstream."
|
||||
upstreamOS = upstreamPrefix + "os"
|
||||
)
|
||||
|
||||
var logf = func(format string, args ...any) {
|
||||
@@ -46,11 +51,16 @@ type prog struct {
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
|
||||
cfg *ctrld.Config
|
||||
cache dnscache.Cacher
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
router router.Router
|
||||
cfg *ctrld.Config
|
||||
appCallback *AppCallback
|
||||
cache dnscache.Cacher
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
um *upstreamMonitor
|
||||
router router.Router
|
||||
|
||||
loopMu sync.Mutex
|
||||
loop map[string]bool
|
||||
|
||||
started chan struct{}
|
||||
onStartedDone chan struct{}
|
||||
@@ -84,6 +94,7 @@ func (p *prog) run() {
|
||||
numListeners := len(p.cfg.Listener)
|
||||
p.started = make(chan struct{}, numListeners)
|
||||
p.onStartedDone = make(chan struct{})
|
||||
p.loop = make(map[string]bool)
|
||||
if p.cfg.Service.CacheEnable {
|
||||
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
||||
if err != nil {
|
||||
@@ -114,6 +125,8 @@ func (p *prog) run() {
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
for n := range p.cfg.Upstream {
|
||||
uc := p.cfg.Upstream[n]
|
||||
uc.Init()
|
||||
@@ -133,12 +146,14 @@ func (p *prog) run() {
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.ciTable.Init()
|
||||
p.ciTable.RefreshLoop(p.stopCh)
|
||||
}()
|
||||
go p.watchLinkState()
|
||||
// Newer versions of android and iOS denies permission which breaks connectivity.
|
||||
if !isMobile() {
|
||||
go func() {
|
||||
p.ciTable.Init()
|
||||
p.ciTable.RefreshLoop(p.stopCh)
|
||||
}()
|
||||
go p.watchLinkState()
|
||||
}
|
||||
|
||||
for listenerNum := range p.cfg.Listener {
|
||||
p.cfg.Listener[listenerNum].Init()
|
||||
@@ -163,8 +178,13 @@ func (p *prog) run() {
|
||||
for _, f := range p.onStarted {
|
||||
f()
|
||||
}
|
||||
// Check for possible DNS loop.
|
||||
p.checkDnsLoop()
|
||||
close(p.onStartedDone)
|
||||
|
||||
// Start check DNS loop ticker.
|
||||
go p.checkDnsLoopTicker()
|
||||
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
initLoggingWithBackup(false)
|
||||
@@ -345,48 +365,100 @@ var (
|
||||
func errUrlNetworkError(err error) bool {
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
var opErr *net.OpError
|
||||
if errors.As(urlErr.Err, &opErr) {
|
||||
if opErr.Temporary() {
|
||||
return true
|
||||
}
|
||||
switch {
|
||||
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
|
||||
errors.Is(opErr.Err, syscall.EINVAL),
|
||||
errors.Is(opErr.Err, syscall.ENETUNREACH),
|
||||
errors.Is(opErr.Err, windowsENETUNREACH),
|
||||
errors.Is(opErr.Err, windowsEINVAL),
|
||||
errors.Is(opErr.Err, windowsECONNREFUSED):
|
||||
return true
|
||||
}
|
||||
return errNetworkError(urlErr.Err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func errNetworkError(err error) bool {
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
if opErr.Temporary() {
|
||||
return true
|
||||
}
|
||||
switch {
|
||||
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
|
||||
errors.Is(opErr.Err, syscall.EINVAL),
|
||||
errors.Is(opErr.Err, syscall.ENETUNREACH),
|
||||
errors.Is(opErr.Err, windowsENETUNREACH),
|
||||
errors.Is(opErr.Err, windowsEINVAL),
|
||||
errors.Is(opErr.Err, windowsECONNREFUSED):
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// defaultRouteIP returns IP string of the default route if present, prefer IPv4 over IPv6.
|
||||
func defaultRouteIP() string {
|
||||
if dr, err := interfaces.DefaultRoute(); err == nil {
|
||||
if netIface, err := netInterface(dr.InterfaceName); err == nil {
|
||||
addrs, _ := netIface.Addrs()
|
||||
do := func(v4 bool) net.IP {
|
||||
for _, addr := range addrs {
|
||||
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
|
||||
if v4 {
|
||||
return netIP.IP.To4()
|
||||
}
|
||||
return netIP.IP
|
||||
}
|
||||
func ifaceFirstPrivateIP(iface *net.Interface) string {
|
||||
if iface == nil {
|
||||
return ""
|
||||
}
|
||||
do := func(addrs []net.Addr, v4 bool) net.IP {
|
||||
for _, addr := range addrs {
|
||||
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
|
||||
if v4 {
|
||||
return netIP.IP.To4()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ip := do(true); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
if ip := do(false); ip != nil {
|
||||
return ip.String()
|
||||
return netIP.IP
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
addrs, _ := iface.Addrs()
|
||||
if ip := do(addrs, true); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
if ip := do(addrs, false); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6.
|
||||
func defaultRouteIP() string {
|
||||
dr, err := interfaces.DefaultRoute()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
drNetIface, err := netInterface(dr.InterfaceName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface")
|
||||
if ip := ifaceFirstPrivateIP(drNetIface); ip != "" {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface")
|
||||
return ip
|
||||
}
|
||||
|
||||
// If we reach here, it means the default route interface is connected directly to ISP.
|
||||
// We need to find the LAN interface with the same Mac address with drNetIface.
|
||||
//
|
||||
// There could be multiple LAN interfaces with the same Mac address, so we find all private
|
||||
// IPs then using the smallest one.
|
||||
var addrs []netip.Addr
|
||||
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
if i.Name == drNetIface.Name {
|
||||
return
|
||||
}
|
||||
if bytes.Equal(i.HardwareAddr, drNetIface.HardwareAddr) {
|
||||
for _, pfx := range prefixes {
|
||||
addr := pfx.Addr()
|
||||
if addr.IsPrivate() {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if len(addrs) == 0 {
|
||||
mainLog.Load().Warn().Msg("no default route IP found")
|
||||
return ""
|
||||
}
|
||||
sort.Slice(addrs, func(i, j int) bool {
|
||||
return addrs[i].Less(addrs[j])
|
||||
})
|
||||
|
||||
ip := addrs[0].String()
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
|
||||
return ip
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -21,9 +20,8 @@ func setDependencies(svc *service.Config) {
|
||||
"After=NetworkManager-wait-online.service",
|
||||
"Wants=systemd-networkd-wait-online.service",
|
||||
"After=systemd-networkd-wait-online.service",
|
||||
}
|
||||
if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 {
|
||||
svc.Dependencies = append(svc.Dependencies, routerDeps...)
|
||||
"Wants=nss-lookup.target",
|
||||
"After=nss-lookup.target",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
98
cmd/cli/upstream_monitor.go
Normal file
98
cmd/cli/upstream_monitor.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 100
|
||||
// checkUpstreamMaxBackoff is the max backoff time when checking upstream status.
|
||||
checkUpstreamMaxBackoff = 2 * time.Minute
|
||||
)
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
|
||||
down map[string]*atomic.Bool
|
||||
failureReq map[string]*atomic.Uint64
|
||||
|
||||
mu sync.Mutex
|
||||
checking map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
down: make(map[string]*atomic.Bool),
|
||||
failureReq: make(map[string]*atomic.Uint64),
|
||||
checking: make(map[string]bool),
|
||||
}
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
um.down[upstream] = new(atomic.Bool)
|
||||
um.failureReq[upstream] = new(atomic.Uint64)
|
||||
}
|
||||
um.down[upstreamOS] = new(atomic.Bool)
|
||||
um.failureReq[upstreamOS] = new(atomic.Uint64)
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
failedCount := um.failureReq[upstream].Add(1)
|
||||
um.down[upstream].Store(failedCount >= maxFailureRequest)
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
return um.down[upstream].Load()
|
||||
}
|
||||
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.failureReq[upstream].Store(0)
|
||||
um.down[upstream].Store(false)
|
||||
}
|
||||
|
||||
// checkUpstream checks the given upstream status, periodically sending query to upstream
|
||||
// until successfully. An upstream status/counter will be reset once it becomes reachable.
|
||||
func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) {
|
||||
um.mu.Lock()
|
||||
isChecking := um.checking[upstream]
|
||||
if isChecking {
|
||||
um.mu.Unlock()
|
||||
return
|
||||
}
|
||||
um.checking[upstream] = true
|
||||
um.mu.Unlock()
|
||||
|
||||
bo := backoff.NewBackoff("checkUpstream", logf, checkUpstreamMaxBackoff)
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
|
||||
return
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
ctx := context.Background()
|
||||
|
||||
for {
|
||||
_, err := resolver.Resolve(ctx, msg)
|
||||
if err == nil {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
|
||||
um.reset(upstream)
|
||||
return
|
||||
}
|
||||
bo.BackOff(ctx, err)
|
||||
}
|
||||
}
|
||||
74
cmd/ctrld_library/main.go
Normal file
74
cmd/ctrld_library/main.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package ctrld_library
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
// Controller holds global state
|
||||
type Controller struct {
|
||||
stopCh chan struct{}
|
||||
AppCallback AppCallback
|
||||
Config cli.AppConfig
|
||||
}
|
||||
|
||||
// NewController provides reference to global state to be managed by android vpn service and iOS network extension.
|
||||
// reference is not safe for concurrent use.
|
||||
func NewController(appCallback AppCallback) *Controller {
|
||||
return &Controller{AppCallback: appCallback}
|
||||
}
|
||||
|
||||
// AppCallback provides access to app instance.
|
||||
type AppCallback interface {
|
||||
Hostname() string
|
||||
LanIp() string
|
||||
MacAddress() string
|
||||
Exit(error string)
|
||||
}
|
||||
|
||||
// Start configures utility with config.toml from provided directory.
|
||||
// This function will block until Stop is called
|
||||
// Check port availability prior to calling it.
|
||||
func (c *Controller) Start(CdUID string, HomeDir string, logLevel int, logPath string) {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
c.Config = cli.AppConfig{
|
||||
CdUID: CdUID,
|
||||
HomeDir: HomeDir,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
}
|
||||
appCallback := mapCallback(c.AppCallback)
|
||||
cli.RunMobile(&c.Config, &appCallback, c.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
// As workaround to avoid circular dependency between cli and ctrld_library module
|
||||
func mapCallback(callback AppCallback) cli.AppCallback {
|
||||
return cli.AppCallback{
|
||||
HostName: func() string {
|
||||
return callback.Hostname()
|
||||
},
|
||||
LanIp: func() string {
|
||||
return callback.LanIp()
|
||||
},
|
||||
MacAddress: func() string {
|
||||
return callback.MacAddress()
|
||||
},
|
||||
Exit: func(err string) {
|
||||
callback.Exit(err)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) Stop() bool {
|
||||
if c.stopCh != nil {
|
||||
close(c.stopCh)
|
||||
c.stopCh = nil
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Controller) IsRunning() bool {
|
||||
return c.stopCh != nil
|
||||
}
|
||||
26
config.go
26
config.go
@@ -2,8 +2,10 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
@@ -78,8 +80,8 @@ func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
|
||||
func InitConfig(v *viper.Viper, name string) {
|
||||
v.SetDefault("listener", map[string]*ListenerConfig{
|
||||
"0": {
|
||||
IP: "127.0.0.1",
|
||||
Port: 53,
|
||||
IP: "",
|
||||
Port: 0,
|
||||
},
|
||||
})
|
||||
v.SetDefault("network", map[string]*NetworkConfig{
|
||||
@@ -178,6 +180,7 @@ type ServiceConfig struct {
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
@@ -216,6 +219,7 @@ type UpstreamConfig struct {
|
||||
http3RoundTripper6 http.RoundTripper
|
||||
certPool *x509.CertPool
|
||||
u *url.URL
|
||||
uid string
|
||||
}
|
||||
|
||||
// ListenerConfig specifies the networks configuration that ctrld will run on.
|
||||
@@ -260,6 +264,7 @@ type Rule map[string][]string
|
||||
|
||||
// Init initialized necessary values for an UpstreamConfig.
|
||||
func (uc *UpstreamConfig) Init() {
|
||||
uc.uid = upstreamUID()
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
uc.Domain = u.Host
|
||||
switch uc.Type {
|
||||
@@ -340,6 +345,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
uc.setupBootstrapIP(true)
|
||||
}
|
||||
|
||||
// UID returns the unique identifier of the upstream.
|
||||
func (uc *UpstreamConfig) UID() string {
|
||||
return uc.uid
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
@@ -679,3 +689,15 @@ func ResolverTypeFromEndpoint(endpoint string) string {
|
||||
func pick(s []string) string {
|
||||
return s[rand.Intn(len(s))]
|
||||
}
|
||||
|
||||
// upstreamUID generates an unique identifier for an upstream.
|
||||
func upstreamUID() string {
|
||||
b := make([]byte, 4)
|
||||
for {
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
|
||||
continue
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +185,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.uid = "" // we don't care about the uid.
|
||||
assert.Equal(t, tc.expected, tc.uc)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -43,7 +44,6 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
rt := &http3.RoundTripper{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
domain := addr
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
@@ -57,20 +57,23 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
|
||||
return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
||||
}
|
||||
dialAddrs := make([]string, len(addrs))
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
pd := &quicParallelDialer{}
|
||||
conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg)
|
||||
conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) {
|
||||
rt.CloseIdleConnections()
|
||||
})
|
||||
return rt
|
||||
}
|
||||
|
||||
@@ -107,13 +110,15 @@ type parallelDialerResult struct {
|
||||
type quicParallelDialer struct{}
|
||||
|
||||
// Dial performs parallel dialing to the given address list.
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
ch := make(chan *parallelDialerResult, len(addrs))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(addrs))
|
||||
@@ -135,9 +140,14 @@ func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []st
|
||||
ch <- ¶llelDialerResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
|
||||
ch <- ¶llelDialerResult{conn: conn, err: err}
|
||||
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
||||
select {
|
||||
case ch <- ¶llelDialerResult{conn: conn, err: err}:
|
||||
case <-done:
|
||||
if conn != nil {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
|
||||
}
|
||||
}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
|
||||
32
docker/Dockerfile.debug
Normal file
32
docker/Dockerfile.debug
Normal file
@@ -0,0 +1,32 @@
|
||||
# Using Debian bullseye for building regular image.
|
||||
# Using scratch image for minimal image size.
|
||||
# The final image has:
|
||||
#
|
||||
# - Timezone info file.
|
||||
# - CA certs file.
|
||||
# - /etc/{passwd,group} file.
|
||||
# - Non-cgo ctrld binary.
|
||||
#
|
||||
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
||||
FROM golang:1.20-bullseye as base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y upx-ucl
|
||||
|
||||
COPY . .
|
||||
|
||||
ARG tag=master
|
||||
ENV CI_COMMIT_TAG=$tag
|
||||
RUN CTRLD_NO_QF=yes CGO_ENABLED=0 ./scripts/build.sh
|
||||
|
||||
FROM alpine
|
||||
|
||||
COPY --from=base /usr/share/zoneinfo /usr/share/zoneinfo
|
||||
COPY --from=base /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
COPY --from=base /etc/passwd /etc/passwd
|
||||
COPY --from=base /etc/group /etc/group
|
||||
|
||||
COPY --from=base /app/ctrld-linux-*-nocgo ctrld
|
||||
|
||||
ENTRYPOINT ["./ctrld", "run"]
|
||||
@@ -193,6 +193,13 @@ Perform LAN client discovery using PTR queries.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
### discover_hosts
|
||||
Perform LAN client discovery using hosts file.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
### dhcp_lease_file_path
|
||||
Relative or absolute path to a custom DHCP leases file location.
|
||||
|
||||
|
||||
63
doh.go
63
doh.go
@@ -8,6 +8,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cuonglm/osinfo"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -16,9 +21,56 @@ const (
|
||||
dohMacHeader = "x-cd-mac"
|
||||
dohIPHeader = "x-cd-ip"
|
||||
dohHostHeader = "x-cd-host"
|
||||
dohOsHeader = "x-cd-os"
|
||||
headerApplicationDNS = "application/dns-message"
|
||||
)
|
||||
|
||||
// EncodeOsNameMap provides mapping from OS name to a shorter string, used for encoding x-cd-os value.
|
||||
var EncodeOsNameMap = map[string]string{
|
||||
"windows": "1",
|
||||
"darwin": "2",
|
||||
"linux": "3",
|
||||
"freebsd": "4",
|
||||
}
|
||||
|
||||
// DecodeOsNameMap provides mapping from encoded OS name to real value, used for decoding x-cd-os value.
|
||||
var DecodeOsNameMap = map[string]string{}
|
||||
|
||||
// EncodeArchNameMap provides mapping from OS arch to a shorter string, used for encoding x-cd-os value.
|
||||
var EncodeArchNameMap = map[string]string{
|
||||
"amd64": "1",
|
||||
"arm64": "2",
|
||||
"arm": "3",
|
||||
"386": "4",
|
||||
"mips": "5",
|
||||
"mipsle": "6",
|
||||
"mips64": "7",
|
||||
}
|
||||
|
||||
// DecodeArchNameMap provides mapping from encoded OS arch to real value, used for decoding x-cd-os value.
|
||||
var DecodeArchNameMap = map[string]string{}
|
||||
|
||||
func init() {
|
||||
for k, v := range EncodeOsNameMap {
|
||||
DecodeOsNameMap[v] = k
|
||||
}
|
||||
for k, v := range EncodeArchNameMap {
|
||||
DecodeArchNameMap[v] = k
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: use sync.OnceValue when upgrading to go1.21
|
||||
var xCdOsValueOnce sync.Once
|
||||
var xCdOsValue string
|
||||
|
||||
func dohOsHeaderValue() string {
|
||||
xCdOsValueOnce.Do(func() {
|
||||
oi := osinfo.New()
|
||||
xCdOsValue = strings.Join([]string{EncodeOsNameMap[runtime.GOOS], EncodeArchNameMap[runtime.GOARCH], oi.Dist}, "-")
|
||||
})
|
||||
return xCdOsValue
|
||||
}
|
||||
|
||||
func newDohResolver(uc *UpstreamConfig) *dohResolver {
|
||||
r := &dohResolver{
|
||||
endpoint: uc.u,
|
||||
@@ -97,8 +149,12 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
|
||||
req.Header.Set("Content-Type", headerApplicationDNS)
|
||||
req.Header.Set("Accept", headerApplicationDNS)
|
||||
req.Header.Set(dohOsHeader, dohOsHeaderValue())
|
||||
|
||||
printed := false
|
||||
if sendClientInfo {
|
||||
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
|
||||
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
|
||||
if ci.Mac != "" {
|
||||
req.Header.Set(dohMacHeader, ci.Mac)
|
||||
}
|
||||
@@ -108,7 +164,12 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set(dohHostHeader, ci.Hostname)
|
||||
}
|
||||
if ci.Self {
|
||||
req.Header.Set(dohOsHeader, dohOsHeaderValue())
|
||||
}
|
||||
}
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header")
|
||||
if printed {
|
||||
Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header")
|
||||
}
|
||||
}
|
||||
|
||||
23
doh_test.go
Normal file
23
doh_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_dohOsHeaderValue(t *testing.T) {
|
||||
val := dohOsHeaderValue()
|
||||
if val == "" {
|
||||
t.Fatalf("empty %s", dohOsHeader)
|
||||
}
|
||||
t.Log(val)
|
||||
|
||||
encodedOs := EncodeOsNameMap[runtime.GOOS]
|
||||
if encodedOs == "" {
|
||||
t.Fatalf("missing encoding value for: %q", runtime.GOOS)
|
||||
}
|
||||
decodedOs := DecodeOsNameMap[encodedOs]
|
||||
if decodedOs == "" {
|
||||
t.Fatalf("missing decoding value for: %q", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
2
doq.go
2
doq.go
@@ -51,7 +51,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
|
||||
}
|
||||
|
||||
func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) {
|
||||
session, err := quic.DialAddr(endpoint, tlsConfig, nil)
|
||||
session, err := quic.DialAddr(ctx, endpoint, tlsConfig, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
16
go.mod
16
go.mod
@@ -4,7 +4,7 @@ go 1.20
|
||||
|
||||
require (
|
||||
github.com/coreos/go-systemd/v22 v22.5.0
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf
|
||||
github.com/frankban/quicktest v1.14.5
|
||||
github.com/fsnotify/fsnotify v1.6.0
|
||||
github.com/go-playground/validator/v10 v10.11.1
|
||||
@@ -12,19 +12,19 @@ require (
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.1
|
||||
github.com/illarion/gonotify v1.0.1
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16
|
||||
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
|
||||
github.com/kardianos/service v1.2.1
|
||||
github.com/miekg/dns v1.1.55
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/pelletier/go-toml/v2 v2.0.8
|
||||
github.com/quic-go/quic-go v0.32.0
|
||||
github.com/quic-go/quic-go v0.38.0
|
||||
github.com/rs/zerolog v1.28.0
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.16.0
|
||||
github.com/stretchr/testify v1.8.3
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13
|
||||
golang.org/x/net v0.10.0
|
||||
golang.org/x/sync v0.2.0
|
||||
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a
|
||||
@@ -37,7 +37,7 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-playground/locales v0.14.0 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
@@ -56,13 +56,11 @@ require (
|
||||
github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.2.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.17 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qtls-go1-18 v0.2.0 // indirect
|
||||
github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.3.2 // indirect
|
||||
github.com/rivo/uniseg v0.4.4 // indirect
|
||||
github.com/rogpeppe/go-internal v1.10.0 // indirect
|
||||
github.com/spf13/afero v1.9.5 // indirect
|
||||
@@ -71,8 +69,10 @@ require (
|
||||
github.com/subosito/gotenv v1.4.2 // indirect
|
||||
github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
|
||||
golang.org/x/crypto v0.9.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect
|
||||
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda // indirect
|
||||
golang.org/x/mod v0.10.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
golang.org/x/tools v0.9.1 // indirect
|
||||
|
||||
29
go.sum
29
go.sum
@@ -57,6 +57,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf h1:40DHYsri+d1bnroFDU2FQAeq68f3kAlOzlQ93kCf26Q=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -73,6 +75,7 @@ github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbS
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
|
||||
@@ -81,8 +84,8 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j
|
||||
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
|
||||
github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ=
|
||||
github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
@@ -162,6 +165,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16 h1:+aAGyK41KRn8jbF2Q7PLL0Sxwg6dShGcQSeCC7nZQ8E=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16/go.mod h1:IKrnDWs3/Mqq5n0lI+RxA2sB7MvN/vbMBP3ehXg65UI=
|
||||
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c h1:kbTQ8oGf+BVFvt/fM+ECI+NbZDCqoi0vtZTfB2p2hrI=
|
||||
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c/go.mod h1:k6+89xKz7BSMJ+DzIerBdtpEUeTlBMugO/hcVSzahog=
|
||||
github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk=
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8=
|
||||
@@ -211,9 +216,9 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI=
|
||||
github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk=
|
||||
github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
@@ -227,14 +232,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwqqj7U=
|
||||
github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc=
|
||||
github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk=
|
||||
github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
||||
github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI=
|
||||
github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
|
||||
github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA=
|
||||
github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.2 h1:rRgN3WfnKbyik4dBV8A6girlJVxGand/d+jVKbQq5GI=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.2/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/quic-go/quic-go v0.38.0 h1:T45lASr5q/TrVwt+jrVccmqHhPL2XuSyoCLVCpfOSLc=
|
||||
github.com/quic-go/quic-go v0.38.0/go.mod h1:MPCuRq7KBK2hNcfKj/1iD1BGuN3eAYMeNxp3T42LRUg=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
|
||||
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
@@ -330,6 +331,8 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI
|
||||
golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
|
||||
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda h1:O+EUvnBNPwI4eLthn8W5K+cS8zQZfgTABPLNm6Bna34=
|
||||
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda/go.mod h1:aAjjkJNdrh3PMckS4B10TGS2nag27cbKR1y2BpUxsiY=
|
||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
|
||||
@@ -33,6 +33,9 @@ func (a *arpDiscover) String() string {
|
||||
}
|
||||
|
||||
func (a *arpDiscover) List() []string {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
a.ip.Range(func(key, value any) bool {
|
||||
ips = append(ips, value.(string))
|
||||
|
||||
@@ -73,6 +73,8 @@ type Table struct {
|
||||
arp *arpDiscover
|
||||
ptr *ptrDiscover
|
||||
mdns *mdns
|
||||
hf *hostsFile
|
||||
vni *virtualNetworkIface
|
||||
cfg *ctrld.Config
|
||||
quitCh chan struct{}
|
||||
selfIP string
|
||||
@@ -116,6 +118,7 @@ func (t *Table) Init() {
|
||||
}
|
||||
|
||||
func (t *Table) init() {
|
||||
// Custom client ID presents, use it as the only source.
|
||||
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery")
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP}
|
||||
@@ -125,6 +128,11 @@ func (t *Table) init() {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.dhcp)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, process all possible sources in order, that means
|
||||
// the first result of IP/MAC/Hostname lookup will be used.
|
||||
//
|
||||
// Merlin custom clients.
|
||||
if t.discoverDHCP() || t.discoverARP() {
|
||||
t.merlin = &merlinDiscover{}
|
||||
if err := t.merlin.refresh(); err != nil {
|
||||
@@ -134,6 +142,19 @@ func (t *Table) init() {
|
||||
t.refreshers = append(t.refreshers, t.merlin)
|
||||
}
|
||||
}
|
||||
// Hosts file mapping.
|
||||
if t.discoverHosts() {
|
||||
t.hf = &hostsFile{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery")
|
||||
if err := t.hf.init(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init hosts file discover")
|
||||
} else {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.hf)
|
||||
t.refreshers = append(t.refreshers, t.hf)
|
||||
}
|
||||
go t.hf.watchChanges()
|
||||
}
|
||||
// DHCP lease files.
|
||||
if t.discoverDHCP() {
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery")
|
||||
@@ -146,6 +167,7 @@ func (t *Table) init() {
|
||||
}
|
||||
go t.dhcp.watchChanges()
|
||||
}
|
||||
// ARP table.
|
||||
if t.discoverARP() {
|
||||
t.arp = &arpDiscover{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery")
|
||||
@@ -157,6 +179,7 @@ func (t *Table) init() {
|
||||
t.refreshers = append(t.refreshers, t.arp)
|
||||
}
|
||||
}
|
||||
// PTR lookup.
|
||||
if t.discoverPTR() {
|
||||
t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery")
|
||||
@@ -167,6 +190,7 @@ func (t *Table) init() {
|
||||
t.refreshers = append(t.refreshers, t.ptr)
|
||||
}
|
||||
}
|
||||
// mdns.
|
||||
if t.discoverMDNS() {
|
||||
t.mdns = &mdns{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery")
|
||||
@@ -176,6 +200,11 @@ func (t *Table) init() {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.mdns)
|
||||
}
|
||||
}
|
||||
// VPN clients.
|
||||
if t.discoverDHCP() || t.discoverARP() {
|
||||
t.vni = &virtualNetworkIface{}
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.vni)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) LookupIP(mac string) string {
|
||||
@@ -259,7 +288,7 @@ func (t *Table) ListClients() []*Client {
|
||||
_ = r.refresh()
|
||||
}
|
||||
ipMap := make(map[string]*Client)
|
||||
il := []ipLister{t.dhcp, t.arp, t.ptr, t.mdns}
|
||||
il := []ipLister{t.dhcp, t.arp, t.ptr, t.mdns, t.vni}
|
||||
for _, ir := range il {
|
||||
for _, ip := range ir.List() {
|
||||
c, ok := ipMap[ip]
|
||||
@@ -300,6 +329,15 @@ func (t *Table) ListClients() []*Client {
|
||||
return clients
|
||||
}
|
||||
|
||||
// StoreVPNClient stores client info for VPN clients.
|
||||
func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) {
|
||||
if ci == nil || t.vni == nil {
|
||||
return
|
||||
}
|
||||
t.vni.mac.Store(ci.IP, ci.Mac)
|
||||
t.vni.ip2name.Store(ci.IP, ci.Hostname)
|
||||
}
|
||||
|
||||
func (t *Table) discoverDHCP() bool {
|
||||
if t.cfg.Service.DiscoverDHCP == nil {
|
||||
return true
|
||||
@@ -328,6 +366,13 @@ func (t *Table) discoverPTR() bool {
|
||||
return *t.cfg.Service.DiscoverPtr
|
||||
}
|
||||
|
||||
func (t *Table) discoverHosts() bool {
|
||||
if t.cfg.Service.DiscoverHosts == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverHosts
|
||||
}
|
||||
|
||||
// normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file.
|
||||
func normalizeIP(in string) string {
|
||||
// dnsmasq may put ip with interface index in lease file, strip it here.
|
||||
|
||||
@@ -47,12 +47,25 @@ func (d *dhcp) watchChanges() {
|
||||
if d.watcher == nil {
|
||||
return
|
||||
}
|
||||
if dir := router.LeaseFilesDir(); dir != "" {
|
||||
if err := d.watcher.Add(dir); err != nil {
|
||||
ctrld.ProxyLogger.Load().Err(err).Str("dir", dir).Msg("could not watch lease dir")
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-d.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Create) {
|
||||
if format, ok := clientInfoFiles[event.Name]; ok {
|
||||
if err := d.addLeaseFile(event.Name, format); err != nil {
|
||||
ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("could not add lease file")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) {
|
||||
format := clientInfoFiles[event.Name]
|
||||
if err := d.readLeaseFile(event.Name, format); err != nil && !os.IsNotExist(err) {
|
||||
@@ -106,6 +119,9 @@ func (d *dhcp) String() string {
|
||||
}
|
||||
|
||||
func (d *dhcp) List() []string {
|
||||
if d == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
d.ip.Range(func(key, value any) bool {
|
||||
ips = append(ips, value.(string))
|
||||
|
||||
120
internal/clientinfo/hostsfile.go
Normal file
120
internal/clientinfo/hostsfile.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/jaytaylor/go-hostsfile"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4LocalhostName = "localhost"
|
||||
ipv6LocalhostName = "ip6-localhost"
|
||||
ipv6LoopbackName = "ip6-loopback"
|
||||
)
|
||||
|
||||
// hostsFile provides client discovery functionality using system hosts file.
|
||||
type hostsFile struct {
|
||||
watcher *fsnotify.Watcher
|
||||
mu sync.Mutex
|
||||
m map[string][]string
|
||||
}
|
||||
|
||||
// init performs initialization works, which is necessary before hostsFile can be fully operated.
|
||||
func (hf *hostsFile) init() error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hf.watcher = watcher
|
||||
if err := hf.watcher.Add(hostsfile.HostsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
m, err := hostsfile.ParseHosts(hostsfile.ReadHostsFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hf.mu.Lock()
|
||||
hf.m = m
|
||||
hf.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// refresh reloads hosts file entries.
|
||||
func (hf *hostsFile) refresh() error {
|
||||
m, err := hostsfile.ParseHosts(hostsfile.ReadHostsFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hf.mu.Lock()
|
||||
hf.m = m
|
||||
hf.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// watchChanges watches and updates hosts file data if any changes happens.
|
||||
func (hf *hostsFile) watchChanges() {
|
||||
if hf.watcher == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-hf.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) {
|
||||
if err := hf.refresh(); err != nil && !os.IsNotExist(err) {
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("hosts file changed but failed to update client info")
|
||||
}
|
||||
}
|
||||
case err, ok := <-hf.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// LookupHostnameByIP returns hostname for given IP from current hosts file entries.
|
||||
func (hf *hostsFile) LookupHostnameByIP(ip string) string {
|
||||
hf.mu.Lock()
|
||||
defer hf.mu.Unlock()
|
||||
if names := hf.m[ip]; len(names) > 0 {
|
||||
isLoopback := ip == "127.0.0.1" || ip == "::1"
|
||||
for _, hostname := range names {
|
||||
name := normalizeHostname(hostname)
|
||||
// Ignoring ipv4/ipv6 loopback entry.
|
||||
if isLoopback && isLocalhostName(name) {
|
||||
continue
|
||||
}
|
||||
return name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// LookupHostnameByMac returns hostname for given Mac from current hosts file entries.
|
||||
func (hf *hostsFile) LookupHostnameByMac(mac string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// String returns human-readable format of hostsFile.
|
||||
func (hf *hostsFile) String() string {
|
||||
return "hosts"
|
||||
}
|
||||
|
||||
// isLocalhostName reports whether the given hostname represents localhost.
|
||||
func isLocalhostName(hostname string) bool {
|
||||
switch hostname {
|
||||
case ipv4LocalhostName, ipv6LocalhostName, ipv6LoopbackName:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
33
internal/clientinfo/hostsfile_test.go
Normal file
33
internal/clientinfo/hostsfile_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_hostsFile_LookupHostnameByIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
hostnames []string
|
||||
expectedHostname string
|
||||
}{
|
||||
{"ipv4 loopback", "127.0.0.1", []string{ipv4LocalhostName}, ""},
|
||||
{"ipv6 loopback", "::1", []string{ipv6LocalhostName, ipv6LoopbackName}, ""},
|
||||
{"non-localhost", "::1", []string{"foo"}, "foo"},
|
||||
{"multiple hostnames", "::1", []string{ipv4LocalhostName, "foo"}, "foo"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hf := &hostsFile{m: make(map[string][]string)}
|
||||
hf.mu.Lock()
|
||||
hf.m[tc.ip] = tc.hostnames
|
||||
hf.mu.Unlock()
|
||||
if got := hf.LookupHostnameByIP(tc.ip); got != tc.expectedHostname {
|
||||
t.Errorf("unpexpected result, want: %q, got: %q", tc.expectedHostname, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -48,6 +48,9 @@ func (m *mdns) String() string {
|
||||
}
|
||||
|
||||
func (m *mdns) List() []string {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
m.name.Range(func(key, value any) bool {
|
||||
ips = append(ips, key.(string))
|
||||
|
||||
@@ -3,16 +3,19 @@ package clientinfo
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
type ptrDiscover struct {
|
||||
hostname sync.Map // ip => hostname
|
||||
resolver ctrld.Resolver
|
||||
hostname sync.Map // ip => hostname
|
||||
resolver ctrld.Resolver
|
||||
serverDown atomic.Bool
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) refresh() error {
|
||||
@@ -41,6 +44,9 @@ func (p *ptrDiscover) String() string {
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) List() []string {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
p.hostname.Range(func(key, value any) bool {
|
||||
ips = append(ips, key.(string))
|
||||
@@ -57,18 +63,24 @@ func (p *ptrDiscover) lookupHostnameFromCache(ip string) string {
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) lookupHostname(ip string) string {
|
||||
// If nameserver is down, do nothing.
|
||||
if p.serverDown.Load() {
|
||||
return ""
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
msg := new(dns.Msg)
|
||||
addr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("invalid ip address")
|
||||
ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
|
||||
return ""
|
||||
}
|
||||
msg.SetQuestion(addr, dns.TypePTR)
|
||||
ans, err := p.resolver.Resolve(ctx, msg)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not lookup IP")
|
||||
ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
|
||||
p.serverDown.Store(true)
|
||||
go p.checkServer()
|
||||
return ""
|
||||
}
|
||||
for _, rr := range ans.Answer {
|
||||
@@ -80,3 +92,25 @@ func (p *ptrDiscover) lookupHostname(ip string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// checkServer monitors if the resolver can reach its nameserver. When the nameserver
|
||||
// is reachable, set p.serverDown to false, so p.lookupHostname can continue working.
|
||||
func (p *ptrDiscover) checkServer() {
|
||||
bo := backoff.NewBackoff("ptrDiscover", func(format string, args ...any) {}, time.Minute*5)
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(".", dns.TypeNS)
|
||||
ping := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, err := p.resolver.Resolve(ctx, m)
|
||||
return err
|
||||
}
|
||||
for {
|
||||
if err := ping(); err != nil {
|
||||
bo.BackOff(context.Background(), err)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
p.serverDown.Store(false)
|
||||
}
|
||||
|
||||
43
internal/clientinfo/virtual_iface.go
Normal file
43
internal/clientinfo/virtual_iface.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// virtualNetworkIface is the manager for clients from virtual network interface.
|
||||
type virtualNetworkIface struct {
|
||||
ip2name sync.Map // ip => name
|
||||
mac sync.Map // ip => mac
|
||||
}
|
||||
|
||||
// LookupHostnameByIP returns hostname of the given VPN client ip.
|
||||
func (v *virtualNetworkIface) LookupHostnameByIP(ip string) string {
|
||||
val, ok := v.ip2name.Load(ip)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return val.(string)
|
||||
}
|
||||
|
||||
// LookupHostnameByMac always returns empty string.
|
||||
func (v *virtualNetworkIface) LookupHostnameByMac(mac string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// String returns the string representation of virtualNetworkIface struct.
|
||||
func (v *virtualNetworkIface) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// List lists all known VPN clients IP.
|
||||
func (v *virtualNetworkIface) List() []string {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
v.mac.Range(func(key, value any) bool {
|
||||
ips = append(ips, key.(string))
|
||||
return true
|
||||
})
|
||||
return ips
|
||||
}
|
||||
@@ -17,6 +17,7 @@ server={{ .IP }}#{{ .Port }}
|
||||
{{- end}}
|
||||
{{- if .SendClientInfo}}
|
||||
add-mac
|
||||
add-subnet=32,128
|
||||
{{- end}}
|
||||
`
|
||||
|
||||
@@ -39,7 +40,10 @@ if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then
|
||||
pc_append "server={{ .IP }}#{{ .Port }}" "$config_file"
|
||||
{{- end}}
|
||||
{{- if .SendClientInfo}}
|
||||
pc_delete "add-mac" "$config_file"
|
||||
pc_delete "add-subnet" "$config_file"
|
||||
pc_append "add-mac" "$config_file" # add client mac
|
||||
pc_append "add-subnet=32,128" "$config_file" # add client ip
|
||||
{{- end}}
|
||||
pc_delete "dnssec" "$config_file" # disable DNSSEC
|
||||
pc_delete "trust-anchor=" "$config_file" # disable DNSSEC
|
||||
|
||||
@@ -169,9 +169,16 @@ func ContentFilteringEnabled() bool {
|
||||
return err == nil && !st.IsDir()
|
||||
}
|
||||
|
||||
func LeaseFileDir() string {
|
||||
if checkUSG() {
|
||||
return ""
|
||||
}
|
||||
return "/run"
|
||||
}
|
||||
|
||||
func checkUSG() bool {
|
||||
out, _ := exec.Command("mca-cli-op", "info").Output()
|
||||
return bytes.Contains(out, []byte("UniFi-Gateway-"))
|
||||
out, _ := os.ReadFile("/etc/version")
|
||||
return bytes.HasPrefix(out, []byte("UniFiSecurityGateway."))
|
||||
}
|
||||
|
||||
func restartDNSMasq() error {
|
||||
|
||||
@@ -173,20 +173,6 @@ func CanListenLocalhost() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceDependencies returns list of dependencies that ctrld services needs on this router.
|
||||
// See https://pkg.go.dev/github.com/kardianos/service#Config for list format.
|
||||
func ServiceDependencies() []string {
|
||||
if Name() == edgeos.Name {
|
||||
// On EdeOS, ctrld needs to start after vyatta-dhcpd, so it can read leases file.
|
||||
return []string{
|
||||
"Wants=vyatta-dhcpd.service",
|
||||
"After=vyatta-dhcpd.service",
|
||||
"Wants=dnsmasq.service",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SelfInterfaces return list of *net.Interface that will be source of requests from router itself.
|
||||
func SelfInterfaces() []*net.Interface {
|
||||
switch Name() {
|
||||
@@ -197,6 +183,14 @@ func SelfInterfaces() []*net.Interface {
|
||||
}
|
||||
}
|
||||
|
||||
// LeaseFilesDir is the directory which contains lease files.
|
||||
func LeaseFilesDir() string {
|
||||
if Name() == edgeos.Name {
|
||||
edgeos.LeaseFileDir()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func distroName() string {
|
||||
switch {
|
||||
case bytes.HasPrefix(unameO(), []byte("DD-WRT")):
|
||||
|
||||
9
nameservers_unix.go
Normal file
9
nameservers_unix.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build unix
|
||||
|
||||
package ctrld
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
|
||||
func nameserversFromResolvconf() []string {
|
||||
return resolvconffile.NameServers("")
|
||||
}
|
||||
@@ -58,3 +58,7 @@ func dnsFromAdapter() []string {
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
func nameserversFromResolvconf() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
45
resolver.go
45
resolver.go
@@ -27,13 +27,15 @@ const (
|
||||
)
|
||||
|
||||
var bootstrapDNS = "76.76.2.0"
|
||||
var or = &osResolver{nameservers: nameservers()}
|
||||
|
||||
func init() {
|
||||
if len(or.nameservers) == 0 {
|
||||
// Add bootstrap DNS in case we did not find any.
|
||||
or.nameservers = []string{net.JoinHostPort(bootstrapDNS, "53")}
|
||||
}
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = &osResolver{nameservers: defaultNameservers()}
|
||||
|
||||
// defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver.
|
||||
func defaultNameservers() []string {
|
||||
ns := nameservers()
|
||||
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
|
||||
return ns
|
||||
}
|
||||
|
||||
// Resolver is the interface that wraps the basic DNS operations.
|
||||
@@ -237,13 +239,25 @@ func NewBootstrapResolver(servers ...string) Resolver {
|
||||
return resolver
|
||||
}
|
||||
|
||||
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers.
|
||||
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers,
|
||||
// excluding nameservers from /etc/resolv.conf file.
|
||||
//
|
||||
// This is useful for doing PTR lookup in LAN network.
|
||||
func NewPrivateResolver() Resolver {
|
||||
nss := nameservers()
|
||||
resolveConfNss := nameserversFromResolvconf()
|
||||
n := 0
|
||||
for _, ns := range nss {
|
||||
host, _, _ := net.SplitHostPort(ns)
|
||||
// Ignore nameserver from resolve.conf file, because the nameserver can be either:
|
||||
//
|
||||
// - ctrld itself.
|
||||
// - Direct listener that has ctrld as an upstream (e.g: dnsmasq).
|
||||
//
|
||||
// causing the query always succeed.
|
||||
if sliceContains(resolveConfNss, host) {
|
||||
continue
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.IsPrivate() && !ip.IsLoopback() {
|
||||
nss[n] = ns
|
||||
@@ -269,3 +283,20 @@ func newDialer(dnsAddress string) *net.Dialer {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(cuonglm): use slices.Contains once upgrading to go1.21
|
||||
// sliceContains reports whether v is present in s.
|
||||
func sliceContains[S ~[]E, E comparable](s S, v E) bool {
|
||||
return sliceIndex(s, v) >= 0
|
||||
}
|
||||
|
||||
// sliceIndex returns the index of the first occurrence of v in s,
|
||||
// or -1 if not present.
|
||||
func sliceIndex[S ~[]E, E comparable](s S, v E) int {
|
||||
for i := range s {
|
||||
if v == s[i] {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ compress() {
|
||||
build() {
|
||||
goos=$1
|
||||
goarch=$2
|
||||
ldflags="-s -w -X github.com/Windscribe/ctrld/cmd/cli.version="${CI_COMMIT_TAG:-dev}" -X github.com/Windscribe/ctrld/cmd/cli.commit=$(git rev-parse HEAD)"
|
||||
ldflags="-s -w -X github.com/Control-D-Inc/ctrld/cmd/cli.version="${CI_COMMIT_TAG:-dev}" -X github.com/Control-D-Inc/ctrld/cmd/cli.commit=$(git rev-parse HEAD)"
|
||||
|
||||
case $3 in
|
||||
5 | 6 | 7)
|
||||
|
||||
Reference in New Issue
Block a user