diff --git a/README.md b/README.md index 74ed0cb..09543b7 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ $ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest or ``` -$ docker build -t controldns/ctrld . +$ docker build -t controldns/ctrld . -f docker/Dockerfile $ docker run -d --name=ctrld -p 53:53/tcp -p 53:53/udp controldns/ctrld --cd=RESOLVER_ID_GOES_HERE -vv ``` @@ -188,8 +188,8 @@ See [Configuration Docs](docs/config.md). [listener] [listener.0] - ip = "127.0.0.1" - port = 53 + ip = "" + port = 0 restricted = false [network] @@ -220,6 +220,8 @@ See [Configuration Docs](docs/config.md). ``` +`ctrld` will pick a working config for `listener.0` then writing the default config to disk for the first run. + ## Advanced Configuration The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file. diff --git a/client_info.go b/client_info.go index c4494f7..f32526a 100644 --- a/client_info.go +++ b/client_info.go @@ -8,6 +8,7 @@ type ClientInfo struct { Mac string IP string Hostname string + Self bool } // LeaseFileFormat specifies the format of DHCP lease file. diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 8f54d1f..b67504c 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -124,185 +124,7 @@ func initCLI() { initConsoleLogging() }, Run: func(cmd *cobra.Command, args []string) { - waitCh := make(chan struct{}) - stopCh := make(chan struct{}) - p := &prog{ - waitCh: waitCh, - stopCh: stopCh, - cfg: &cfg, - } - if homedir == "" { - if dir, err := userHomeDir(); err == nil { - homedir = dir - } - } - sockPath := filepath.Join(homedir, ctrldLogUnixSock) - if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil { - if conn, err := net.Dial(addr.Network(), addr.String()); err == nil { - lc := &logConn{conn: conn} - consoleWriter.Out = io.MultiWriter(os.Stdout, lc) - p.logConn = lc - } - } - - if daemon && runtime.GOOS == "windows" { - mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") - } - - if !daemon { - // We need to call s.Run() as soon as possible to response to the OS manager, so it - // can see ctrld is running and don't mark ctrld as failed service. - go func() { - s, err := newService(p, svcConfig) - if err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed create new service") - } - if err := s.Run(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start service") - } - }() - } - noConfigStart := isNoConfigStart(cmd) - writeDefaultConfig := !noConfigStart && configBase64 == "" - tryReadingConfig(writeDefaultConfig) - - readBase64Config(configBase64) - processNoConfigFlags(noConfigStart) - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) - } - - processLogAndCacheFlags() - - // Log config do not have thing to validate, so it's safe to init log here, - // so it's able to log information in processCDFlags. - initLogging() - - mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) - mainLog.Load().Info().Msgf("os: %s", osVersion()) - - // Wait for network up. - if !ctrldnet.Up() { - mainLog.Load().Fatal().Msg("network is not up yet") - } - - p.router = router.New(&cfg, cdUID != "") - cs, err := newControlServer(filepath.Join(homedir, ctrldControlUnixSock)) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("could not create control server") - } - p.cs = cs - - // Processing --cd flag require connecting to ControlD API, which needs valid - // time for validating server certificate. Some routers need NTP synchronization - // to set the current time, so this check must happen before processCDFlags. - if err := p.router.PreRun(); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check") - } - - oldLogPath := cfg.Service.LogPath - if uid := cdUIDFromProvToken(); uid != "" { - cdUID = uid - } - if cdUID != "" { - processCDFlags() - } - - updated := updateListenerConfig() - - if cdUID != "" { - processLogAndCacheFlags() - } - - if updated { - if err := writeConfigFile(); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to write config file") - } else { - mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) - } - } - - if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath { - // After processCDFlags, log config may change, so reset mainLog and re-init logging. - l := zerolog.New(io.Discard) - mainLog.Store(&l) - - // Copy logs written so far to new log file if possible. - if buf, err := os.ReadFile(oldLogPath); err == nil { - if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not copy old log file") - } - } - initLoggingWithBackup(false) - } - - validateConfig(&cfg) - initCache() - - if daemon { - exe, err := os.Executable() - if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to find the binary") - os.Exit(1) - } - curDir, err := os.Getwd() - if err != nil { - mainLog.Load().Error().Err(err).Msg("failed to get current working directory") - os.Exit(1) - } - // If running as daemon, re-run the command in background, with daemon off. - cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...) - cmd.Dir = curDir - if err := cmd.Start(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to start process as daemon") - os.Exit(1) - } - mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") - os.Exit(0) - } - - p.onStarted = append(p.onStarted, func() { - for _, lc := range p.cfg.Listener { - if shouldAllocateLoopbackIP(lc.IP) { - if err := allocateIP(lc.IP); err != nil { - mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP) - } - } - } - }) - p.onStopped = append(p.onStopped, func() { - for _, lc := range p.cfg.Listener { - if shouldAllocateLoopbackIP(lc.IP) { - if err := deAllocateIP(lc.IP); err != nil { - mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP) - } - } - } - }) - if platform := router.Name(); platform != "" { - if cp := router.CertPool(); cp != nil { - rootCertPool = cp - } - p.onStarted = append(p.onStarted, func() { - mainLog.Load().Debug().Msg("router setup on start") - if err := p.router.Setup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not configure router") - } - }) - p.onStopped = append(p.onStopped, func() { - mainLog.Load().Debug().Msg("router cleanup on stop") - if err := p.router.Cleanup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not cleanup router") - } - p.resetDNS() - }) - } - - close(waitCh) - <-stopCh - for _, f := range p.onStopped { - f() - } + RunCobraCommand(cmd) }, } runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon") @@ -314,8 +136,8 @@ func initCLI() { runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") - runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") - runCmd.Flags().StringVarP(&cdOrg, "cd-org", "", "", "Control D provision token") + runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") + runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = runCmd.Flags().MarkHidden("dev") runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") @@ -334,6 +156,8 @@ func initCLI() { Short: "Install and start the ctrld service", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { + checkStrFlagEmpty(cmd, cdUidFlagName) + checkStrFlagEmpty(cmd, cdOrgFlagName) sc := &service.Config{} *sc = *svcConfig osArgs := os.Args[2:] @@ -466,8 +290,8 @@ func initCLI() { startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy") startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") - startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") - startCmd.Flags().StringVarP(&cdOrg, "cd-org", "", "", "Control D provision token") + startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") + startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) @@ -804,6 +628,9 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, map2Slice := func(m map[string]struct{}) []string { s := make([]string, 0, len(m)) for k := range m { + if k == "" { // skip empty source from output. + continue + } s = append(s, k) } sort.Strings(s) @@ -838,6 +665,222 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, rootCmd.AddCommand(clientsCmd) } +// isMobile reports whether the current OS is a mobile platform. +func isMobile() bool { + return runtime.GOOS == "android" || runtime.GOOS == "ios" +} + +// RunCobraCommand runs ctrld cli. +func RunCobraCommand(cmd *cobra.Command) { + noConfigStart = isNoConfigStart(cmd) + checkStrFlagEmpty(cmd, cdUidFlagName) + checkStrFlagEmpty(cmd, cdOrgFlagName) + run(nil, make(chan struct{})) +} + +// RunMobile runs the ctrld cli on mobile platforms. +func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struct{}) { + if appConfig == nil { + panic("appConfig is nil") + } + initConsoleLogging() + noConfigStart = false + homedir = appConfig.HomeDir + verbose = appConfig.Verbose + cdUID = appConfig.CdUID + logPath = appConfig.LogPath + run(appCallback, stopCh) +} + +// run runs ctrld cli with given app callback and stop channel. +func run(appCallback *AppCallback, stopCh chan struct{}) { + if stopCh == nil { + mainLog.Load().Fatal().Msg("stopCh is nil") + } + waitCh := make(chan struct{}) + p := &prog{ + waitCh: waitCh, + stopCh: stopCh, + cfg: &cfg, + appCallback: appCallback, + } + if homedir == "" { + if dir, err := userHomeDir(); err == nil { + homedir = dir + } + } + sockPath := filepath.Join(homedir, ctrldLogUnixSock) + if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil { + if conn, err := net.Dial(addr.Network(), addr.String()); err == nil { + lc := &logConn{conn: conn} + consoleWriter.Out = io.MultiWriter(os.Stdout, lc) + p.logConn = lc + } + } + + if daemon && runtime.GOOS == "windows" { + mainLog.Load().Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.") + } + + if !daemon { + // We need to call s.Run() as soon as possible to response to the OS manager, so it + // can see ctrld is running and don't mark ctrld as failed service. + go func() { + s, err := newService(p, svcConfig) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed create new service") + } + if err := s.Run(); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to start service") + } + }() + } + writeDefaultConfig := !noConfigStart && configBase64 == "" + tryReadingConfig(writeDefaultConfig) + + readBase64Config(configBase64) + processNoConfigFlags(noConfigStart) + if err := v.Unmarshal(&cfg); err != nil { + mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + } + + processLogAndCacheFlags() + + // Log config do not have thing to validate, so it's safe to init log here, + // so it's able to log information in processCDFlags. + initLogging() + + mainLog.Load().Info().Msgf("starting ctrld %s", curVersion()) + mainLog.Load().Info().Msgf("os: %s", osVersion()) + + // Wait for network up. + if !ctrldnet.Up() { + mainLog.Load().Fatal().Msg("network is not up yet") + } + + p.router = router.New(&cfg, cdUID != "") + cs, err := newControlServer(filepath.Join(homedir, ctrldControlUnixSock)) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("could not create control server") + } + p.cs = cs + + // Processing --cd flag require connecting to ControlD API, which needs valid + // time for validating server certificate. Some routers need NTP synchronization + // to set the current time, so this check must happen before processCDFlags. + if err := p.router.PreRun(); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to perform router pre-run check") + } + + oldLogPath := cfg.Service.LogPath + if uid := cdUIDFromProvToken(); uid != "" { + cdUID = uid + } + if cdUID != "" { + err := processCDFlags() + if err != nil { + appCallback.Exit(err.Error()) + return + } + } + + updated := updateListenerConfig() + + if cdUID != "" { + processLogAndCacheFlags() + } + + if updated { + if err := writeConfigFile(); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to write config file") + } else { + mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) + } + } + + if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath { + // After processCDFlags, log config may change, so reset mainLog and re-init logging. + l := zerolog.New(io.Discard) + mainLog.Store(&l) + + // Copy logs written so far to new log file if possible. + if buf, err := os.ReadFile(oldLogPath); err == nil { + if err := os.WriteFile(newLogPath, buf, os.FileMode(0o600)); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not copy old log file") + } + } + initLoggingWithBackup(false) + } + + validateConfig(&cfg) + initCache() + + if daemon { + exe, err := os.Executable() + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to find the binary") + os.Exit(1) + } + curDir, err := os.Getwd() + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to get current working directory") + os.Exit(1) + } + // If running as daemon, re-run the command in background, with daemon off. + cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...) + cmd.Dir = curDir + if err := cmd.Start(); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to start process as daemon") + os.Exit(1) + } + mainLog.Load().Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started") + os.Exit(0) + } + + p.onStarted = append(p.onStarted, func() { + for _, lc := range p.cfg.Listener { + if shouldAllocateLoopbackIP(lc.IP) { + if err := allocateIP(lc.IP); err != nil { + mainLog.Load().Error().Err(err).Msgf("could not allocate IP: %s", lc.IP) + } + } + } + }) + p.onStopped = append(p.onStopped, func() { + for _, lc := range p.cfg.Listener { + if shouldAllocateLoopbackIP(lc.IP) { + if err := deAllocateIP(lc.IP); err != nil { + mainLog.Load().Error().Err(err).Msgf("could not de-allocate IP: %s", lc.IP) + } + } + } + }) + if platform := router.Name(); platform != "" { + if cp := router.CertPool(); cp != nil { + rootCertPool = cp + } + p.onStarted = append(p.onStarted, func() { + mainLog.Load().Debug().Msg("router setup on start") + if err := p.router.Setup(); err != nil { + mainLog.Load().Error().Err(err).Msg("could not configure router") + } + }) + p.onStopped = append(p.onStopped, func() { + mainLog.Load().Debug().Msg("router cleanup on stop") + if err := p.router.Cleanup(); err != nil { + mainLog.Load().Error().Err(err).Msg("could not cleanup router") + } + p.resetDNS() + }) + } + + close(waitCh) + <-stopCh + for _, f := range p.onStopped { + f() + } +} + func writeConfigFile() error { if cfu := v.ConfigFileUsed(); cfu != "" { defaultConfigFile = cfu @@ -882,6 +925,7 @@ func readConfigFile(writeDefaultConfig bool) bool { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } + _ = updateListenerConfig() if err := writeConfigFile(); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) } else { @@ -971,7 +1015,7 @@ func processNoConfigFlags(noConfigStart bool) { v.Set("upstream", upstream) } -func processCDFlags() { +func processCDFlags() error { logger := mainLog.Load().With().Str("mode", "cd").Logger() logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) @@ -991,12 +1035,12 @@ func processCDFlags() { s, err := newService(&prog{}, svcConfig) if err != nil { logger.Warn().Err(err).Msg("failed to create new service") - return + return nil } if netIface, _ := netInterface(iface); netIface != nil { if err := restoreNetworkManager(); err != nil { logger.Error().Err(err).Msg("could not restore NetworkManager") - return + return nil } logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface") if err := resetDNS(netIface); err != nil { @@ -1010,11 +1054,16 @@ func processCDFlags() { if doTasks(tasks) { logger.Info().Msg("uninstalled service") } - logger.Fatal().Err(uer).Msg("failed to fetch resolver config") + event := logger.Fatal() + if isMobile() { + event = logger.Warn() + } + event.Err(uer).Msg("failed to fetch resolver config") + return uer } if err != nil { logger.Warn().Err(err).Msg("could not fetch resolver config") - return + return nil } logger.Info().Msg("generating ctrld config from Control-D configuration") @@ -1058,6 +1107,7 @@ func processCDFlags() { "0": {IP: "", Port: 0}, } } + return nil } func processListenFlag() { @@ -1266,7 +1316,17 @@ func userHomeDir() (string, error) { } // viper will expand for us. if runtime.GOOS == "windows" { - return os.UserHomeDir() + // If we're on windows, use the install path for this. + exePath, err := os.Executable() + if err != nil { + return "", err + } + + return filepath.Dir(exePath), nil + } + // Mobile platform should provide a rw dir path for this. + if isMobile() { + return homedir, nil } dir = "/etc/controld" if err := os.MkdirAll(dir, 0750); err != nil { @@ -1412,6 +1472,14 @@ type listenerConfigCheck struct { Port bool } +// mobileListenerPort returns hardcoded port for mobile platforms. +func mobileListenerPort() int { + if runtime.GOOS == "ios" { + return 53 + } + return 5354 +} + // updateListenerConfig updates the config for listeners if not defined, // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. @@ -1435,7 +1503,25 @@ func updateListenerConfig() (updated bool) { } updated = updated || lcc[n].IP || lcc[n].Port } - + if isMobile() { + // On Mobile, only use first listener, ignore others. + firstLn := cfg.FirstListener() + for k := range cfg.Listener { + if cfg.Listener[k] != firstLn { + delete(cfg.Listener, k) + } + } + // In cd mode, always use 127.0.0.1:5354. + if cdMode { + firstLn.IP = "127.0.0.1" // Mobile platforms allows running listener only on loop back address. + firstLn.Port = mobileListenerPort() + // TODO: use clear(lcc) once upgrading to go 1.21 + for k := range lcc { + delete(lcc, k) + } + updated = true + } + } var closers []io.Closer defer func() { for _, closer := range closers { @@ -1656,12 +1742,12 @@ func removeProvTokenFromArgs(sc *service.Config) { continue } // For "--cd-org XXX", skip it and mark next arg skipped. - if x == "--cd-org" { + if x == cdOrgFlagName { skip = true continue } // For "--cd-org=XXX", just skip it. - if strings.HasPrefix(x, "--cd-org=") { + if strings.HasPrefix(x, cdOrgFlagName+"=") { continue } a = append(a, x) @@ -1700,3 +1786,15 @@ func newSocketControlClient(s service.Service, dir string) *controlClient { return cc } + +// checkStrFlagEmpty validates if a string flag was set to an empty string. +// If yes, emitting a fatal error message. +func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { + fl := cmd.Flags().Lookup(flagName) + if !fl.Changed || fl.Value.Type() != "string" { + return + } + if fl.Value.String() == "" { + mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name) + } +} diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 23ae03e..12cf781 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -5,10 +5,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "io" "net" "net/netip" - "os" "runtime" "strconv" "strings" @@ -16,10 +14,9 @@ import ( "time" "github.com/miekg/dns" - "go4.org/mem" "golang.org/x/sync/errgroup" "tailscale.com/net/interfaces" - "tailscale.com/util/lineread" + "tailscale.com/net/netaddr" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/dnscache" @@ -54,12 +51,12 @@ func (p *prog) serveDNS(listenerNum string) error { handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { p.sema.acquire() defer p.sema.release() + go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) reqId := requestID() remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - mac := macFromMsg(m) - ci := p.getClientInfo(remoteIP, mac) + ci := p.getClientInfo(remoteIP, m) remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String()) t := time.Now() @@ -121,7 +118,8 @@ func (p *prog) serveDNS(listenerNum string) error { }) } g.Go(func() error { - s, errCh := runDNSServer(dnsListenAddress(listenerConfig), proto, handler) + addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) + s, errCh := runDNSServer(addr, proto, handler) defer s.Shutdown() select { case err := <-errCh: @@ -149,7 +147,7 @@ func (p *prog) serveDNS(listenerNum string) error { // processed later, because policy logging want to know whether a network rule // is disregarded in favor of the domain level rule. func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) { - upstreams := []string{"upstream." + defaultUpstreamNum} + upstreams := []string{upstreamPrefix + defaultUpstreamNum} matchedPolicy := "no policy" matchedNetwork := "no network" matchedRule := "no rule" @@ -233,7 +231,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - upstreams = []string{"upstream.os"} + upstreams = []string{upstreamOS} } // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { @@ -277,6 +275,12 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i answer, err := resolve1(n, upstreamConfig, msg) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") + if errNetworkError(err) { + p.um.increaseFailureCount(upstreams[n]) + if p.um.isDown(upstreams[n]) { + go p.um.checkUpstream(upstreams[n], upstreamConfig) + } + } return nil } return answer @@ -285,6 +289,14 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i if upstreamConfig == nil { continue } + if p.isLoop(upstreamConfig) { + mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint) + continue + } + if p.um.isDown(upstreams[n]) { + ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n]) + continue + } answer := resolve(n, upstreamConfig, msg) if answer == nil { if serveStaleCache && staleAnswer != nil { @@ -316,7 +328,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i } return answer } - ctrld.Log(ctx, mainLog.Load().Error(), "all upstreams failed") + ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) answer := new(dns.Msg) answer.SetRcode(msg, dns.RcodeServerFailure) return answer @@ -325,7 +337,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { - upstreamNum := strings.TrimPrefix(upstream, "upstream.") + upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix) upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum]) } return upstreamConfigs @@ -422,29 +434,24 @@ func needLocalIPv6Listener() bool { return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows" } -func dnsListenAddress(lc *ctrld.ListenerConfig) string { - // If we are inside container and the listener loopback address, change - // the address to something like 0.0.0.0:53, so user can expose the port to outside. - if inContainer() { - if ip := net.ParseIP(lc.IP); ip != nil && ip.IsLoopback() { - return net.JoinHostPort("0.0.0.0", strconv.Itoa(lc.Port)) - } - } - return net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) -} - -func macFromMsg(msg *dns.Msg) string { +// ipAndMacFromMsg extracts IP and MAC information included in a DNS message, if any. +func ipAndMacFromMsg(msg *dns.Msg) (string, string) { + ip, mac := "", "" if opt := msg.IsEdns0(); opt != nil { for _, s := range opt.Option { switch e := s.(type) { case *dns.EDNS0_LOCAL: if e.Code == EDNS0_OPTION_MAC { - return net.HardwareAddr(e.Data).String() + mac = net.HardwareAddr(e.Data).String() + } + case *dns.EDNS0_SUBNET: + if len(e.Address) > 0 && !e.Address.IsLoopback() { + ip = e.Address.String() } } } } - return "" + return ip, mac } func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr { @@ -498,55 +505,73 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha return s, errCh } -// inContainer reports whether we're running in a container. -// -// Copied from https://github.com/tailscale/tailscale/blob/v1.42.0/hostinfo/hostinfo.go#L260 -// with modification for ctrld usage. -func inContainer() bool { - if runtime.GOOS != "linux" { - return false +func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo { + ci := &ctrld.ClientInfo{} + if p.appCallback != nil { + ci.IP = p.appCallback.LanIp() + ci.Mac = p.appCallback.MacAddress() + ci.Hostname = p.appCallback.HostName() + ci.Self = true + return ci + } + ci.IP, ci.Mac = ipAndMacFromMsg(msg) + switch { + case ci.IP != "" && ci.Mac != "": + // Nothing to do. + case ci.IP == "" && ci.Mac != "": + // Have MAC, no IP. + ci.IP = p.ciTable.LookupIP(ci.Mac) + case ci.IP == "" && ci.Mac == "": + // Have nothing, use remote IP then lookup MAC. + ci.IP = remoteIP + fallthrough + case ci.IP != "" && ci.Mac == "": + // Have IP, no MAC. + ci.Mac = p.ciTable.LookupMac(ci.IP) } - var ret bool - if _, err := os.Stat("/.dockerenv"); err == nil { - return true - } - if _, err := os.Stat("/run/.containerenv"); err == nil { - // See https://github.com/cri-o/cri-o/issues/5461 - return true - } - lineread.File("/proc/1/cgroup", func(line []byte) error { - if mem.Contains(mem.B(line), mem.S("/docker/")) || - mem.Contains(mem.B(line), mem.S("/lxc/")) { - ret = true - return io.EOF // arbitrary non-nil error to stop loop + // If MAC is still empty here, that mean the requests are made from virtual interface, + // like VPN/Wireguard clients, so we use whatever MAC address associated with remoteIP + // (most likely 127.0.0.1), and ci.IP as hostname, so we can distinguish those clients. + if ci.Mac == "" { + ci.Mac = p.ciTable.LookupMac(remoteIP) + if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" { + ci.Hostname = hostname + } else { + ci.Hostname = ci.IP + p.ciTable.StoreVPNClient(ci) } - return nil - }) - lineread.File("/proc/mounts", func(line []byte) error { - if mem.Contains(mem.B(line), mem.S("lxcfs /proc/cpuinfo fuse.lxcfs")) { - ret = true - return io.EOF - } - return nil - }) - return ret + } else { + ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac) + } + ci.Self = queryFromSelf(ci.IP) + return ci } -func (p *prog) getClientInfo(ip, mac string) *ctrld.ClientInfo { - ci := &ctrld.ClientInfo{} - if mac != "" { - ci.Mac = mac - ci.IP = p.ciTable.LookupIP(mac) - } else { - ci.IP = ip - ci.Mac = p.ciTable.LookupMac(ip) - if ip == "127.0.0.1" || ip == "::1" { - ci.IP = p.ciTable.LookupIP(ci.Mac) +// queryFromSelf reports whether the input IP is from device running ctrld. +func queryFromSelf(ip string) bool { + netIP := netip.MustParseAddr(ip) + ifaces, err := interfaces.GetList() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("could not get interfaces list") + return false + } + for _, iface := range ifaces { + addrs, err := iface.Addrs() + if err != nil { + mainLog.Load().Warn().Err(err).Msgf("could not get interfaces addresses: %s", iface.Name) + continue + } + for _, a := range addrs { + switch v := a.(type) { + case *net.IPNet: + if pfx, ok := netaddr.FromStdIPNet(v); ok && pfx.Addr().Compare(netIP) == 0 { + return true + } + } } } - ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac) - return ci + return false } func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool { diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index b7b0dbd..674d486 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -156,19 +156,27 @@ func TestCache(t *testing.T) { assert.Equal(t, answer2.Rcode, got2.Rcode) } -func Test_macFromMsg(t *testing.T) { +func Test_ipAndMacFromMsg(t *testing.T) { tests := []struct { name string + ip string + wantIp bool mac string wantMac bool }{ - {"has mac", "4c:20:b8:ab:87:1b", true}, - {"no mac", "4c:20:b8:ab:87:1b", false}, + {"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true}, + {"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true}, + {"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false}, + {"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + ip := net.ParseIP(tc.ip) + if ip == nil { + t.Fatal("missing IP") + } hw, err := net.ParseMAC(tc.mac) if err != nil { t.Fatal(err) @@ -180,13 +188,23 @@ func Test_macFromMsg(t *testing.T) { ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw} o.Option = append(o.Option, ec1) } - m.Extra = append(m.Extra, o) - got := macFromMsg(m) - if tc.wantMac && got != tc.mac { - t.Errorf("mismatch, want: %q, got: %q", tc.mac, got) + if tc.wantIp { + ec2 := &dns.EDNS0_SUBNET{Address: ip} + o.Option = append(o.Option, ec2) } - if !tc.wantMac && got != "" { - t.Errorf("unexpected mac: %q", got) + m.Extra = append(m.Extra, o) + gotIP, gotMac := ipAndMacFromMsg(m) + if tc.wantMac && gotMac != tc.mac { + t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac) + } + if !tc.wantMac && gotMac != "" { + t.Errorf("unexpected mac: %q", gotMac) + } + if tc.wantIp && gotIP != tc.ip { + t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP) + } + if !tc.wantIp && gotIP != "" { + t.Errorf("unexpected ip: %q", gotIP) } }) } diff --git a/cmd/cli/library.go b/cmd/cli/library.go new file mode 100644 index 0000000..80612c9 --- /dev/null +++ b/cmd/cli/library.go @@ -0,0 +1,18 @@ +package cli + +// AppCallback provides hooks for injecting certain functionalities +// from mobile platforms to main ctrld cli. +type AppCallback struct { + HostName func() string + LanIp func() string + MacAddress func() string + Exit func(error string) +} + +// AppConfig allows overwriting ctrld cli flags from mobile platforms. +type AppConfig struct { + CdUID string + HomeDir string + Verbose int + LogPath string +} diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go new file mode 100644 index 0000000..87dabf8 --- /dev/null +++ b/cmd/cli/loop.go @@ -0,0 +1,100 @@ +package cli + +import ( + "context" + "strings" + "time" + + "github.com/miekg/dns" + + "github.com/Control-D-Inc/ctrld" +) + +const ( + loopTestDomain = ".test" + loopTestQtype = dns.TypeTXT +) + +// isLoop reports whether the given upstream config is detected as having DNS loop. +func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool { + p.loopMu.Lock() + defer p.loopMu.Unlock() + return p.loop[uc.UID()] +} + +// detectLoop checks if the given DNS message is initialized sent by ctrld. +// If yes, marking the corresponding upstream as loop, prevent infinite DNS +// forwarding loop. +// +// See p.checkDnsLoop for more details how it works. +func (p *prog) detectLoop(msg *dns.Msg) { + if len(msg.Question) != 1 { + return + } + q := msg.Question[0] + if q.Qtype != loopTestQtype { + return + } + unFQDNname := strings.TrimSuffix(q.Name, ".") + uid := strings.TrimSuffix(unFQDNname, loopTestDomain) + p.loopMu.Lock() + if _, loop := p.loop[uid]; loop { + p.loop[uid] = loop + } + p.loopMu.Unlock() +} + +// checkDnsLoop sends a message to check if there's any DNS forwarding loop +// with all the upstreams. The way it works based on dnsmasq --dns-loop-detect. +// +// - Generating a TXT test query and sending it to all upstream. +// - The test query is formed by upstream UID and test domain: .test +// - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop). +// +// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html +func (p *prog) checkDnsLoop() { + mainLog.Load().Debug().Msg("start checking DNS loop") + upstream := make(map[string]*ctrld.UpstreamConfig) + p.loopMu.Lock() + for _, uc := range p.cfg.Upstream { + uid := uc.UID() + p.loop[uid] = false + upstream[uid] = uc + } + p.loopMu.Unlock() + + for uid := range p.loop { + msg := loopTestMsg(uid) + uc := upstream[uid] + resolver, err := ctrld.NewResolver(uc) + if err != nil { + mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + continue + } + if _, err := resolver.Resolve(context.Background(), msg); err != nil { + mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint) + } + } + mainLog.Load().Debug().Msg("end checking DNS loop") +} + +// checkDnsLoopTicker performs p.checkDnsLoop every minute. +func (p *prog) checkDnsLoopTicker() { + timer := time.NewTicker(time.Minute) + defer timer.Stop() + for { + select { + case <-p.stopCh: + return + case <-timer.C: + p.checkDnsLoop() + } + } +} + +// loopTestMsg generates DNS message for checking loop. +func loopTestMsg(uid string) *dns.Msg { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype) + return msg +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index e7376be..f4439a5 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -35,6 +35,12 @@ var ( mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter + noConfigStart bool +) + +const ( + cdUidFlagName = "cd" + cdOrgFlagName = "cd-org" ) func init() { @@ -65,6 +71,7 @@ func normalizeLogFilePath(logFilePath string) string { return filepath.Join(dir, logFilePath) } +// initConsoleLogging initializes console logging, then storing to mainLog. func initConsoleLogging() { consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) { w.TimeFormat = time.StampMilli diff --git a/cmd/cli/network_manager_linux.go b/cmd/cli/network_manager_linux.go index 5e7b540..1a8c22b 100644 --- a/cmd/cli/network_manager_linux.go +++ b/cmd/cli/network_manager_linux.go @@ -3,6 +3,7 @@ package cli import ( "context" "os" + "os/exec" "path/filepath" "time" @@ -16,13 +17,21 @@ const ( dns=none systemd-resolved=false ` - nmSystemdUnitName = "NetworkManager.service" - systemdEnabledState = "enabled" + nmSystemdUnitName = "NetworkManager.service" ) var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename) +// hasNetworkManager reports whether NetworkManager executable found. +func hasNetworkManager() bool { + exe, _ := exec.LookPath("NetworkManager") + return exe != "" +} + func setupNetworkManager() error { + if !hasNetworkManager() { + return nil + } if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent { mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do") return nil @@ -43,6 +52,9 @@ func setupNetworkManager() error { } func restoreNetworkManager() error { + if !hasNetworkManager() { + return nil + } err := os.Remove(networkManagerCtrldConfFile) if os.IsNotExist(err) { mainLog.Load().Debug().Msg("NetworkManager is not available") @@ -71,6 +83,7 @@ func reloadNetworkManager() { waitCh := make(chan string) if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil { mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager") + return } <-waitCh } diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index 004e863..7fb692c 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -9,7 +9,6 @@ import ( "net" "net/netip" "os/exec" - "reflect" "strings" "syscall" "time" @@ -85,8 +84,13 @@ func setDNS(iface *net.Interface, nameservers []string) error { } return err } + if useSystemdResolved { + if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil { + mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out)) + } + } currentNS := currentDNS(iface) - if reflect.DeepEqual(currentNS, nameservers) { + if isSubSet(nameservers, currentNS) { return nil } } @@ -104,7 +108,7 @@ func setDNS(iface *net.Interface, nameservers []string) error { return fmt.Errorf("%s: %w", string(out), err) } currentNS := currentDNS(iface) - if reflect.DeepEqual(currentNS, nameservers) { + if isSubSet(nameservers, currentNS) { return nil } time.Sleep(time.Second) @@ -265,3 +269,33 @@ func ignoringEINTR(fn func() error) error { } } } + +// isSubSet reports whether s2 contains all elements of s1. +func isSubSet(s1, s2 []string) bool { + ok := true + for _, ns := range s1 { + // TODO(cuonglm): use slices.Contains once upgrading to go1.21 + if sliceContains(s2, ns) { + continue + } + ok = false + break + } + return ok +} + +// sliceContains reports whether v is present in s. +func sliceContains[S ~[]E, E comparable](s S, v E) bool { + return sliceIndex(s, v) >= 0 +} + +// sliceIndex returns the index of the first occurrence of v in s, +// or -1 if not present. +func sliceIndex[S ~[]E, E comparable](s S, v E) int { + for i := range s { + if v == s[i] { + return i + } + } + return -1 +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 4169fb8..e30a03d 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -1,13 +1,16 @@ package cli import ( + "bytes" "errors" "fmt" "math/rand" "net" + "net/netip" "net/url" "os" "runtime" + "sort" "strconv" "sync" "syscall" @@ -25,6 +28,8 @@ const ( defaultSemaphoreCap = 256 ctrldLogUnixSock = "ctrld_start.sock" ctrldControlUnixSock = "ctrld_control.sock" + upstreamPrefix = "upstream." + upstreamOS = upstreamPrefix + "os" ) var logf = func(format string, args ...any) { @@ -46,11 +51,16 @@ type prog struct { logConn net.Conn cs *controlServer - cfg *ctrld.Config - cache dnscache.Cacher - sema semaphore - ciTable *clientinfo.Table - router router.Router + cfg *ctrld.Config + appCallback *AppCallback + cache dnscache.Cacher + sema semaphore + ciTable *clientinfo.Table + um *upstreamMonitor + router router.Router + + loopMu sync.Mutex + loop map[string]bool started chan struct{} onStartedDone chan struct{} @@ -84,6 +94,7 @@ func (p *prog) run() { numListeners := len(p.cfg.Listener) p.started = make(chan struct{}, numListeners) p.onStartedDone = make(chan struct{}) + p.loop = make(map[string]bool) if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { @@ -114,6 +125,8 @@ func (p *prog) run() { nc.IPNets = append(nc.IPNets, ipNet) } } + + p.um = newUpstreamMonitor(p.cfg) for n := range p.cfg.Upstream { uc := p.cfg.Upstream[n] uc.Init() @@ -133,12 +146,14 @@ func (p *prog) run() { format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) p.ciTable.AddLeaseFile(leaseFile, format) } - - go func() { - p.ciTable.Init() - p.ciTable.RefreshLoop(p.stopCh) - }() - go p.watchLinkState() + // Newer versions of android and iOS denies permission which breaks connectivity. + if !isMobile() { + go func() { + p.ciTable.Init() + p.ciTable.RefreshLoop(p.stopCh) + }() + go p.watchLinkState() + } for listenerNum := range p.cfg.Listener { p.cfg.Listener[listenerNum].Init() @@ -163,8 +178,13 @@ func (p *prog) run() { for _, f := range p.onStarted { f() } + // Check for possible DNS loop. + p.checkDnsLoop() close(p.onStartedDone) + // Start check DNS loop ticker. + go p.checkDnsLoopTicker() + // Stop writing log to unix socket. consoleWriter.Out = os.Stdout initLoggingWithBackup(false) @@ -345,48 +365,100 @@ var ( func errUrlNetworkError(err error) bool { var urlErr *url.Error if errors.As(err, &urlErr) { - var opErr *net.OpError - if errors.As(urlErr.Err, &opErr) { - if opErr.Temporary() { - return true - } - switch { - case errors.Is(opErr.Err, syscall.ECONNREFUSED), - errors.Is(opErr.Err, syscall.EINVAL), - errors.Is(opErr.Err, syscall.ENETUNREACH), - errors.Is(opErr.Err, windowsENETUNREACH), - errors.Is(opErr.Err, windowsEINVAL), - errors.Is(opErr.Err, windowsECONNREFUSED): - return true - } + return errNetworkError(urlErr.Err) + } + return false +} + +func errNetworkError(err error) bool { + var opErr *net.OpError + if errors.As(err, &opErr) { + if opErr.Temporary() { + return true + } + switch { + case errors.Is(opErr.Err, syscall.ECONNREFUSED), + errors.Is(opErr.Err, syscall.EINVAL), + errors.Is(opErr.Err, syscall.ENETUNREACH), + errors.Is(opErr.Err, windowsENETUNREACH), + errors.Is(opErr.Err, windowsEINVAL), + errors.Is(opErr.Err, windowsECONNREFUSED): + return true } } return false } -// defaultRouteIP returns IP string of the default route if present, prefer IPv4 over IPv6. -func defaultRouteIP() string { - if dr, err := interfaces.DefaultRoute(); err == nil { - if netIface, err := netInterface(dr.InterfaceName); err == nil { - addrs, _ := netIface.Addrs() - do := func(v4 bool) net.IP { - for _, addr := range addrs { - if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() { - if v4 { - return netIP.IP.To4() - } - return netIP.IP - } +func ifaceFirstPrivateIP(iface *net.Interface) string { + if iface == nil { + return "" + } + do := func(addrs []net.Addr, v4 bool) net.IP { + for _, addr := range addrs { + if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() { + if v4 { + return netIP.IP.To4() } - return nil - } - if ip := do(true); ip != nil { - return ip.String() - } - if ip := do(false); ip != nil { - return ip.String() + return netIP.IP } } + return nil + } + addrs, _ := iface.Addrs() + if ip := do(addrs, true); ip != nil { + return ip.String() + } + if ip := do(addrs, false); ip != nil { + return ip.String() } return "" } + +// defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6. +func defaultRouteIP() string { + dr, err := interfaces.DefaultRoute() + if err != nil { + return "" + } + drNetIface, err := netInterface(dr.InterfaceName) + if err != nil { + return "" + } + mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface") + if ip := ifaceFirstPrivateIP(drNetIface); ip != "" { + mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface") + return ip + } + + // If we reach here, it means the default route interface is connected directly to ISP. + // We need to find the LAN interface with the same Mac address with drNetIface. + // + // There could be multiple LAN interfaces with the same Mac address, so we find all private + // IPs then using the smallest one. + var addrs []netip.Addr + interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { + if i.Name == drNetIface.Name { + return + } + if bytes.Equal(i.HardwareAddr, drNetIface.HardwareAddr) { + for _, pfx := range prefixes { + addr := pfx.Addr() + if addr.IsPrivate() { + addrs = append(addrs, addr) + } + } + } + }) + + if len(addrs) == 0 { + mainLog.Load().Warn().Msg("no default route IP found") + return "" + } + sort.Slice(addrs, func(i, j int) bool { + return addrs[i].Less(addrs[j]) + }) + + ip := addrs[0].String() + mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP") + return ip +} diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index 6f28083..ed28561 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -4,7 +4,6 @@ import ( "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld/internal/dns" - "github.com/Control-D-Inc/ctrld/internal/router" ) func init() { @@ -21,9 +20,8 @@ func setDependencies(svc *service.Config) { "After=NetworkManager-wait-online.service", "Wants=systemd-networkd-wait-online.service", "After=systemd-networkd-wait-online.service", - } - if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 { - svc.Dependencies = append(svc.Dependencies, routerDeps...) + "Wants=nss-lookup.target", + "After=nss-lookup.target", } } diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go new file mode 100644 index 0000000..4b3ee69 --- /dev/null +++ b/cmd/cli/upstream_monitor.go @@ -0,0 +1,98 @@ +package cli + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + "tailscale.com/logtail/backoff" + + "github.com/Control-D-Inc/ctrld" +) + +const ( + // maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down. + maxFailureRequest = 100 + // checkUpstreamMaxBackoff is the max backoff time when checking upstream status. + checkUpstreamMaxBackoff = 2 * time.Minute +) + +// upstreamMonitor performs monitoring upstreams health. +type upstreamMonitor struct { + cfg *ctrld.Config + + down map[string]*atomic.Bool + failureReq map[string]*atomic.Uint64 + + mu sync.Mutex + checking map[string]bool +} + +func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { + um := &upstreamMonitor{ + cfg: cfg, + down: make(map[string]*atomic.Bool), + failureReq: make(map[string]*atomic.Uint64), + checking: make(map[string]bool), + } + for n := range cfg.Upstream { + upstream := upstreamPrefix + n + um.down[upstream] = new(atomic.Bool) + um.failureReq[upstream] = new(atomic.Uint64) + } + um.down[upstreamOS] = new(atomic.Bool) + um.failureReq[upstreamOS] = new(atomic.Uint64) + return um +} + +// increaseFailureCount increase failed queries count for an upstream by 1. +func (um *upstreamMonitor) increaseFailureCount(upstream string) { + failedCount := um.failureReq[upstream].Add(1) + um.down[upstream].Store(failedCount >= maxFailureRequest) +} + +// isDown reports whether the given upstream is being marked as down. +func (um *upstreamMonitor) isDown(upstream string) bool { + return um.down[upstream].Load() +} + +// reset marks an upstream as up and set failed queries counter to zero. +func (um *upstreamMonitor) reset(upstream string) { + um.failureReq[upstream].Store(0) + um.down[upstream].Store(false) +} + +// checkUpstream checks the given upstream status, periodically sending query to upstream +// until successfully. An upstream status/counter will be reset once it becomes reachable. +func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { + um.mu.Lock() + isChecking := um.checking[upstream] + if isChecking { + um.mu.Unlock() + return + } + um.checking[upstream] = true + um.mu.Unlock() + + bo := backoff.NewBackoff("checkUpstream", logf, checkUpstreamMaxBackoff) + resolver, err := ctrld.NewResolver(uc) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("could not check upstream") + return + } + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + ctx := context.Background() + + for { + _, err := resolver.Resolve(ctx, msg) + if err == nil { + mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint) + um.reset(upstream) + return + } + bo.BackOff(ctx, err) + } +} diff --git a/cmd/ctrld_library/main.go b/cmd/ctrld_library/main.go new file mode 100644 index 0000000..526dd3b --- /dev/null +++ b/cmd/ctrld_library/main.go @@ -0,0 +1,74 @@ +package ctrld_library + +import ( + "github.com/Control-D-Inc/ctrld/cmd/cli" +) + +// Controller holds global state +type Controller struct { + stopCh chan struct{} + AppCallback AppCallback + Config cli.AppConfig +} + +// NewController provides reference to global state to be managed by android vpn service and iOS network extension. +// reference is not safe for concurrent use. +func NewController(appCallback AppCallback) *Controller { + return &Controller{AppCallback: appCallback} +} + +// AppCallback provides access to app instance. +type AppCallback interface { + Hostname() string + LanIp() string + MacAddress() string + Exit(error string) +} + +// Start configures utility with config.toml from provided directory. +// This function will block until Stop is called +// Check port availability prior to calling it. +func (c *Controller) Start(CdUID string, HomeDir string, logLevel int, logPath string) { + if c.stopCh == nil { + c.stopCh = make(chan struct{}) + c.Config = cli.AppConfig{ + CdUID: CdUID, + HomeDir: HomeDir, + Verbose: logLevel, + LogPath: logPath, + } + appCallback := mapCallback(c.AppCallback) + cli.RunMobile(&c.Config, &appCallback, c.stopCh) + } +} + +// As workaround to avoid circular dependency between cli and ctrld_library module +func mapCallback(callback AppCallback) cli.AppCallback { + return cli.AppCallback{ + HostName: func() string { + return callback.Hostname() + }, + LanIp: func() string { + return callback.LanIp() + }, + MacAddress: func() string { + return callback.MacAddress() + }, + Exit: func(err string) { + callback.Exit(err) + }, + } +} + +func (c *Controller) Stop() bool { + if c.stopCh != nil { + close(c.stopCh) + c.stopCh = nil + return true + } + return false +} + +func (c *Controller) IsRunning() bool { + return c.stopCh != nil +} diff --git a/config.go b/config.go index eef5af0..21d636c 100644 --- a/config.go +++ b/config.go @@ -2,8 +2,10 @@ package ctrld import ( "context" + crand "crypto/rand" "crypto/tls" "crypto/x509" + "encoding/hex" "errors" "io" "math/rand" @@ -78,8 +80,8 @@ func SetConfigNameWithPath(v *viper.Viper, name, configPath string) { func InitConfig(v *viper.Viper, name string) { v.SetDefault("listener", map[string]*ListenerConfig{ "0": { - IP: "127.0.0.1", - Port: 53, + IP: "", + Port: 0, }, }) v.SetDefault("network", map[string]*NetworkConfig{ @@ -178,6 +180,7 @@ type ServiceConfig struct { DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"` DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` + DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` Daemon bool `mapstructure:"-" toml:"-"` AllocateIP bool `mapstructure:"-" toml:"-"` } @@ -216,6 +219,7 @@ type UpstreamConfig struct { http3RoundTripper6 http.RoundTripper certPool *x509.CertPool u *url.URL + uid string } // ListenerConfig specifies the networks configuration that ctrld will run on. @@ -260,6 +264,7 @@ type Rule map[string][]string // Init initialized necessary values for an UpstreamConfig. func (uc *UpstreamConfig) Init() { + uc.uid = upstreamUID() if u, err := url.Parse(uc.Endpoint); err == nil { uc.Domain = u.Host switch uc.Type { @@ -340,6 +345,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { uc.setupBootstrapIP(true) } +// UID returns the unique identifier of the upstream. +func (uc *UpstreamConfig) UID() string { + return uc.uid +} + // SetupBootstrapIP manually find all available IPs of the upstream. // The first usable IP will be used as bootstrap IP of the upstream. func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) { @@ -679,3 +689,15 @@ func ResolverTypeFromEndpoint(endpoint string) string { func pick(s []string) string { return s[rand.Intn(len(s))] } + +// upstreamUID generates an unique identifier for an upstream. +func upstreamUID() string { + b := make([]byte, 4) + for { + if _, err := crand.Read(b); err != nil { + ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...") + continue + } + return hex.EncodeToString(b) + } +} diff --git a/config_internal_test.go b/config_internal_test.go index 6fc1844..89cec19 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -185,6 +185,7 @@ func TestUpstreamConfig_Init(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() tc.uc.Init() + tc.uc.uid = "" // we don't care about the uid. assert.Equal(t, tc.expected, tc.uc) }) } diff --git a/config_quic.go b/config_quic.go index e953c72..cd3eaee 100644 --- a/config_quic.go +++ b/config_quic.go @@ -8,6 +8,7 @@ import ( "errors" "net" "net/http" + "runtime" "sync" "time" @@ -43,7 +44,6 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { rt := &http3.RoundTripper{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - domain := addr _, port, _ := net.SplitHostPort(addr) // if we have a bootstrap ip set, use it to avoid DNS lookup if uc.BootstrapIP != "" { @@ -57,20 +57,23 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { if err != nil { return nil, err } - return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg) + return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg) } dialAddrs := make([]string, len(addrs)) for i := range addrs { dialAddrs[i] = net.JoinHostPort(addrs[i], port) } pd := &quicParallelDialer{} - conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg) + conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg) if err != nil { return nil, err } ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) return conn, err } + runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) { + rt.CloseIdleConnections() + }) return rt } @@ -107,13 +110,15 @@ type parallelDialerResult struct { type quicParallelDialer struct{} // Dial performs parallel dialing to the given address list. -func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { +func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { if len(addrs) == 0 { return nil, errors.New("empty addresses") } ctx, cancel := context.WithCancel(ctx) defer cancel() + done := make(chan struct{}) + defer close(done) ch := make(chan *parallelDialerResult, len(addrs)) var wg sync.WaitGroup wg.Add(len(addrs)) @@ -135,9 +140,14 @@ func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []st ch <- ¶llelDialerResult{conn: nil, err: err} return } - - conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg) - ch <- ¶llelDialerResult{conn: conn, err: err} + conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg) + select { + case ch <- ¶llelDialerResult{conn: conn, err: err}: + case <-done: + if conn != nil { + conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "") + } + } }(addr) } diff --git a/Dockerfile b/docker/Dockerfile similarity index 100% rename from Dockerfile rename to docker/Dockerfile diff --git a/docker/Dockerfile.debug b/docker/Dockerfile.debug new file mode 100644 index 0000000..e7ce172 --- /dev/null +++ b/docker/Dockerfile.debug @@ -0,0 +1,32 @@ +# Using Debian bullseye for building regular image. +# Using scratch image for minimal image size. +# The final image has: +# +# - Timezone info file. +# - CA certs file. +# - /etc/{passwd,group} file. +# - Non-cgo ctrld binary. +# +# CI_COMMIT_TAG is used to set the version of ctrld binary. +FROM golang:1.20-bullseye as base + +WORKDIR /app + +RUN apt-get update && apt-get install -y upx-ucl + +COPY . . + +ARG tag=master +ENV CI_COMMIT_TAG=$tag +RUN CTRLD_NO_QF=yes CGO_ENABLED=0 ./scripts/build.sh + +FROM alpine + +COPY --from=base /usr/share/zoneinfo /usr/share/zoneinfo +COPY --from=base /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ +COPY --from=base /etc/passwd /etc/passwd +COPY --from=base /etc/group /etc/group + +COPY --from=base /app/ctrld-linux-*-nocgo ctrld + +ENTRYPOINT ["./ctrld", "run"] diff --git a/docs/config.md b/docs/config.md index f2b5554..35fbda5 100644 --- a/docs/config.md +++ b/docs/config.md @@ -193,6 +193,13 @@ Perform LAN client discovery using PTR queries. - Required: no - Default: true +### discover_hosts +Perform LAN client discovery using hosts file. + +- Type: boolean +- Required: no +- Default: true + ### dhcp_lease_file_path Relative or absolute path to a custom DHCP leases file location. diff --git a/doh.go b/doh.go index 5886881..e0aa363 100644 --- a/doh.go +++ b/doh.go @@ -8,6 +8,11 @@ import ( "io" "net/http" "net/url" + "runtime" + "strings" + "sync" + + "github.com/cuonglm/osinfo" "github.com/miekg/dns" ) @@ -16,9 +21,56 @@ const ( dohMacHeader = "x-cd-mac" dohIPHeader = "x-cd-ip" dohHostHeader = "x-cd-host" + dohOsHeader = "x-cd-os" headerApplicationDNS = "application/dns-message" ) +// EncodeOsNameMap provides mapping from OS name to a shorter string, used for encoding x-cd-os value. +var EncodeOsNameMap = map[string]string{ + "windows": "1", + "darwin": "2", + "linux": "3", + "freebsd": "4", +} + +// DecodeOsNameMap provides mapping from encoded OS name to real value, used for decoding x-cd-os value. +var DecodeOsNameMap = map[string]string{} + +// EncodeArchNameMap provides mapping from OS arch to a shorter string, used for encoding x-cd-os value. +var EncodeArchNameMap = map[string]string{ + "amd64": "1", + "arm64": "2", + "arm": "3", + "386": "4", + "mips": "5", + "mipsle": "6", + "mips64": "7", +} + +// DecodeArchNameMap provides mapping from encoded OS arch to real value, used for decoding x-cd-os value. +var DecodeArchNameMap = map[string]string{} + +func init() { + for k, v := range EncodeOsNameMap { + DecodeOsNameMap[v] = k + } + for k, v := range EncodeArchNameMap { + DecodeArchNameMap[v] = k + } +} + +// TODO: use sync.OnceValue when upgrading to go1.21 +var xCdOsValueOnce sync.Once +var xCdOsValue string + +func dohOsHeaderValue() string { + xCdOsValueOnce.Do(func() { + oi := osinfo.New() + xCdOsValue = strings.Join([]string{EncodeOsNameMap[runtime.GOOS], EncodeArchNameMap[runtime.GOARCH], oi.Dist}, "-") + }) + return xCdOsValue +} + func newDohResolver(uc *UpstreamConfig) *dohResolver { r := &dohResolver{ endpoint: uc.u, @@ -97,8 +149,12 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { req.Header.Set("Content-Type", headerApplicationDNS) req.Header.Set("Accept", headerApplicationDNS) + req.Header.Set(dohOsHeader, dohOsHeaderValue()) + + printed := false if sendClientInfo { if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { + printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != "" if ci.Mac != "" { req.Header.Set(dohMacHeader, ci.Mac) } @@ -108,7 +164,12 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { if ci.Hostname != "" { req.Header.Set(dohHostHeader, ci.Hostname) } + if ci.Self { + req.Header.Set(dohOsHeader, dohOsHeaderValue()) + } } } - Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") + if printed { + Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") + } } diff --git a/doh_test.go b/doh_test.go new file mode 100644 index 0000000..d233498 --- /dev/null +++ b/doh_test.go @@ -0,0 +1,23 @@ +package ctrld + +import ( + "runtime" + "testing" +) + +func Test_dohOsHeaderValue(t *testing.T) { + val := dohOsHeaderValue() + if val == "" { + t.Fatalf("empty %s", dohOsHeader) + } + t.Log(val) + + encodedOs := EncodeOsNameMap[runtime.GOOS] + if encodedOs == "" { + t.Fatalf("missing encoding value for: %q", runtime.GOOS) + } + decodedOs := DecodeOsNameMap[encodedOs] + if decodedOs == "" { + t.Fatalf("missing decoding value for: %q", runtime.GOOS) + } +} diff --git a/doq.go b/doq.go index 365fa10..3c3f9e8 100644 --- a/doq.go +++ b/doq.go @@ -51,7 +51,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls. } func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) { - session, err := quic.DialAddr(endpoint, tlsConfig, nil) + session, err := quic.DialAddr(ctx, endpoint, tlsConfig, nil) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index 1229987..58ba1e4 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.20 require ( github.com/coreos/go-systemd/v22 v22.5.0 - github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 + github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf github.com/frankban/quicktest v1.14.5 github.com/fsnotify/fsnotify v1.6.0 github.com/go-playground/validator/v10 v10.11.1 @@ -12,19 +12,19 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.1 github.com/illarion/gonotify v1.0.1 github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16 + github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 github.com/kardianos/service v1.2.1 github.com/miekg/dns v1.1.55 github.com/olekukonko/tablewriter v0.0.5 github.com/pelletier/go-toml/v2 v2.0.8 - github.com/quic-go/quic-go v0.32.0 + github.com/quic-go/quic-go v0.38.0 github.com/rs/zerolog v1.28.0 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.8.3 github.com/vishvananda/netlink v1.2.1-beta.2 - go4.org/mem v0.0.0-20220726221520-4f986261bf13 golang.org/x/net v0.10.0 golang.org/x/sync v0.2.0 golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a @@ -37,7 +37,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/golang/mock v1.6.0 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect @@ -56,13 +56,11 @@ require ( github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 // indirect github.com/mdlayher/socket v0.4.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/onsi/ginkgo/v2 v2.2.0 // indirect + github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/pierrec/lz4/v4 v4.1.17 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect - github.com/quic-go/qtls-go1-18 v0.2.0 // indirect - github.com/quic-go/qtls-go1-19 v0.2.0 // indirect - github.com/quic-go/qtls-go1-20 v0.1.0 // indirect + github.com/quic-go/qtls-go1-20 v0.3.2 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/spf13/afero v1.9.5 // indirect @@ -71,8 +69,10 @@ require ( github.com/subosito/gotenv v1.4.2 // indirect github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect github.com/vishvananda/netns v0.0.4 // indirect + go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect golang.org/x/crypto v0.9.0 // indirect golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect + golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda // indirect golang.org/x/mod v0.10.0 // indirect golang.org/x/text v0.9.0 // indirect golang.org/x/tools v0.9.1 // indirect diff --git a/go.sum b/go.sum index bdd9bef..409133a 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs= github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= +github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf h1:40DHYsri+d1bnroFDU2FQAeq68f3kAlOzlQ93kCf26Q= +github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -73,6 +75,7 @@ github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbS github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= @@ -81,8 +84,8 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= -github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -162,6 +165,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16 h1:+aAGyK41KRn8jbF2Q7PLL0Sxwg6dShGcQSeCC7nZQ8E= github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16/go.mod h1:IKrnDWs3/Mqq5n0lI+RxA2sB7MvN/vbMBP3ehXg65UI= +github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c h1:kbTQ8oGf+BVFvt/fM+ECI+NbZDCqoi0vtZTfB2p2hrI= +github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c/go.mod h1:k6+89xKz7BSMJ+DzIerBdtpEUeTlBMugO/hcVSzahog= github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk= github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8= @@ -211,9 +216,9 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= -github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI= -github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk= -github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q= +github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= +github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= @@ -227,14 +232,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwqqj7U= -github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc= -github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk= -github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI= -github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= -github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA= -github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo= +github.com/quic-go/qtls-go1-20 v0.3.2 h1:rRgN3WfnKbyik4dBV8A6girlJVxGand/d+jVKbQq5GI= +github.com/quic-go/qtls-go1-20 v0.3.2/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/quic-go/quic-go v0.38.0 h1:T45lASr5q/TrVwt+jrVccmqHhPL2XuSyoCLVCpfOSLc= +github.com/quic-go/quic-go v0.38.0/go.mod h1:MPCuRq7KBK2hNcfKj/1iD1BGuN3eAYMeNxp3T42LRUg= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -330,6 +331,8 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda h1:O+EUvnBNPwI4eLthn8W5K+cS8zQZfgTABPLNm6Bna34= +golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda/go.mod h1:aAjjkJNdrh3PMckS4B10TGS2nag27cbKR1y2BpUxsiY= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= diff --git a/internal/clientinfo/arp.go b/internal/clientinfo/arp.go index 8429b56..f99f783 100644 --- a/internal/clientinfo/arp.go +++ b/internal/clientinfo/arp.go @@ -33,6 +33,9 @@ func (a *arpDiscover) String() string { } func (a *arpDiscover) List() []string { + if a == nil { + return nil + } var ips []string a.ip.Range(func(key, value any) bool { ips = append(ips, value.(string)) diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 9235ca9..3e92fd1 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -73,6 +73,8 @@ type Table struct { arp *arpDiscover ptr *ptrDiscover mdns *mdns + hf *hostsFile + vni *virtualNetworkIface cfg *ctrld.Config quitCh chan struct{} selfIP string @@ -116,6 +118,7 @@ func (t *Table) Init() { } func (t *Table) init() { + // Custom client ID presents, use it as the only source. if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" { ctrld.ProxyLogger.Load().Debug().Msg("start self discovery") t.dhcp = &dhcp{selfIP: t.selfIP} @@ -125,6 +128,11 @@ func (t *Table) init() { t.hostnameResolvers = append(t.hostnameResolvers, t.dhcp) return } + + // Otherwise, process all possible sources in order, that means + // the first result of IP/MAC/Hostname lookup will be used. + // + // Merlin custom clients. if t.discoverDHCP() || t.discoverARP() { t.merlin = &merlinDiscover{} if err := t.merlin.refresh(); err != nil { @@ -134,6 +142,19 @@ func (t *Table) init() { t.refreshers = append(t.refreshers, t.merlin) } } + // Hosts file mapping. + if t.discoverHosts() { + t.hf = &hostsFile{} + ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery") + if err := t.hf.init(); err != nil { + ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init hosts file discover") + } else { + t.hostnameResolvers = append(t.hostnameResolvers, t.hf) + t.refreshers = append(t.refreshers, t.hf) + } + go t.hf.watchChanges() + } + // DHCP lease files. if t.discoverDHCP() { t.dhcp = &dhcp{selfIP: t.selfIP} ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery") @@ -146,6 +167,7 @@ func (t *Table) init() { } go t.dhcp.watchChanges() } + // ARP table. if t.discoverARP() { t.arp = &arpDiscover{} ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery") @@ -157,6 +179,7 @@ func (t *Table) init() { t.refreshers = append(t.refreshers, t.arp) } } + // PTR lookup. if t.discoverPTR() { t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") @@ -167,6 +190,7 @@ func (t *Table) init() { t.refreshers = append(t.refreshers, t.ptr) } } + // mdns. if t.discoverMDNS() { t.mdns = &mdns{} ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery") @@ -176,6 +200,11 @@ func (t *Table) init() { t.hostnameResolvers = append(t.hostnameResolvers, t.mdns) } } + // VPN clients. + if t.discoverDHCP() || t.discoverARP() { + t.vni = &virtualNetworkIface{} + t.hostnameResolvers = append(t.hostnameResolvers, t.vni) + } } func (t *Table) LookupIP(mac string) string { @@ -259,7 +288,7 @@ func (t *Table) ListClients() []*Client { _ = r.refresh() } ipMap := make(map[string]*Client) - il := []ipLister{t.dhcp, t.arp, t.ptr, t.mdns} + il := []ipLister{t.dhcp, t.arp, t.ptr, t.mdns, t.vni} for _, ir := range il { for _, ip := range ir.List() { c, ok := ipMap[ip] @@ -300,6 +329,15 @@ func (t *Table) ListClients() []*Client { return clients } +// StoreVPNClient stores client info for VPN clients. +func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) { + if ci == nil || t.vni == nil { + return + } + t.vni.mac.Store(ci.IP, ci.Mac) + t.vni.ip2name.Store(ci.IP, ci.Hostname) +} + func (t *Table) discoverDHCP() bool { if t.cfg.Service.DiscoverDHCP == nil { return true @@ -328,6 +366,13 @@ func (t *Table) discoverPTR() bool { return *t.cfg.Service.DiscoverPtr } +func (t *Table) discoverHosts() bool { + if t.cfg.Service.DiscoverHosts == nil { + return true + } + return *t.cfg.Service.DiscoverHosts +} + // normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file. func normalizeIP(in string) string { // dnsmasq may put ip with interface index in lease file, strip it here. diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index 27e2bf4..7c1b2cf 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -47,12 +47,25 @@ func (d *dhcp) watchChanges() { if d.watcher == nil { return } + if dir := router.LeaseFilesDir(); dir != "" { + if err := d.watcher.Add(dir); err != nil { + ctrld.ProxyLogger.Load().Err(err).Str("dir", dir).Msg("could not watch lease dir") + } + } for { select { case event, ok := <-d.watcher.Events: if !ok { return } + if event.Has(fsnotify.Create) { + if format, ok := clientInfoFiles[event.Name]; ok { + if err := d.addLeaseFile(event.Name, format); err != nil { + ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("could not add lease file") + } + } + continue + } if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) { format := clientInfoFiles[event.Name] if err := d.readLeaseFile(event.Name, format); err != nil && !os.IsNotExist(err) { @@ -106,6 +119,9 @@ func (d *dhcp) String() string { } func (d *dhcp) List() []string { + if d == nil { + return nil + } var ips []string d.ip.Range(func(key, value any) bool { ips = append(ips, value.(string)) diff --git a/internal/clientinfo/hostsfile.go b/internal/clientinfo/hostsfile.go new file mode 100644 index 0000000..baf05fb --- /dev/null +++ b/internal/clientinfo/hostsfile.go @@ -0,0 +1,120 @@ +package clientinfo + +import ( + "os" + "sync" + + "github.com/fsnotify/fsnotify" + "github.com/jaytaylor/go-hostsfile" + + "github.com/Control-D-Inc/ctrld" +) + +const ( + ipv4LocalhostName = "localhost" + ipv6LocalhostName = "ip6-localhost" + ipv6LoopbackName = "ip6-loopback" +) + +// hostsFile provides client discovery functionality using system hosts file. +type hostsFile struct { + watcher *fsnotify.Watcher + mu sync.Mutex + m map[string][]string +} + +// init performs initialization works, which is necessary before hostsFile can be fully operated. +func (hf *hostsFile) init() error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + hf.watcher = watcher + if err := hf.watcher.Add(hostsfile.HostsPath); err != nil { + return err + } + m, err := hostsfile.ParseHosts(hostsfile.ReadHostsFile()) + if err != nil { + return err + } + hf.mu.Lock() + hf.m = m + hf.mu.Unlock() + return nil +} + +// refresh reloads hosts file entries. +func (hf *hostsFile) refresh() error { + m, err := hostsfile.ParseHosts(hostsfile.ReadHostsFile()) + if err != nil { + return err + } + hf.mu.Lock() + hf.m = m + hf.mu.Unlock() + return nil +} + +// watchChanges watches and updates hosts file data if any changes happens. +func (hf *hostsFile) watchChanges() { + if hf.watcher == nil { + return + } + for { + select { + case event, ok := <-hf.watcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) { + if err := hf.refresh(); err != nil && !os.IsNotExist(err) { + ctrld.ProxyLogger.Load().Err(err).Msg("hosts file changed but failed to update client info") + } + } + case err, ok := <-hf.watcher.Errors: + if !ok { + return + } + ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file") + } + } + +} + +// LookupHostnameByIP returns hostname for given IP from current hosts file entries. +func (hf *hostsFile) LookupHostnameByIP(ip string) string { + hf.mu.Lock() + defer hf.mu.Unlock() + if names := hf.m[ip]; len(names) > 0 { + isLoopback := ip == "127.0.0.1" || ip == "::1" + for _, hostname := range names { + name := normalizeHostname(hostname) + // Ignoring ipv4/ipv6 loopback entry. + if isLoopback && isLocalhostName(name) { + continue + } + return name + } + } + return "" +} + +// LookupHostnameByMac returns hostname for given Mac from current hosts file entries. +func (hf *hostsFile) LookupHostnameByMac(mac string) string { + return "" +} + +// String returns human-readable format of hostsFile. +func (hf *hostsFile) String() string { + return "hosts" +} + +// isLocalhostName reports whether the given hostname represents localhost. +func isLocalhostName(hostname string) bool { + switch hostname { + case ipv4LocalhostName, ipv6LocalhostName, ipv6LoopbackName: + return true + default: + return false + } +} diff --git a/internal/clientinfo/hostsfile_test.go b/internal/clientinfo/hostsfile_test.go new file mode 100644 index 0000000..f67fcef --- /dev/null +++ b/internal/clientinfo/hostsfile_test.go @@ -0,0 +1,33 @@ +package clientinfo + +import ( + "testing" +) + +func Test_hostsFile_LookupHostnameByIP(t *testing.T) { + tests := []struct { + name string + ip string + hostnames []string + expectedHostname string + }{ + {"ipv4 loopback", "127.0.0.1", []string{ipv4LocalhostName}, ""}, + {"ipv6 loopback", "::1", []string{ipv6LocalhostName, ipv6LoopbackName}, ""}, + {"non-localhost", "::1", []string{"foo"}, "foo"}, + {"multiple hostnames", "::1", []string{ipv4LocalhostName, "foo"}, "foo"}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + hf := &hostsFile{m: make(map[string][]string)} + hf.mu.Lock() + hf.m[tc.ip] = tc.hostnames + hf.mu.Unlock() + if got := hf.LookupHostnameByIP(tc.ip); got != tc.expectedHostname { + t.Errorf("unpexpected result, want: %q, got: %q", tc.expectedHostname, got) + } + }) + } +} diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index c9d97e5..5875b69 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -48,6 +48,9 @@ func (m *mdns) String() string { } func (m *mdns) List() []string { + if m == nil { + return nil + } var ips []string m.name.Range(func(key, value any) bool { ips = append(ips, key.(string)) diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 9c02fa1..6a9d99b 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -3,16 +3,19 @@ package clientinfo import ( "context" "sync" + "sync/atomic" "time" "github.com/miekg/dns" + "tailscale.com/logtail/backoff" "github.com/Control-D-Inc/ctrld" ) type ptrDiscover struct { - hostname sync.Map // ip => hostname - resolver ctrld.Resolver + hostname sync.Map // ip => hostname + resolver ctrld.Resolver + serverDown atomic.Bool } func (p *ptrDiscover) refresh() error { @@ -41,6 +44,9 @@ func (p *ptrDiscover) String() string { } func (p *ptrDiscover) List() []string { + if p == nil { + return nil + } var ips []string p.hostname.Range(func(key, value any) bool { ips = append(ips, key.(string)) @@ -57,18 +63,24 @@ func (p *ptrDiscover) lookupHostnameFromCache(ip string) string { } func (p *ptrDiscover) lookupHostname(ip string) string { + // If nameserver is down, do nothing. + if p.serverDown.Load() { + return "" + } ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() msg := new(dns.Msg) addr, err := dns.ReverseAddr(ip) if err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("invalid ip address") + ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("invalid ip address") return "" } msg.SetQuestion(addr, dns.TypePTR) ans, err := p.resolver.Resolve(ctx, msg) if err != nil { - ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not lookup IP") + ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") + p.serverDown.Store(true) + go p.checkServer() return "" } for _, rr := range ans.Answer { @@ -80,3 +92,25 @@ func (p *ptrDiscover) lookupHostname(ip string) string { } return "" } + +// checkServer monitors if the resolver can reach its nameserver. When the nameserver +// is reachable, set p.serverDown to false, so p.lookupHostname can continue working. +func (p *ptrDiscover) checkServer() { + bo := backoff.NewBackoff("ptrDiscover", func(format string, args ...any) {}, time.Minute*5) + m := new(dns.Msg) + m.SetQuestion(".", dns.TypeNS) + ping := func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err := p.resolver.Resolve(ctx, m) + return err + } + for { + if err := ping(); err != nil { + bo.BackOff(context.Background(), err) + continue + } + break + } + p.serverDown.Store(false) +} diff --git a/internal/clientinfo/virtual_iface.go b/internal/clientinfo/virtual_iface.go new file mode 100644 index 0000000..6cb018d --- /dev/null +++ b/internal/clientinfo/virtual_iface.go @@ -0,0 +1,43 @@ +package clientinfo + +import ( + "sync" +) + +// virtualNetworkIface is the manager for clients from virtual network interface. +type virtualNetworkIface struct { + ip2name sync.Map // ip => name + mac sync.Map // ip => mac +} + +// LookupHostnameByIP returns hostname of the given VPN client ip. +func (v *virtualNetworkIface) LookupHostnameByIP(ip string) string { + val, ok := v.ip2name.Load(ip) + if !ok { + return "" + } + return val.(string) +} + +// LookupHostnameByMac always returns empty string. +func (v *virtualNetworkIface) LookupHostnameByMac(mac string) string { + return "" +} + +// String returns the string representation of virtualNetworkIface struct. +func (v *virtualNetworkIface) String() string { + return "" +} + +// List lists all known VPN clients IP. +func (v *virtualNetworkIface) List() []string { + if v == nil { + return nil + } + var ips []string + v.mac.Range(func(key, value any) bool { + ips = append(ips, key.(string)) + return true + }) + return ips +} diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index a25f564..54ba8fd 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -17,6 +17,7 @@ server={{ .IP }}#{{ .Port }} {{- end}} {{- if .SendClientInfo}} add-mac +add-subnet=32,128 {{- end}} ` @@ -39,7 +40,10 @@ if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then pc_append "server={{ .IP }}#{{ .Port }}" "$config_file" {{- end}} {{- if .SendClientInfo}} + pc_delete "add-mac" "$config_file" + pc_delete "add-subnet" "$config_file" pc_append "add-mac" "$config_file" # add client mac + pc_append "add-subnet=32,128" "$config_file" # add client ip {{- end}} pc_delete "dnssec" "$config_file" # disable DNSSEC pc_delete "trust-anchor=" "$config_file" # disable DNSSEC diff --git a/internal/router/edgeos/edgeos.go b/internal/router/edgeos/edgeos.go index 014a594..f50f610 100644 --- a/internal/router/edgeos/edgeos.go +++ b/internal/router/edgeos/edgeos.go @@ -169,9 +169,16 @@ func ContentFilteringEnabled() bool { return err == nil && !st.IsDir() } +func LeaseFileDir() string { + if checkUSG() { + return "" + } + return "/run" +} + func checkUSG() bool { - out, _ := exec.Command("mca-cli-op", "info").Output() - return bytes.Contains(out, []byte("UniFi-Gateway-")) + out, _ := os.ReadFile("/etc/version") + return bytes.HasPrefix(out, []byte("UniFiSecurityGateway.")) } func restartDNSMasq() error { diff --git a/internal/router/router.go b/internal/router/router.go index ad3c641..b8a414b 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -173,20 +173,6 @@ func CanListenLocalhost() bool { } } -// ServiceDependencies returns list of dependencies that ctrld services needs on this router. -// See https://pkg.go.dev/github.com/kardianos/service#Config for list format. -func ServiceDependencies() []string { - if Name() == edgeos.Name { - // On EdeOS, ctrld needs to start after vyatta-dhcpd, so it can read leases file. - return []string{ - "Wants=vyatta-dhcpd.service", - "After=vyatta-dhcpd.service", - "Wants=dnsmasq.service", - } - } - return nil -} - // SelfInterfaces return list of *net.Interface that will be source of requests from router itself. func SelfInterfaces() []*net.Interface { switch Name() { @@ -197,6 +183,14 @@ func SelfInterfaces() []*net.Interface { } } +// LeaseFilesDir is the directory which contains lease files. +func LeaseFilesDir() string { + if Name() == edgeos.Name { + edgeos.LeaseFileDir() + } + return "" +} + func distroName() string { switch { case bytes.HasPrefix(unameO(), []byte("DD-WRT")): diff --git a/nameservers_unix.go b/nameservers_unix.go new file mode 100644 index 0000000..39cc971 --- /dev/null +++ b/nameservers_unix.go @@ -0,0 +1,9 @@ +//go:build unix + +package ctrld + +import "github.com/Control-D-Inc/ctrld/internal/resolvconffile" + +func nameserversFromResolvconf() []string { + return resolvconffile.NameServers("") +} diff --git a/nameservers_windows.go b/nameservers_windows.go index 5cd7811..ea9b347 100644 --- a/nameservers_windows.go +++ b/nameservers_windows.go @@ -58,3 +58,7 @@ func dnsFromAdapter() []string { } return ns } + +func nameserversFromResolvconf() []string { + return nil +} diff --git a/resolver.go b/resolver.go index d2586ec..969da86 100644 --- a/resolver.go +++ b/resolver.go @@ -27,13 +27,15 @@ const ( ) var bootstrapDNS = "76.76.2.0" -var or = &osResolver{nameservers: nameservers()} -func init() { - if len(or.nameservers) == 0 { - // Add bootstrap DNS in case we did not find any. - or.nameservers = []string{net.JoinHostPort(bootstrapDNS, "53")} - } +// or is the Resolver used for ResolverTypeOS. +var or = &osResolver{nameservers: defaultNameservers()} + +// defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver. +func defaultNameservers() []string { + ns := nameservers() + ns = append(ns, net.JoinHostPort(bootstrapDNS, "53")) + return ns } // Resolver is the interface that wraps the basic DNS operations. @@ -237,13 +239,25 @@ func NewBootstrapResolver(servers ...string) Resolver { return resolver } -// NewPrivateResolver returns an OS resolver, which includes only private DNS servers. +// NewPrivateResolver returns an OS resolver, which includes only private DNS servers, +// excluding nameservers from /etc/resolv.conf file. +// // This is useful for doing PTR lookup in LAN network. func NewPrivateResolver() Resolver { nss := nameservers() + resolveConfNss := nameserversFromResolvconf() n := 0 for _, ns := range nss { host, _, _ := net.SplitHostPort(ns) + // Ignore nameserver from resolve.conf file, because the nameserver can be either: + // + // - ctrld itself. + // - Direct listener that has ctrld as an upstream (e.g: dnsmasq). + // + // causing the query always succeed. + if sliceContains(resolveConfNss, host) { + continue + } ip := net.ParseIP(host) if ip != nil && ip.IsPrivate() && !ip.IsLoopback() { nss[n] = ns @@ -269,3 +283,20 @@ func newDialer(dnsAddress string) *net.Dialer { }, } } + +// TODO(cuonglm): use slices.Contains once upgrading to go1.21 +// sliceContains reports whether v is present in s. +func sliceContains[S ~[]E, E comparable](s S, v E) bool { + return sliceIndex(s, v) >= 0 +} + +// sliceIndex returns the index of the first occurrence of v in s, +// or -1 if not present. +func sliceIndex[S ~[]E, E comparable](s S, v E) int { + for i := range s { + if v == s[i] { + return i + } + } + return -1 +} diff --git a/scripts/build.sh b/scripts/build.sh index 4eee0ce..6c96f0f 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -59,7 +59,7 @@ compress() { build() { goos=$1 goarch=$2 - ldflags="-s -w -X github.com/Windscribe/ctrld/cmd/cli.version="${CI_COMMIT_TAG:-dev}" -X github.com/Windscribe/ctrld/cmd/cli.commit=$(git rev-parse HEAD)" + ldflags="-s -w -X github.com/Control-D-Inc/ctrld/cmd/cli.version="${CI_COMMIT_TAG:-dev}" -X github.com/Control-D-Inc/ctrld/cmd/cli.commit=$(git rev-parse HEAD)" case $3 in 5 | 6 | 7)