Compare commits

...

53 Commits

Author SHA1 Message Date
Cuong Manh Le
0084e9ef26 internal/clientinfo: silent staticcheck S1008
The code is written for readability purpose.
2023-12-13 14:53:29 +07:00
Cuong Manh Le
122600bff2 cmd/cli: remove redundant return statement 2023-12-13 14:53:29 +07:00
Cuong Manh Le
41846b6d4c all: add config to enable/disable answering WAN clients 2023-12-13 14:53:29 +07:00
Cuong Manh Le
dfbcb1489d cmd/cli: improving loop guard test
We see number of failed test in Github Action, mostly on MacOS or
Windows due to the fact that goroutines are scheduled to be run
consequently.

This commit improves the test, ensuring at least 2 goroutines were
started before increasing the counting.
2023-12-13 14:53:29 +07:00
Cuong Manh Le
684019c2e3 all: force re-bootstrapping with timeout error 2023-12-11 22:55:16 +07:00
Cuong Manh Le
e92619620d cmd/cli: doing router setup based on "--iface" flag
Solving downgrading issue from newer version to v1.3.1, and also easier
to explain the logic: either doing "magic stuff" or do nothing.
2023-12-11 22:55:16 +07:00
Cuong Manh Le
cebfd12d5c internal/clientinfo: ensure RFC1918 address is chosen over others 2023-12-07 00:04:17 +07:00
Cuong Manh Le
874ff01ab8 cmd/cli: ensure log time field is formated with ms 2023-12-06 22:31:35 +07:00
Alex Paguis
0bb8703f78 Update document for new client_id_preference param 2023-12-06 15:33:05 +07:00
Cuong Manh Le
0bb51aa71d cmd/cli: add loop guard for LAN/PTR queries 2023-12-06 15:33:05 +07:00
Cuong Manh Le
af2c1c87e0 cmd/cli: improve logging for new LAN/PTR flow 2023-12-06 15:33:05 +07:00
Cuong Manh Le
8939debbc0 cmd/cli: do not send test query to external upstreams 2023-12-06 15:33:05 +07:00
Cuong Manh Le
7591a0ccc6 all: add client id preference config param
So client can chose how client id is generated.
2023-12-06 15:33:05 +07:00
Cuong Manh Le
c3ff8182af all: ignoring local interfaces RFC1918 IP for private resolver
Otherwises, the discovery may make a looping with new PTR query flow.
2023-12-06 15:33:05 +07:00
Cuong Manh Le
5897c174d3 all: fix LAN hostname checking condition
The LAN hostname in question is FQDN, "." suffix must be trimmed before
checking.

While at it, also add tests for LAN/PTR query checking functions.
2023-12-06 15:33:05 +07:00
Cuong Manh Le
f9a3f4c045 Implement new flow for LAN and private PTR resolution
- Use client info table.
 - If no sufficient data, use gateway/os/defined local upstreams.
 - If no data is returned, use remote upstream
2023-11-30 18:28:51 +07:00
Cuong Manh Le
a2cb895cdc cmd/cli: watch changes to /etc/resolv.conf
On some routers, change to network may trigger re-rendering
/etc/resolv.conf file, causing requests from router itself stop using
ctrld.

Fixing this by watching changes to /etc/resolv.conf, then revert them.
2023-11-27 22:19:16 +07:00
Cuong Manh Le
2bebe93e47 internal/router: do not disable cache on EdgeOS
The dnsmasq cache-size setting on EdgeOS could be re-generated anytime
by vyatta router/dhcp components. This conflicts with setting generated
by ctrld, causing dnsmasq fails to start.

It's better to keep dnsmasq cache enabled on EdgeOS, we can turn it off
again once we find a reliable way to control cache-size setting.
2023-11-27 22:19:16 +07:00
Cuong Manh Le
28ec1869fc internal/router/merlin: hardening pre-run condition
The postconf script added by ctrld requires all of these conditions to
work correctly:

 - /proc, /tmp were mounted.
 - dnsmasq is running.

Currently, ctrld is only waiting for NTP ready, which may not ensure
both of those conditions are true. Explicitly checking those conditions
is a safer approach.
2023-11-27 22:19:16 +07:00
Cuong Manh Le
17f6d7a77b cmd/cli: notice writing default config in local mode 2023-11-27 22:19:16 +07:00
Cuong Manh Le
9e6e647ff8 Use discover_ptr_endpoints for PTR resolver 2023-11-27 22:19:16 +07:00
Cuong Manh Le
a2116e5eb5 cmd/cli: do not substitute MAC if empty
Using IPv4 as hostname is enough to distinguish clients.
2023-11-27 22:19:16 +07:00
Cuong Manh Le
564c9ef712 cmd/cli: use IP as hostname for ipv4 clients only
For Android devices, when it joins the network, it uses ctrld to resolve
its private DNS once and never reaches ctrld again. For each time, it uses
a different IPv6 address, which causes hundreds/thousands different client
IDs created for the same device, which is pointless.
2023-11-27 22:19:16 +07:00
Cuong Manh Le
856abb71b7 cmd/cli: only notice reading config with "ctrld start"
While at it, also updating the documentation of related functions.
2023-11-27 22:19:16 +07:00
Cuong Manh Le
0a30fdea69 Add listener policy to default generated config
So technical user can figure thing out based on self-documented
commands, without referring to actual documentation.
2023-11-16 20:59:31 +07:00
Cuong Manh Le
4f125cf107 cmd/cli: notice users where config file is written/read 2023-11-16 20:59:12 +07:00
Cuong Manh Le
494d8be777 cmd/cli: skip router setup with "ctrld service start"
Either do magic stuff and make things work automatically (normal users),
or don't do any of it and just run ctrld as a service (power users).
2023-11-16 20:58:41 +07:00
Cuong Manh Le
cd9c750884 cmd/cli: do not run pre run on reload 2023-11-16 20:58:26 +07:00
Cuong Manh Le
91d319804b cmd/cli: only use failover rcodes if defined 2023-11-16 20:58:10 +07:00
Cuong Manh Le
180eae60f2 all: allowing config defined discover ptr endpoints
The default gateway is usually the DNS server in normal home network
setup for most users. However, there's case that it is not, causing
discover ptr failed.

This commit add discover_ptr_endpoints config parameter, so users can
define what DNS nameservers will be used.
2023-11-16 20:57:52 +07:00
Cuong Manh Le
d01f5c2777 cmd/cli: do not stop listener when reloading
We could not do a reload if the listener config changes, so do not turn
them off to try updating new listener config.
2023-11-16 20:56:57 +07:00
Cuong Manh Le
294a90a807 internal/router/openwrt: ensure dnsmasq cache is disabled
Users may have their own dnsmasq cache set via LUCI web, thus ctrld
needs to delete the cache-size setting to ensure its dnsmasq config
works.
2023-11-16 20:56:42 +07:00
Ginder Singh
c3b4ae9c79 Older android missing certificate 2023-11-16 20:56:24 +07:00
Cuong Manh Le
09188bedf7 cmd/cli: fix wrong generated config for nextdns resolver
Generating nextdns config must happen after stopping current ctrld
process. Otherwise, config processing may pick wrong IP+Port.

While at it, also making logging better when updating listener config:

 - Change warn to info, prevent confusing that "something is wrong".
 - Do not emit info when generating working default config, which may
   cause duplicated messages printed.
2023-11-16 20:55:39 +07:00
Cuong Manh Le
4614b98e94 internal/clientinfo: emit error once if ptr discovery failed
So it won't spam ctrld log unnecessary, prevent confusion. While at it,
also change the log level from Warn to Info, since this error is not
actionable by the user.
2023-11-09 00:30:56 +07:00
Cuong Manh Le
990bc620f7 cmd/cli: strip EDNS0_SUBNET for RFC 1918 and loopback address
Since passing them to upstream is pointless, these cannot be used by
anything on the WAN.
2023-11-09 00:23:38 +07:00
Cuong Manh Le
efb5a92571 Using time interval for probing ipv6
A backoff with small max time will flood requests to Control D server,
causing false positive for abuse mitiation system. While a big max time
will cause ctrld not realize network change as fast as possible.

While at it, also sync DoH3 code with DoH code, ensuring no others place
can trigger requests flooding for ipv6 probing.
2023-11-08 23:51:18 +07:00
Cuong Manh Le
8e0a96a44c Fix panic dues to quic-go changes
quic.DialEarly requires separate UDP connection for each
quic.EarlyConnection instead of re-using the same one.
2023-11-08 23:51:18 +07:00
Cuong Manh Le
43ff2f648c internal/router/dnsmasq: disable cache
So multiple upstreams config could work properly.
2023-11-08 23:51:18 +07:00
Cuong Manh Le
4816a09e3a all: use private resolver for private IP address
These queries could not be resolved by Control D upstreams, so it's
useless and less performance to send them to servers.
2023-11-08 23:51:18 +07:00
Cuong Manh Le
3fea92c8b1 Bump golang.org/x/net to v0.17.0 2023-11-08 23:51:08 +07:00
Cuong Manh Le
63f959c951 all: spoof loopback ranges in client info
Sending them are useless, so using RFC1918 address instead.
2023-11-06 20:01:57 +07:00
Cuong Manh Le
44ba6aadd9 internal/clientinfo: do not complain about net.ErrClosed
The probeLoop may have closed the connection before readLoop return, and
we don't care about this error. So prevent it from annoying the log.
2023-11-06 20:01:42 +07:00
Cuong Manh Le
d88cf52b4e cmd/cli: always rebootstrap when check upstream
Otherwise, network changes may not be seen on some platforms, causing
ctrld failed to recover and failing all requests.

While at it, also doing the check DNS in separate goroutine, prevent it
from blocking ctrld from notifying others that it "started". The issue
was seen when ctrld is configured as direct listener, requests are
flooded before ctrld started, causing the healtch process failed.
2023-11-06 20:01:25 +07:00
Cuong Manh Le
58a00ea24a all: implement reload command
This commit adds reload command to ctrld for re-fetch new config from
ContorlD API or re-read the current config on disk.
2023-11-06 20:01:03 +07:00
Cuong Manh Le
712b23a4bb cmd/cli: initialize upstream proto for mobile 2023-11-06 20:00:28 +07:00
Cuong Manh Le
baf836557c cmd/cli: fix wrong checking condition in removeProvTokenFromArgs
The provision token is only used once, then do not have any effect after
Control D uid is fetched. So making it appears in "ctrld run" command is
useless.
2023-11-06 20:00:10 +07:00
Cuong Manh Le
904b23eeac cmd/cli: add --proto flag to set upstream type in cd mode 2023-11-06 19:59:52 +07:00
Cuong Manh Le
6aafe445f5 cmd/cli: add nextdns mode
Adding --nextdns flag to "ctrld start" command for generating ctrld
config with nextdns resolver id, then use nextdns as an upstream.
2023-11-06 19:59:31 +07:00
Ginder Singh
ebd516855b added safe return if error happens during resolver fetch. 2023-11-06 19:58:53 +07:00
Cuong Manh Le
df4e04719e cmd/cli: relax service dependency on systemd-networkd-wait-online
ctrld wants systemd-networkd-wait-online starts before starting itself,
but ctrld should not be blocked waiting for it started.
2023-11-06 19:58:32 +07:00
Cuong Manh Le
2440d922c6 all: add MAC address base policy
While at it, also update the config doc to clarify the order of matching
preference, and the matter of rules order within each policy.
2023-11-06 19:57:50 +07:00
Yegor S
f1b8d1c4ad Merge pull request #93 from Control-D-Inc/release-branch-v1.3.1
Release branch v1.3.1
2023-10-10 22:29:43 -04:00
42 changed files with 2123 additions and 402 deletions

View File

@@ -5,10 +5,11 @@ type ClientInfoCtxKey struct{}
// ClientInfo represents ctrld's clients information.
type ClientInfo struct {
Mac string
IP string
Hostname string
Self bool
Mac string
IP string
Hostname string
Self bool
ClientIDPref string
}
// LeaseFileFormat specifies the format of DHCP lease file.

View File

@@ -144,6 +144,7 @@ func initCLI() {
_ = runCmd.Flags().MarkHidden("homedir")
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
_ = runCmd.Flags().MarkHidden("iface")
runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
rootCmd.AddCommand(runCmd)
@@ -158,6 +159,7 @@ func initCLI() {
Run: func(cmd *cobra.Command, args []string) {
checkStrFlagEmpty(cmd, cdUidFlagName)
checkStrFlagEmpty(cmd, cdOrgFlagName)
validateCdAndNextDNSFlags()
sc := &service.Config{}
*sc = *svcConfig
osArgs := os.Args[2:]
@@ -176,6 +178,9 @@ func initCLI() {
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
sc.Arguments = append(sc.Arguments, "--cd="+cdUID)
}
if cdUID != "" {
validateCdUpstreamProtocol()
}
p := &prog{
router: router.New(&cfg, cdUID != ""),
@@ -223,7 +228,7 @@ func initCLI() {
}()
}
tryReadingConfig(writeDefaultConfig)
tryReadingConfigWithNotice(writeDefaultConfig, true)
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
@@ -231,6 +236,10 @@ func initCLI() {
initLogging()
if nextdns != "" {
removeNextDNSFromArgs(sc)
}
// Explicitly passing config, so on system where home directory could not be obtained,
// or sub-process env is different with the parent, we still behave correctly and use
// the expected config file.
@@ -244,16 +253,20 @@ func initCLI() {
return
}
if router.Name() != "" {
if router.Name() != "" && iface != "" {
mainLog.Load().Debug().Msg("cleaning up router before installing")
_ = p.router.Cleanup()
}
tasks := []task{
{s.Stop, false},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true},
{s.Uninstall, false},
{s.Install, false},
{s.Start, true},
// Note that startCmd do not actually write ControlD config, but the config file was
// generated after s.Start, so we notice users here for consistent with nextdns mode.
{noticeWritingControlDConfig, false},
}
mainLog.Load().Notice().Msg("Starting service")
if doTasks(tasks) {
@@ -281,7 +294,7 @@ func initCLI() {
}
},
}
// Keep these flags in sync with runCmd above, except for "-d".
// Keep these flags in sync with runCmd above, except for "-d"/"--nextdns".
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
@@ -295,6 +308,8 @@ func initCLI() {
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = startCmd.Flags().MarkHidden("dev")
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
routerCmd := &cobra.Command{
Use: "setup",
@@ -367,6 +382,10 @@ func initCLI() {
mainLog.Load().Error().Msg(err.Error())
return
}
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("service not installed")
return
}
initLogging()
tasks := []task{
@@ -388,6 +407,50 @@ func initCLI() {
},
}
reloadCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "reload",
Short: "Reload the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
dir, err := userHomeDir()
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir")
}
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
resp, err := cc.post(reloadPath, nil)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld")
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
mainLog.Load().Notice().Msg("Service reloaded")
case http.StatusCreated:
s, err := newService(&prog{}, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
}
mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.")
mainLog.Load().Warn().Msg("Restarting service")
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("Service not installed")
return
}
restartCmd.Run(cmd, args)
default:
buf, err := io.ReadAll(resp.Body)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("could not read response from control server")
}
mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf))
}
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show status of the ctrld service",
@@ -503,9 +566,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
Short: "Manage ctrld service",
Args: cobra.OnlyValidArgs,
ValidArgs: []string{
statusCmd.Use,
startCmd.Use,
stopCmd.Use,
restartCmd.Use,
reloadCmd.Use,
statusCmd.Use,
uninstallCmd.Use,
interfacesCmd.Use,
@@ -514,6 +578,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
serviceCmd.AddCommand(startCmd)
serviceCmd.AddCommand(stopCmd)
serviceCmd.AddCommand(restartCmd)
serviceCmd.AddCommand(reloadCmd)
serviceCmd.AddCommand(statusCmd)
serviceCmd.AddCommand(uninstallCmd)
serviceCmd.AddCommand(interfacesCmd)
@@ -568,6 +633,19 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
}
rootCmd.AddCommand(restartCmdAlias)
reloadCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "reload",
Short: "Reload the ctrld service",
Run: func(cmd *cobra.Command, args []string) {
reloadCmd.Run(cmd, args)
},
}
rootCmd.AddCommand(reloadCmdAlias)
statusCmdAlias := &cobra.Command{
Use: "status",
Short: "Show status of the ctrld service",
@@ -688,6 +766,7 @@ func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struc
homedir = appConfig.HomeDir
verbose = appConfig.Verbose
cdUID = appConfig.CdUID
cdUpstreamProto = ctrld.ResolverTypeDOH
logPath = appConfig.LogPath
run(appCallback, stopCh)
}
@@ -699,10 +778,12 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
}
waitCh := make(chan struct{})
p := &prog{
waitCh: waitCh,
stopCh: stopCh,
cfg: &cfg,
appCallback: appCallback,
waitCh: waitCh,
stopCh: stopCh,
reloadCh: make(chan struct{}),
reloadDoneCh: make(chan struct{}),
cfg: &cfg,
appCallback: appCallback,
}
if homedir == "" {
if dir, err := userHomeDir(); err == nil {
@@ -740,9 +821,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
readBase64Config(configBase64)
processNoConfigFlags(noConfigStart)
p.mu.Lock()
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
}
p.mu.Unlock()
processLogAndCacheFlags()
@@ -777,14 +860,47 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
cdUID = uid
}
if cdUID != "" {
err := processCDFlags()
if err != nil {
appCallback.Exit(err.Error())
return
validateCdUpstreamProtocol()
if err := processCDFlags(&cfg); err != nil {
if isMobile() {
appCallback.Exit(err.Error())
return
}
uninstallIfInvalidCdUID := func() {
cdLogger := mainLog.Load().With().Str("mode", "cd").Logger()
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
s, err := newService(&prog{}, svcConfig)
if err != nil {
cdLogger.Warn().Err(err).Msg("failed to create new service")
return
}
if netIface, _ := netInterface(iface); netIface != nil {
if err := restoreNetworkManager(); err != nil {
cdLogger.Error().Err(err).Msg("could not restore NetworkManager")
return
}
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
if err := resetDNS(netIface); err != nil {
cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS")
} else {
cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
}
}
tasks := []task{{s.Uninstall, true}}
if doTasks(tasks) {
cdLogger.Info().Msg("uninstalled service")
}
cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config")
return
}
}
uninstallIfInvalidCdUID()
}
}
updated := updateListenerConfig()
updated := updateListenerConfig(&cfg)
if cdUID != "" {
processLogAndCacheFlags()
@@ -812,7 +928,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
initLoggingWithBackup(false)
}
validateConfig(&cfg)
if err := validateConfig(&cfg); err != nil {
os.Exit(1)
}
initCache()
if daemon {
@@ -859,19 +977,21 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
if cp := router.CertPool(); cp != nil {
rootCertPool = cp
}
p.onStarted = append(p.onStarted, func() {
mainLog.Load().Debug().Msg("router setup on start")
if err := p.router.Setup(); err != nil {
mainLog.Load().Error().Err(err).Msg("could not configure router")
}
})
p.onStopped = append(p.onStopped, func() {
mainLog.Load().Debug().Msg("router cleanup on stop")
if err := p.router.Cleanup(); err != nil {
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
}
p.resetDNS()
})
if iface != "" {
p.onStarted = append(p.onStarted, func() {
mainLog.Load().Debug().Msg("router setup on start")
if err := p.router.Setup(); err != nil {
mainLog.Load().Error().Err(err).Msg("could not configure router")
}
})
p.onStopped = append(p.onStopped, func() {
mainLog.Load().Debug().Msg("router cleanup on stop")
if err := p.router.Cleanup(); err != nil {
mainLog.Load().Error().Err(err).Msg("could not cleanup router")
}
p.resetDNS()
})
}
}
close(waitCh)
@@ -907,10 +1027,17 @@ func writeConfigFile() error {
return nil
}
func readConfigFile(writeDefaultConfig bool) bool {
// readConfigFile reads in config file.
//
// - It writes default config file if config file not found if writeDefaultConfig is true.
// - It emits notice message to user if notice is true.
func readConfigFile(writeDefaultConfig, notice bool) bool {
// If err == nil, there's a config supplied via `--config`, no default config written.
err := v.ReadInConfig()
if err == nil {
if notice {
mainLog.Load().Notice().Msg("Reading config: " + v.ConfigFileUsed())
}
mainLog.Load().Info().Msg("loading config file from: " + v.ConfigFileUsed())
defaultConfigFile = v.ConfigFileUsed()
return true
@@ -925,7 +1052,8 @@ func readConfigFile(writeDefaultConfig bool) bool {
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
}
_ = updateListenerConfig()
nop := zerolog.Nop()
_, _ = tryUpdateListenerConfig(&cfg, &nop, true)
if err := writeConfigFile(); err != nil {
mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err)
} else {
@@ -933,6 +1061,9 @@ func readConfigFile(writeDefaultConfig bool) bool {
if err != nil {
mainLog.Load().Fatal().Msgf("failed to get default config file path: %v", err)
}
if cdUID == "" && nextdns == "" {
mainLog.Load().Notice().Msg("Generating controld default config: " + fp)
}
mainLog.Load().Info().Msg("writing default config file to: " + fp)
}
return false
@@ -1015,7 +1146,7 @@ func processNoConfigFlags(noConfigStart bool) {
v.Set("upstream", upstream)
}
func processCDFlags() error {
func processCDFlags(cfg *ctrld.Config) error {
logger := mainLog.Load().With().Str("mode", "cd").Logger()
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second)
@@ -1031,44 +1162,17 @@ func processCDFlags() error {
}
break
}
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
s, err := newService(&prog{}, svcConfig)
if err != nil {
logger.Warn().Err(err).Msg("failed to create new service")
return nil
}
if netIface, _ := netInterface(iface); netIface != nil {
if err := restoreNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not restore NetworkManager")
return nil
}
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
if err := resetDNS(netIface); err != nil {
logger.Warn().Err(err).Msg("something went wrong while restoring DNS")
} else {
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
}
}
tasks := []task{{s.Uninstall, true}}
if doTasks(tasks) {
logger.Info().Msg("uninstalled service")
}
event := logger.Fatal()
if isMobile() {
event = logger.Warn()
}
event.Err(uer).Msg("failed to fetch resolver config")
return uer
}
if err != nil {
if isMobile() {
return err
}
logger.Warn().Err(err).Msg("could not fetch resolver config")
return nil
return err
}
logger.Info().Msg("generating ctrld config from Control-D configuration")
cfg = ctrld.Config{}
*cfg = ctrld.Config{}
// Fetch config, unmarshal to cfg.
if resolverConfig.Ctrld.CustomConfig != "" {
logger.Info().Msg("using defined custom config of Control-D resolver")
@@ -1085,7 +1189,7 @@ func processCDFlags() error {
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
Endpoint: resolverConfig.DOH,
Type: ctrld.ResolverTypeDOH,
Type: cdUpstreamProto,
Timeout: 5000,
}
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
@@ -1213,6 +1317,11 @@ func selfCheckStatus(s service.Service) service.Status {
return service.StatusUnknown
}
// Not a ctrld upstream, return status as-is.
if cfg.FirstUpstream().VerifyDomain() == "" {
return status
}
mainLog.Load().Debug().Msg("ctrld listener is ready")
mainLog.Load().Debug().Msg("performing self-check")
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
@@ -1338,21 +1447,35 @@ func userHomeDir() (string, error) {
return dir, nil
}
// tryReadingConfig is like tryReadingConfigWithNotice, with notice set to false.
func tryReadingConfig(writeDefaultConfig bool) {
tryReadingConfigWithNotice(writeDefaultConfig, false)
}
// tryReadingConfigWithNotice tries reading in config files, either specified by user or from default
// locations. If notice is true, emitting a notice message to user which config file was read.
func tryReadingConfigWithNotice(writeDefaultConfig, notice bool) {
// --config is specified.
if configPath != "" {
v.SetConfigFile(configPath)
readConfigFile(false)
readConfigFile(false, notice)
return
}
// no config start or base64 config mode.
if !writeDefaultConfig {
return
}
readConfig(writeDefaultConfig)
readConfigWithNotice(writeDefaultConfig, notice)
}
// readConfig calls readConfigWithNotice with notice set to false.
func readConfig(writeDefaultConfig bool) {
readConfigWithNotice(writeDefaultConfig, false)
}
// readConfigWithNotice calls readConfigFile with config file set to ctrld.toml
// or config.toml for compatible with earlier versions of ctrld.
func readConfigWithNotice(writeDefaultConfig, notice bool) {
configs := []struct {
name string
written bool
@@ -1369,7 +1492,7 @@ func readConfig(writeDefaultConfig bool) {
for _, config := range configs {
ctrld.SetConfigNameWithPath(v, config.name, dir)
v.SetConfigFile(configPath)
if readConfigFile(config.written) {
if readConfigFile(config.written, notice) {
break
}
}
@@ -1401,18 +1524,17 @@ func uninstall(p *prog, s service.Service) {
}
}
func validateConfig(cfg *ctrld.Config) {
err := ctrld.ValidateConfig(validator.New(), cfg)
if err == nil {
return
}
var ve validator.ValidationErrors
if errors.As(err, &ve) {
for _, fe := range ve {
mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe))
func validateConfig(cfg *ctrld.Config) error {
if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil {
var ve validator.ValidationErrors
if errors.As(err, &ve) {
for _, fe := range ve {
mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe))
}
}
return err
}
os.Exit(1)
return nil
}
// NOTE: Add more case here once new validation tag is used in ctrld.Config struct.
@@ -1483,9 +1605,19 @@ func mobileListenerPort() int {
// updateListenerConfig updates the config for listeners if not defined,
// or defined but invalid to be used, e.g: using loopback address other
// than 127.0.0.1 with systemd-resolved.
func updateListenerConfig() (updated bool) {
func updateListenerConfig(cfg *ctrld.Config) bool {
updated, _ := tryUpdateListenerConfig(cfg, nil, true)
return updated
}
// tryUpdateListenerConfig tries updating listener config with a working one.
// If fatal is true, and there's listen address conflicted, the function do
// fatal error.
func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fatal bool) (updated, ok bool) {
ok = true
lcc := make(map[string]*listenerConfigCheck)
cdMode := cdUID != ""
nextdnsMode := nextdns != ""
for n, listener := range cfg.Listener {
lcc[n] = &listenerConfigCheck{}
if listener.IP == "" {
@@ -1497,12 +1629,17 @@ func updateListenerConfig() (updated bool) {
lcc[n].Port = true
}
// In cd mode, we always try to pick an ip:port pair to work.
if cdMode {
// Same if nextdns resolver is used.
if cdMode || nextdnsMode {
lcc[n].IP = true
lcc[n].Port = true
}
updated = updated || lcc[n].IP || lcc[n].Port
}
il := mainLog.Load()
if infoLogger != nil {
il = infoLogger
}
if isMobile() {
// On Mobile, only use first listener, ignore others.
firstLn := cfg.FirstListener()
@@ -1594,7 +1731,11 @@ func updateListenerConfig() (updated bool) {
break
}
if !check.IP && !check.Port {
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
if fatal {
logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err)
}
ok = false
break
}
if tryAllPort53 {
tryAllPort53 = false
@@ -1605,7 +1746,7 @@ func updateListenerConfig() (updated bool) {
listener.Port = 53
}
if check.IP {
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)))
logMsg(il.Info(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)))
}
continue
}
@@ -1618,7 +1759,7 @@ func updateListenerConfig() (updated bool) {
listener.Port = 53
}
if check.IP {
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)))
logMsg(il.Info(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)))
}
continue
}
@@ -1630,7 +1771,7 @@ func updateListenerConfig() (updated bool) {
if check.Port {
listener.Port = 5354
}
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying current ip with port 5354", addr)
logMsg(il.Info(), n, "could not listen on address: %s, trying current ip with port 5354", addr)
continue
}
if tryPort5354 {
@@ -1641,7 +1782,7 @@ func updateListenerConfig() (updated bool) {
if check.Port {
listener.Port = 5354
}
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr)
logMsg(il.Info(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr)
continue
}
if check.IP && !isZeroIP { // for "0.0.0.0" or "::", we only need to try new port.
@@ -1655,12 +1796,19 @@ func updateListenerConfig() (updated bool) {
listener.Port = oldPort
}
if listener.IP == oldIP && listener.Port == oldPort {
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
if fatal {
logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err)
}
ok = false
break
}
logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, pick a random ip+port", addr)
logMsg(il.Info(), n, "could not listen on address: %s, pick a random ip+port", addr)
attempts++
}
}
if !ok {
return
}
// Specific case for systemd-resolved.
if useSystemdResolved {
@@ -1670,7 +1818,7 @@ func updateListenerConfig() (updated bool) {
// ip address, other than "127.0.0.1", so trying to listen on default route interface
// address instead.
if ip := net.ParseIP(listener.IP); ip != nil && ip.IsLoopback() && ip.String() != "127.0.0.1" {
logMsg(mainLog.Load().Warn(), n, "using loopback interface do not work with systemd-resolved")
logMsg(il.Info(), n, "using loopback interface do not work with systemd-resolved")
found := false
if netIface, _ := net.InterfaceByName(defaultIfaceName()); netIface != nil {
addrs, _ := netIface.Addrs()
@@ -1680,7 +1828,7 @@ func updateListenerConfig() (updated bool) {
if err := tryListen(addr); err == nil {
found = true
listener.IP = netIP.IP.String()
logMsg(mainLog.Load().Warn(), n, "use %s as listener address", listener.IP)
logMsg(il.Info(), n, "use %s as listener address", listener.IP)
break
}
}
@@ -1742,12 +1890,12 @@ func removeProvTokenFromArgs(sc *service.Config) {
continue
}
// For "--cd-org XXX", skip it and mark next arg skipped.
if x == cdOrgFlagName {
if x == "--"+cdOrgFlagName {
skip = true
continue
}
// For "--cd-org=XXX", just skip it.
if strings.HasPrefix(x, cdOrgFlagName+"=") {
if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") {
continue
}
a = append(a, x)
@@ -1795,6 +1943,64 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) {
return
}
if fl.Value.String() == "" {
mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name)
mainLog.Load().Fatal().Msgf(`flag "--%s" value must be non-empty`, fl.Name)
}
}
func validateCdUpstreamProtocol() {
if cdUID == "" {
return
}
switch cdUpstreamProto {
case ctrld.ResolverTypeDOH, ctrld.ResolverTypeDOH3:
default:
mainLog.Load().Fatal().Msg(`flag "--protocol" must be "doh" or "doh3"`)
}
}
func validateCdAndNextDNSFlags() {
if (cdUID != "" || cdOrg != "") && nextdns != "" {
mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName)
}
}
// removeNextDNSFromArgs removes the --nextdns from command line arguments.
func removeNextDNSFromArgs(sc *service.Config) {
a := sc.Arguments[:0]
skip := false
for _, x := range sc.Arguments {
if skip {
skip = false
continue
}
// For "--nextdns XXX", skip it and mark next arg skipped.
if x == "--"+nextdnsFlagName {
skip = true
continue
}
// For "--nextdns=XXX", just skip it.
if strings.HasPrefix(x, "--"+nextdnsFlagName+"=") {
continue
}
a = append(a, x)
}
sc.Arguments = a
}
// doGenerateNextDNSConfig generates a working config with nextdns resolver.
func doGenerateNextDNSConfig(uid string) error {
if uid == "" {
return nil
}
mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile)
generateNextDNSConfig(uid)
updateListenerConfig(&cfg)
return writeConfigFile()
}
func noticeWritingControlDConfig() error {
if cdUID != "" {
mainLog.Load().Notice().Msgf("Generating controld config: %s", defaultConfigFile)
}
return nil
}

View File

@@ -6,14 +6,18 @@ import (
"net"
"net/http"
"os"
"reflect"
"sort"
"time"
"github.com/Control-D-Inc/ctrld"
)
const (
contentTypeJson = "application/json"
listClientsPath = "/clients"
startedPath = "/started"
reloadPath = "/reload"
)
type controlServer struct {
@@ -75,6 +79,52 @@ func (p *prog) registerControlServerHandler() {
w.WriteHeader(http.StatusRequestTimeout)
}
}))
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
listeners := make(map[string]*ctrld.ListenerConfig)
p.mu.Lock()
for k, v := range p.cfg.Listener {
listeners[k] = &ctrld.ListenerConfig{
IP: v.IP,
Port: v.Port,
}
}
oldSvc := p.cfg.Service
p.mu.Unlock()
if err := p.sendReloadSignal(); err != nil {
mainLog.Load().Err(err).Msg("could not send reload signal")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
select {
case <-p.reloadDoneCh:
case <-time.After(5 * time.Second):
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
return
}
p.mu.Lock()
defer p.mu.Unlock()
// Checking for cases that we could not do a reload.
// 1. Listener config ip or port changes.
for k, v := range p.cfg.Listener {
l := listeners[k]
if l == nil || l.IP != v.IP || l.Port != v.Port {
w.WriteHeader(http.StatusCreated)
return
}
}
// 2. Service config changes.
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
w.WriteHeader(http.StatusCreated)
return
}
// Otherwise, reload is done.
w.WriteHeader(http.StatusOK)
}))
}
func jsonResponse(next http.Handler) http.Handler {

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
@@ -17,6 +18,7 @@ import (
"golang.org/x/sync/errgroup"
"tailscale.com/net/interfaces"
"tailscale.com/net/netaddr"
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
@@ -25,6 +27,7 @@ import (
const (
staleTTL = 60 * time.Second
localTTL = 3600 * time.Second
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
@@ -37,6 +40,29 @@ var osUpstreamConfig = &ctrld.UpstreamConfig{
Timeout: 2000,
}
var privateUpstreamConfig = &ctrld.UpstreamConfig{
Name: "Private resolver",
Type: ctrld.ResolverTypePrivate,
Timeout: 2000,
}
// proxyRequest contains data for proxying a DNS query to upstream.
type proxyRequest struct {
msg *dns.Msg
ci *ctrld.ClientInfo
failoverRcodes []int
ufr *upstreamForResult
}
// upstreamForResult represents the result of processing rules for a request.
type upstreamForResult struct {
upstreams []string
matchedPolicy string
matchedNetwork string
matchedRule string
matched bool
}
func (p *prog) serveDNS(listenerNum string) error {
listenerConfig := p.cfg.Listener[listenerNum]
// make sure ip is allocated
@@ -44,36 +70,58 @@ func (p *prog) serveDNS(listenerNum string) error {
mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
return allocErr
}
var failoverRcodes []int
if listenerConfig.Policy != nil {
failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers
}
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
p.sema.acquire()
defer p.sema.release()
if len(m.Question) == 0 {
answer := new(dns.Msg)
answer.SetRcode(m, dns.RcodeFormatError)
_ = w.WriteMsg(answer)
return
}
reqId := requestID()
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
answer := new(dns.Msg)
answer.SetRcode(m, dns.RcodeRefused)
_ = w.WriteMsg(answer)
return
}
go p.detectLoop(m)
q := m.Question[0]
domain := canonicalName(q.Name)
reqId := requestID()
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
ci := p.getClientInfo(remoteIP, m)
ci.ClientIDPref = p.cfg.Service.ClientIDPref
stripClientSubnet(m)
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci)
fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String())
t := time.Now()
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
ctrld.Log(ctx, mainLog.Load().Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, domain)
res := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
var answer *dns.Msg
if !matched && listenerConfig.Restricted {
if !res.matched && listenerConfig.Restricted {
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String())
answer = new(dns.Msg)
answer.SetRcode(m, dns.RcodeRefused)
} else {
answer = p.proxy(ctx, upstreams, failoverRcodes, m, ci)
var failoverRcode []int
if listenerConfig.Policy != nil {
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
}
answer = p.proxy(ctx, &proxyRequest{
msg: m,
ci: ci,
failoverRcodes: failoverRcode,
ufr: res,
})
rtt := time.Since(t)
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
}
if err := w.WriteMsg(answer); err != nil {
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveUDP: failed to send DNS response to client")
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client")
}
})
@@ -99,7 +147,7 @@ func (p *prog) serveDNS(listenerNum string) error {
// addresses of the machine. So ctrld could receive queries from LAN clients.
if needRFC1918Listeners(listenerConfig) {
g.Go(func() error {
for _, addr := range rfc1918Addresses() {
for _, addr := range ctrld.Rfc1918Addresses() {
func() {
listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port))
s, errCh := runDNSServer(listenAddr, proto, handler)
@@ -146,27 +194,24 @@ func (p *prog) serveDNS(listenerNum string) error {
// Though domain policy has higher priority than network policy, it is still
// processed later, because policy logging want to know whether a network rule
// is disregarded in favor of the domain level rule.
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) {
upstreams := []string{upstreamPrefix + defaultUpstreamNum}
matchedPolicy := "no policy"
matchedNetwork := "no network"
matchedRule := "no rule"
matched := false
res = &upstreamForResult{}
defer func() {
if !matched && lc.Restricted {
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", addr.String())
return
}
if matched {
ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams)
} else {
ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams)
}
res.upstreams = upstreams
res.matched = matched
res.matchedPolicy = matchedPolicy
res.matchedNetwork = matchedNetwork
res.matchedRule = matchedRule
}()
if lc.Policy == nil {
return upstreams, false
return
}
do := func(policyUpstreams []string) {
@@ -202,6 +247,19 @@ networkRules:
}
}
macRules:
for _, rule := range lc.Policy.Macs {
for source, targets := range rule {
if source != "" && strings.EqualFold(source, srcMac) {
matchedPolicy = lc.Policy.Name
matchedNetwork = source
networkTargets = targets
matched = true
break macRules
}
}
}
for _, rule := range lc.Policy.Rules {
// There's only one entry per rule, config validation ensures this.
for source, targets := range rule {
@@ -213,7 +271,7 @@ networkRules:
matchedRule = source
do(targets)
matched = true
return upstreams, matched
return
}
}
}
@@ -222,26 +280,134 @@ networkRules:
do(networkTargets)
}
return upstreams, matched
return
}
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo) *dns.Msg {
func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg {
cDomainName := msg.Question[0].Name
locked := p.ptrLoopGuard.TryLock(cDomainName)
defer p.ptrLoopGuard.Unlock(cDomainName)
if !locked {
return nil
}
ip := ipFromARPA(cDomainName)
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
answer := new(dns.Msg)
answer.SetReply(msg)
answer.Compress = true
answer.Answer = []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
},
Ptr: dns.Fqdn(name),
}}
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
Mac: p.ciTable.LookupMac(ip.String()),
IP: ip.String(),
Hostname: name,
})
return answer
}
return nil
}
func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg {
q := msg.Question[0]
hostname := strings.TrimSuffix(q.Name, ".")
locked := p.lanLoopGuard.TryLock(hostname)
defer p.lanLoopGuard.Unlock(hostname)
if !locked {
return nil
}
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
answer := new(dns.Msg)
answer.SetReply(msg)
answer.Compress = true
switch {
case ip.Is4():
answer.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: uint32(localTTL.Seconds()),
},
A: ip.AsSlice(),
}}
case ip.Is6():
answer.Answer = []dns.RR{&dns.AAAA{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: uint32(localTTL.Seconds()),
},
AAAA: ip.AsSlice(),
}}
}
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
Mac: p.ciTable.LookupMac(ip.String()),
IP: ip.String(),
Hostname: hostname,
})
return answer
}
return nil
}
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *dns.Msg {
var staleAnswer *dns.Msg
upstreams := req.ufr.upstreams
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
upstreams = []string{upstreamOS}
}
// LAN/PTR lookup flow:
//
// 1. If there's matching rule, follow it.
// 2. Try from client info table.
// 3. Try private resolver.
// 4. Try remote upstream.
isLanOrPtrQuery := false
if req.ufr.matched {
ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
} else {
switch {
case isPrivatePtrLookup(req.msg):
isLanOrPtrQuery = true
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
return answer
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams)
case isLanHostnameQuery(req.msg):
isLanOrPtrQuery = true
if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil {
return answer
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams)
default:
ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams)
}
}
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
for _, upstream := range upstreams {
cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream))
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
if cachedValue == nil {
continue
}
answer := cachedValue.Msg.Copy()
answer.SetRcode(msg, answer.Rcode)
answer.SetRcode(req.msg, answer.Rcode)
now := time.Now()
if cachedValue.Expire.After(now) {
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
@@ -268,9 +434,9 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
return dnsResolver.Resolve(resolveCtx, msg)
}
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
if upstreamConfig.UpstreamSendClientInfo() && ci != nil {
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci)
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
}
answer, err := resolve1(n, upstreamConfig, msg)
if err != nil {
@@ -281,6 +447,11 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
go p.um.checkUpstream(upstreams[n], upstreamConfig)
}
}
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
var e net.Error
if errors.As(err, &e) && e.Timeout() {
upstreamConfig.ReBootstrap()
}
return nil
}
return answer
@@ -297,7 +468,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
continue
}
answer := resolve(n, upstreamConfig, msg)
answer := resolve(n, upstreamConfig, req.msg)
if answer == nil {
if serveStaleCache && staleAnswer != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
@@ -307,7 +478,13 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
}
continue
}
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
// We are doing LAN/PTR lookup using private resolver, so always process next one.
// Except for the last, we want to send response instead of saying all upstream failed.
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n])
continue
}
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) {
ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream")
continue
}
@@ -315,7 +492,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
// set compression, as it is not set by default when unpacking
answer.Compress = true
if p.cache != nil {
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
ttl := ttlFromMsg(answer)
now := time.Now()
expired := now.Add(time.Duration(ttl) * time.Second)
@@ -323,17 +500,27 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
expired = now.Add(time.Duration(cachedTTL) * time.Second)
}
setCachedAnswerTTL(answer, now, expired)
p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired))
p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired))
ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response")
}
return answer
}
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
answer := new(dns.Msg)
answer.SetRcode(msg, dns.RcodeServerFailure)
answer.SetRcode(req.msg, dns.RcodeServerFailure)
return answer
}
func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
if len(p.localUpstreams) > 0 {
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
tmp = append(tmp, p.localUpstreams...)
tmp = append(tmp, upstreams...)
return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp)
}
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
}
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
for _, upstream := range upstreams {
@@ -454,6 +641,23 @@ func ipAndMacFromMsg(msg *dns.Msg) (string, string) {
return ip, mac
}
// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address,
// passing them to upstream is pointless, these cannot be used by anything on the WAN.
func stripClientSubnet(msg *dns.Msg) {
if opt := msg.IsEdns0(); opt != nil {
opts := make([]dns.EDNS0, 0, len(opt.Option))
for _, s := range opt.Option {
if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) {
continue
}
opts = append(opts, s)
}
if len(opts) != len(opt.Option) {
opt.Option = opts
}
}
}
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
if ci != nil && ci.IP != "" {
switch addr := addr.(type) {
@@ -531,23 +735,44 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
}
// If MAC is still empty here, that mean the requests are made from virtual interface,
// like VPN/Wireguard clients, so we use whatever MAC address associated with remoteIP
// (most likely 127.0.0.1), and ci.IP as hostname, so we can distinguish those clients.
// like VPN/Wireguard clients, so we use ci.IP as hostname to distinguish those clients.
if ci.Mac == "" {
ci.Mac = p.ciTable.LookupMac(remoteIP)
if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" {
ci.Hostname = hostname
} else {
ci.Hostname = ci.IP
p.ciTable.StoreVPNClient(ci)
// Only use IP as hostname for IPv4 clients.
// For Android devices, when it joins the network, it uses ctrld to resolve
// its private DNS once and never reaches ctrld again. For each time, it uses
// a different IPv6 address, which causes hundreds/thousands different client
// IDs created for the same device, which is pointless.
//
// TODO(cuonglm): investigate whether this can be a false positive for other clients?
if !ctrldnet.IsIPv6(ci.IP) {
ci.Hostname = ci.IP
p.ciTable.StoreVPNClient(ci)
}
}
} else {
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
}
ci.Self = queryFromSelf(ci.IP)
p.spoofLoopbackIpInClientInfo(ci)
return ci
}
// spoofLoopbackIpInClientInfo replaces loopback IPs in client info.
//
// - Preference IPv4.
// - Preference RFC1918.
func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
if ip := net.ParseIP(ci.IP); ip == nil || !ip.IsLoopback() {
return
}
if ip := p.ciTable.LookupRFC1918IPv4(ci.Mac); ip != "" {
ci.IP = ip
}
}
// queryFromSelf reports whether the input IP is from device running ctrld.
func queryFromSelf(ip string) bool {
netIP := netip.MustParseAddr(ip)
@@ -578,17 +803,86 @@ func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
return lc.IP == "127.0.0.1" && lc.Port == 53
}
func rfc1918Addresses() []string {
var res []string
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
addrs, _ := i.Addrs()
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok || !ipNet.IP.IsPrivate() {
continue
}
res = append(res, ipNet.IP.String())
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.
func ipFromARPA(arpa string) net.IP {
if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok {
if ptrIP := net.ParseIP(arpa); ptrIP != nil {
return net.IP{ptrIP[15], ptrIP[14], ptrIP[13], ptrIP[12]}
}
})
return res
}
if arpa, ok := strings.CutSuffix(arpa, ".ip6.arpa."); ok {
l := net.IPv6len * 2
base := 16
ip := make(net.IP, net.IPv6len)
for i := 0; i < l && arpa != ""; i++ {
idx := strings.LastIndexByte(arpa, '.')
off := idx + 1
if idx == -1 {
idx = 0
off = 0
} else if idx == len(arpa)-1 {
return nil
}
n, err := strconv.ParseUint(arpa[off:], base, 8)
if err != nil {
return nil
}
b := byte(n)
ii := i / 2
if i&1 == 1 {
b |= ip[ii] << 4
}
ip[ii] = b
arpa = arpa[:idx]
}
return ip
}
return nil
}
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network.
func isPrivatePtrLookup(m *dns.Msg) bool {
if m == nil || len(m.Question) == 0 {
return false
}
q := m.Question[0]
if ip := ipFromARPA(q.Name); ip != nil {
if addr, ok := netip.AddrFromSlice(ip); ok {
return addr.IsPrivate() ||
addr.IsLoopback() ||
addr.IsLinkLocalUnicast() ||
tsaddr.CGNATRange().Contains(addr)
}
}
return false
}
// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname.
func isLanHostnameQuery(m *dns.Msg) bool {
if m == nil || len(m.Question) == 0 {
return false
}
q := m.Question[0]
switch q.Qtype {
case dns.TypeA, dns.TypeAAAA:
default:
return false
}
name := strings.TrimSuffix(q.Name, ".")
return !strings.Contains(name, ".") ||
strings.HasSuffix(name, ".domain") ||
strings.HasSuffix(name, ".lan")
}
// isWanClient reports whether the input is a WAN address.
func isWanClient(na net.Addr) bool {
var ip netip.Addr
if ap, err := netip.ParseAddrPort(na.String()); err == nil {
ip = ap.Addr()
}
return !ip.IsLoopback() &&
!ip.IsPrivate() &&
!ip.IsLinkLocalUnicast() &&
!ip.IsLinkLocalMulticast() &&
!tsaddr.CGNATRange().Contains(ip)
}

View File

@@ -67,8 +67,11 @@ func Test_canonicalName(t *testing.T) {
func Test_prog_upstreamFor(t *testing.T) {
cfg := testhelper.SampleConfig(t)
prog := &prog{cfg: cfg}
for _, nc := range prog.cfg.Network {
p := &prog{cfg: cfg}
p.um = newUpstreamMonitor(p.cfg)
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
for _, nc := range p.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
@@ -81,6 +84,7 @@ func Test_prog_upstreamFor(t *testing.T) {
tests := []struct {
name string
ip string
mac string
defaultUpstreamNum string
lc *ctrld.ListenerConfig
domain string
@@ -88,11 +92,14 @@ func Test_prog_upstreamFor(t *testing.T) {
matched bool
testLogMsg string
}{
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
}
for _, tc := range tests {
@@ -111,9 +118,13 @@ func Test_prog_upstreamFor(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, addr)
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
assert.Equal(t, tc.matched, matched)
assert.Equal(t, tc.upstreams, upstreams)
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
p.proxy(ctx, &proxyRequest{
msg: newDnsMsgWithHostname("foo", dns.TypeA),
ufr: ufr,
})
assert.Equal(t, tc.matched, ufr.matched)
assert.Equal(t, tc.upstreams, ufr.upstreams)
if tc.testLogMsg != "" {
assert.Contains(t, logOutput.String(), tc.testLogMsg)
}
@@ -149,8 +160,32 @@ func TestCache(t *testing.T) {
answer2.SetRcode(msg, dns.RcodeRefused)
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil)
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil)
req1 := &proxyRequest{
msg: msg,
ci: nil,
failoverRcodes: nil,
ufr: &upstreamForResult{
upstreams: []string{"upstream.1"},
matchedPolicy: "",
matchedNetwork: "",
matchedRule: "",
matched: false,
},
}
req2 := &proxyRequest{
msg: msg,
ci: nil,
failoverRcodes: nil,
ufr: &upstreamForResult{
upstreams: []string{"upstream.0"},
matchedPolicy: "",
matchedNetwork: "",
matchedRule: "",
matched: false,
},
}
got1 := prog.proxy(context.Background(), req1)
got2 := prog.proxy(context.Background(), req2)
assert.NotSame(t, got1, got2)
assert.Equal(t, answer1.Rcode, got1.Rcode)
assert.Equal(t, answer2.Rcode, got2.Rcode)
@@ -234,3 +269,165 @@ func Test_remoteAddrFromMsg(t *testing.T) {
})
}
}
func Test_ipFromARPA(t *testing.T) {
tests := []struct {
IP string
ARPA string
}{
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
{"", "asd.in-addr.arpa."},
{"", "asd.ip6.arpa."},
}
for _, tc := range tests {
tc := tc
t.Run(tc.IP, func(t *testing.T) {
t.Parallel()
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
}
})
}
}
func newDnsMsgWithClientIP(ip string) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA)
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
m.Extra = append(m.Extra, o)
return m
}
func Test_stripClientSubnet(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
wantSubnet bool
}{
{"no edns0", new(dns.Msg), false},
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
{"invalid IP", newDnsMsgWithClientIP(""), true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
stripClientSubnet(tc.msg)
hasSubnet := false
if opt := tc.msg.IsEdns0(); opt != nil {
for _, s := range opt.Option {
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
hasSubnet = true
}
}
}
if tc.wantSubnet != hasSubnet {
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
}
})
}
}
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(hostname, typ)
return m
}
func Test_isLanHostnameQuery(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
isLanHostnameQuery bool
}{
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
}
})
}
}
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
t.Helper()
m := new(dns.Msg)
ptr, err := dns.ReverseAddr(ip)
if err != nil {
t.Fatal(err)
}
m.SetQuestion(ptr, dns.TypePTR)
return m
}
func Test_isPrivatePtrLookup(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
isPrivatePtrLookup bool
}{
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
}
})
}
}
func Test_isWanClient(t *testing.T) {
tests := []struct {
name string
addr net.Addr
isWanClient bool
}{
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isWanClient(tc.addr); tc.isWanClient != got {
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
}
})
}
}

View File

@@ -3,6 +3,7 @@ package cli
import (
"context"
"strings"
"sync"
"time"
"github.com/miekg/dns"
@@ -15,6 +16,36 @@ const (
loopTestQtype = dns.TypeTXT
)
// newLoopGuard returns new loopGuard.
func newLoopGuard() *loopGuard {
return &loopGuard{inflight: make(map[string]struct{})}
}
// loopGuard guards against DNS loop, ensuring only one query
// for a given domain is processed at a time.
type loopGuard struct {
mu sync.Mutex
inflight map[string]struct{}
}
// TryLock marks the domain as being processed.
func (lg *loopGuard) TryLock(domain string) bool {
lg.mu.Lock()
defer lg.mu.Unlock()
if _, inflight := lg.inflight[domain]; !inflight {
lg.inflight[domain] = struct{}{}
return true
}
return false
}
// Unlock marks the domain as being done.
func (lg *loopGuard) Unlock(domain string) {
lg.mu.Lock()
defer lg.mu.Unlock()
delete(lg.inflight, domain)
}
// isLoop reports whether the given upstream config is detected as having DNS loop.
func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool {
p.loopMu.Lock()
@@ -56,7 +87,15 @@ func (p *prog) checkDnsLoop() {
mainLog.Load().Debug().Msg("start checking DNS loop")
upstream := make(map[string]*ctrld.UpstreamConfig)
p.loopMu.Lock()
for _, uc := range p.cfg.Upstream {
for n, uc := range p.cfg.Upstream {
if p.um.isDown("upstream." + n) {
continue
}
// Do not send test query to external upstream.
if !canBeLocalUpstream(uc.Domain) {
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
continue
}
uid := uc.UID()
p.loop[uid] = false
upstream[uid] = uc
@@ -79,13 +118,15 @@ func (p *prog) checkDnsLoop() {
}
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
func (p *prog) checkDnsLoopTicker() {
func (p *prog) checkDnsLoopTicker(ctx context.Context) {
timer := time.NewTicker(time.Minute)
defer timer.Stop()
for {
select {
case <-p.stopCh:
return
case <-ctx.Done():
return
case <-timer.C:
p.checkDnsLoop()
}

42
cmd/cli/loop_test.go Normal file
View File

@@ -0,0 +1,42 @@
package cli
import (
"sync"
"sync/atomic"
"testing"
)
func Test_loopGuard(t *testing.T) {
lg := newLoopGuard()
key := "foo"
var i atomic.Int64
var started atomic.Int64
n := 1000
do := func() {
locked := lg.TryLock(key)
defer lg.Unlock(key)
started.Add(1)
for started.Load() < 2 {
// Wait until at least 2 goroutines started, otherwise, on system with heavy load,
// or having only 1 CPU, all goroutines can be scheduled to run consequently.
}
if locked {
i.Add(1)
}
}
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
do()
}()
}
wg.Wait()
if i.Load() == int64(n) {
t.Fatalf("i must not be increased %d times", n)
}
}

View File

@@ -32,6 +32,8 @@ var (
cdDev bool
iface string
ifaceStartStop string
nextdns string
cdUpstreamProto string
mainLog atomic.Pointer[zerolog.Logger]
consoleWriter zerolog.ConsoleWriter
@@ -39,8 +41,9 @@ var (
)
const (
cdUidFlagName = "cd"
cdOrgFlagName = "cd-org"
cdUidFlagName = "cd"
cdOrgFlagName = "cd-org"
nextdnsFlagName = "nextdns"
)
func init() {
@@ -93,6 +96,7 @@ func initConsoleLogging() {
// initLogging initializes global logging setup.
func initLogging() {
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
initLoggingWithBackup(true)
}
@@ -131,7 +135,7 @@ func initLoggingWithBackup(doBackup bool) {
}
writers = append(writers, consoleWriter)
multi := zerolog.MultiLevelWriter(writers...)
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
l := mainLog.Load().Output(multi).With().Logger()
mainLog.Store(&l)
// TODO: find a better way.
ctrld.ProxyLogger.Store(&l)

View File

@@ -1,11 +1,13 @@
package cli
import (
"context"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
func (p *prog) watchLinkState() {
func (p *prog) watchLinkState(ctx context.Context) {
ch := make(chan netlink.LinkUpdate)
done := make(chan struct{})
defer close(done)
@@ -13,14 +15,19 @@ func (p *prog) watchLinkState() {
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
return
}
for lu := range ch {
if lu.Change == 0xFFFFFFFF {
continue
}
if lu.Change&unix.IFF_UP != 0 {
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
for _, uc := range p.cfg.Upstream {
uc.ReBootstrap()
for {
select {
case <-ctx.Done():
return
case lu := <-ch:
if lu.Change == 0xFFFFFFFF {
continue
}
if lu.Change&unix.IFF_UP != 0 {
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
for _, uc := range p.cfg.Upstream {
uc.ReBootstrap()
}
}
}
}

View File

@@ -2,4 +2,6 @@
package cli
func (p *prog) watchLinkState() {}
import "context"
func (p *prog) watchLinkState(ctx context.Context) {}

31
cmd/cli/nextdns.go Normal file
View File

@@ -0,0 +1,31 @@
package cli
import (
"fmt"
"github.com/Control-D-Inc/ctrld"
)
const nextdnsURL = "https://dns.nextdns.io"
func generateNextDNSConfig(uid string) {
if uid == "" {
return
}
mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver")
cfg = ctrld.Config{
Listener: map[string]*ctrld.ListenerConfig{
"0": {
IP: "0.0.0.0",
Port: 53,
},
},
Upstream: map[string]*ctrld.UpstreamConfig{
"0": {
Type: ctrld.ResolverTypeDOH3,
Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, uid),
Timeout: 5000,
},
},
}
}

View File

@@ -9,10 +9,12 @@ import (
"net"
"net/netip"
"os/exec"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/fsnotify/fsnotify"
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/dhcpv6/client6"
@@ -23,7 +25,10 @@ import (
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
const (
resolvConfPath = "/etc/resolv.conf"
resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
)
// allocate loopback ip
// sudo ip a add 127.0.0.2/24 dev lo
@@ -64,6 +69,11 @@ func setDNS(iface *net.Interface, nameservers []string) error {
Nameservers: ns,
SearchDomains: []dnsname.FQDN{},
}
defer func() {
if r.Mode() == "direct" {
go watchResolveConf(osConfig)
}
}()
trySystemdResolve := false
for i := 0; i < maxSetDNSAttempts; i++ {
@@ -299,3 +309,59 @@ func sliceIndex[S ~[]E, E comparable](s S, v E) int {
}
return -1
}
// watchResolveConf watches any changes to /etc/resolv.conf file,
// and reverting to the original config set by ctrld.
func watchResolveConf(oc dns.OSConfig) {
mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file")
watcher, err := fsnotify.NewWatcher()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
return
}
// We watch /etc instead of /etc/resolv.conf directly,
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
watchDir := filepath.Dir(resolvConfPath)
if err := watcher.Add(watchDir); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list")
return
}
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter.
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
return
}
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Name != resolvConfPath { // skip if not /etc/resolv.conf changes.
continue
}
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
if err := watcher.Remove(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
continue
}
if err := r.SetDNS(oc); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
}
if err := watcher.Add(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
return
}
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
}
}
}

View File

@@ -2,6 +2,7 @@ package cli
import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
@@ -16,7 +17,9 @@ import (
"syscall"
"github.com/kardianos/service"
"github.com/spf13/viper"
"tailscale.com/net/interfaces"
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
@@ -30,6 +33,7 @@ const (
ctrldControlUnixSock = "ctrld_control.sock"
upstreamPrefix = "upstream."
upstreamOS = upstreamPrefix + "os"
upstreamPrivate = upstreamPrefix + "private"
)
var logf = func(format string, args ...any) {
@@ -45,19 +49,25 @@ var svcConfig = &service.Config{
var useSystemdResolved = false
type prog struct {
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
logConn net.Conn
cs *controlServer
mu sync.Mutex
waitCh chan struct{}
stopCh chan struct{}
reloadCh chan struct{} // For Windows.
reloadDoneCh chan struct{}
logConn net.Conn
cs *controlServer
cfg *ctrld.Config
appCallback *AppCallback
cache dnscache.Cacher
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
cfg *ctrld.Config
localUpstreams []string
ptrNameservers []string
appCallback *AppCallback
cache dnscache.Cacher
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
ptrLoopGuard *loopGuard
lanLoopGuard *loopGuard
loopMu sync.Mutex
loop map[string]bool
@@ -69,11 +79,106 @@ type prog struct {
}
func (p *prog) Start(s service.Service) error {
p.cfg = &cfg
go p.run()
go p.runWait()
return nil
}
// runWait runs ctrld components, waiting for signal to reload.
func (p *prog) runWait() {
p.mu.Lock()
p.cfg = &cfg
p.mu.Unlock()
reloadSigCh := make(chan os.Signal, 1)
notifyReloadSigCh(reloadSigCh)
reload := false
logger := mainLog.Load()
for {
reloadCh := make(chan struct{})
done := make(chan struct{})
go func() {
defer close(done)
p.run(reload, reloadCh)
reload = true
}()
select {
case sig := <-reloadSigCh:
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
case <-p.reloadCh:
logger.Notice().Msg("reloading...")
case <-p.stopCh:
close(reloadCh)
return
}
waitOldRunDone := func() {
close(reloadCh)
<-done
}
newCfg := &ctrld.Config{}
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.InitConfig(v, "ctrld")
if configPath != "" {
v.SetConfigFile(configPath)
}
if err := v.ReadInConfig(); err != nil {
logger.Err(err).Msg("could not read new config")
waitOldRunDone()
continue
}
if err := v.Unmarshal(&newCfg); err != nil {
logger.Err(err).Msg("could not unmarshal new config")
waitOldRunDone()
continue
}
if cdUID != "" {
if err := processCDFlags(newCfg); err != nil {
logger.Err(err).Msg("could not fetch ControlD config")
waitOldRunDone()
continue
}
}
waitOldRunDone()
p.mu.Lock()
curListener := p.cfg.Listener
p.mu.Unlock()
for n, lc := range newCfg.Listener {
curLc := curListener[n]
if curLc == nil {
continue
}
if lc.IP == "" {
lc.IP = curLc.IP
}
if lc.Port == 0 {
lc.Port = curLc.Port
}
}
if err := validateConfig(newCfg); err != nil {
logger.Err(err).Msg("invalid config")
continue
}
// This needs to be done here, otherwise, the DNS handler may observe an invalid
// upstream config because its initialization function have not been called yet.
mainLog.Load().Debug().Msg("setup upstream with new config")
p.setupUpstream(newCfg)
p.mu.Lock()
*p.cfg = *newCfg
p.mu.Unlock()
logger.Notice().Msg("reloading config successfully")
select {
case p.reloadDoneCh <- struct{}{}:
default:
}
}
}
func (p *prog) preRun() {
if !service.Interactive() {
p.setDNS()
@@ -87,14 +192,54 @@ func (p *prog) preRun() {
}
}
func (p *prog) run() {
func (p *prog) setupUpstream(cfg *ctrld.Config) {
localUpstreams := make([]string, 0, len(cfg.Upstream))
ptrNameservers := make([]string, 0, len(cfg.Upstream))
for n := range cfg.Upstream {
uc := cfg.Upstream[n]
uc.Init()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
} else {
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
}
uc.SetCertPool(rootCertPool)
go uc.Ping()
if canBeLocalUpstream(uc.Domain) {
localUpstreams = append(localUpstreams, upstreamPrefix+n)
}
if uc.IsDiscoverable() {
ptrNameservers = append(ptrNameservers, uc.Endpoint)
}
}
p.localUpstreams = localUpstreams
p.ptrNameservers = ptrNameservers
}
// run runs the ctrld main components.
//
// The reload boolean indicates that the function is run when ctrld first start
// or when ctrld receive reloading signal. Platform specifics setup is only done
// on started, mean reload is "false".
//
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
// so all listeners could be terminated and re-spawned again.
func (p *prog) run(reload bool, reloadCh chan struct{}) {
// Wait the caller to signal that we can do our logic.
<-p.waitCh
p.preRun()
if !reload {
p.preRun()
}
numListeners := len(p.cfg.Listener)
p.started = make(chan struct{}, numListeners)
if !reload {
p.started = make(chan struct{}, numListeners)
}
p.onStartedDone = make(chan struct{})
p.loop = make(map[string]bool)
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
if p.cfg.Service.CacheEnable {
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
if err != nil {
@@ -103,15 +248,7 @@ func (p *prog) run() {
p.cache = cacher
}
}
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
n := *mcr
if n == 0 {
p.sema = &noopSemaphore{}
} else {
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
}
}
var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener))
@@ -127,74 +264,102 @@ func (p *prog) run() {
}
p.um = newUpstreamMonitor(p.cfg)
for n := range p.cfg.Upstream {
uc := p.cfg.Upstream[n]
uc.Init()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
} else {
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
if !reload {
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
n := *mcr
if n == 0 {
p.sema = &noopSemaphore{}
} else {
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
}
}
p.setupUpstream(p.cfg)
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
uc.SetCertPool(rootCertPool)
go uc.Ping()
}
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
// context for managing spawn goroutines.
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
// Newer versions of android and iOS denies permission which breaks connectivity.
if !isMobile() {
if !isMobile() && !reload {
wg.Add(1)
go func() {
defer wg.Done()
p.ciTable.Init()
p.ciTable.RefreshLoop(p.stopCh)
p.ciTable.RefreshLoop(ctx)
}()
go p.watchLinkState()
go p.watchLinkState(ctx)
}
for listenerNum := range p.cfg.Listener {
p.cfg.Listener[listenerNum].Init()
go func(listenerNum string) {
defer wg.Done()
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
if !reload {
go func(listenerNum string) {
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(listenerNum); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
}(listenerNum)
}
go func() {
defer func() {
cancelFunc()
wg.Done()
}()
select {
case <-p.stopCh:
case <-ctx.Done():
case <-reloadCh:
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(listenerNum); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
}(listenerNum)
}()
}
for i := 0; i < numListeners; i++ {
<-p.started
if !reload {
for i := 0; i < numListeners; i++ {
<-p.started
}
for _, f := range p.onStarted {
f()
}
}
for _, f := range p.onStarted {
f()
}
// Check for possible DNS loop.
p.checkDnsLoop()
close(p.onStartedDone)
// Start check DNS loop ticker.
go p.checkDnsLoopTicker()
wg.Add(1)
go func() {
defer wg.Done()
// Check for possible DNS loop.
p.checkDnsLoop()
// Start check DNS loop ticker.
p.checkDnsLoopTicker(ctx)
}()
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
initLoggingWithBackup(false)
if p.logConn != nil {
_ = p.logConn.Close()
}
if p.cs != nil {
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
if !reload {
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
initLoggingWithBackup(false)
if p.logConn != nil {
_ = p.logConn.Close()
}
if p.cs != nil {
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
}
}
}
wg.Wait()
@@ -276,7 +441,7 @@ func (p *prog) setDNS() {
nameservers := []string{ns}
if needRFC1918Listeners(lc) {
nameservers = append(nameservers, rfc1918Addresses()...)
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
}
if err := setDNS(netIface, nameservers); err != nil {
logger.Error().Err(err).Msgf("could not set DNS for interface")
@@ -462,3 +627,11 @@ func defaultRouteIP() string {
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
return ip
}
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
func canBeLocalUpstream(addr string) bool {
if ip, err := netip.ParseAddr(addr); err == nil {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
}
return false
}

View File

@@ -19,7 +19,6 @@ func setDependencies(svc *service.Config) {
"Wants=NetworkManager-wait-online.service",
"After=NetworkManager-wait-online.service",
"Wants=systemd-networkd-wait-online.service",
"After=systemd-networkd-wait-online.service",
"Wants=nss-lookup.target",
"After=nss-lookup.target",
}

17
cmd/cli/reload_others.go Normal file
View File

@@ -0,0 +1,17 @@
//go:build !windows
package cli
import (
"os"
"os/signal"
"syscall"
)
func notifyReloadSigCh(ch chan os.Signal) {
signal.Notify(ch, syscall.SIGUSR1)
}
func (p *prog) sendReloadSignal() error {
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
}

18
cmd/cli/reload_windows.go Normal file
View File

@@ -0,0 +1,18 @@
package cli
import (
"errors"
"os"
"time"
)
func notifyReloadSigCh(ch chan os.Signal) {}
func (p *prog) sendReloadSignal() error {
select {
case p.reloadCh <- struct{}{}:
return nil
case <-time.After(5 * time.Second):
}
return errors.New("timeout while sending reload signal")
}

View File

@@ -7,7 +7,6 @@ import (
"time"
"github.com/miekg/dns"
"tailscale.com/logtail/backoff"
"github.com/Control-D-Inc/ctrld"
)
@@ -15,8 +14,8 @@ import (
const (
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
maxFailureRequest = 100
// checkUpstreamMaxBackoff is the max backoff time when checking upstream status.
checkUpstreamMaxBackoff = 2 * time.Minute
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
checkUpstreamBackoffSleep = 2 * time.Second
)
// upstreamMonitor performs monitoring upstreams health.
@@ -76,7 +75,6 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf
um.checking[upstream] = true
um.mu.Unlock()
bo := backoff.NewBackoff("checkUpstream", logf, checkUpstreamMaxBackoff)
resolver, err := ctrld.NewResolver(uc)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
@@ -84,15 +82,20 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf
}
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
ctx := context.Background()
for {
check := func() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
uc.ReBootstrap()
_, err := resolver.Resolve(ctx, msg)
if err == nil {
return err
}
for {
if err := check(); err == nil {
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
um.reset(upstream)
return
}
bo.BackOff(ctx, err)
time.Sleep(checkUpstreamBackoffSleep)
}
}

View File

@@ -11,6 +11,7 @@ import (
"math/rand"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"runtime"
@@ -26,6 +27,7 @@ import (
"github.com/spf13/viper"
"golang.org/x/sync/singleflight"
"tailscale.com/logtail/backoff"
"tailscale.com/net/tsaddr"
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
@@ -82,6 +84,16 @@ func InitConfig(v *viper.Viper, name string) {
"0": {
IP: "",
Port: 0,
Policy: &ListenerPolicyConfig{
Name: "Main Policy",
Networks: []Rule{
{"network.0": []string{"upstream.0"}},
},
Rules: []Rule{
{"example.com": []string{"upstream.0"}},
{"*.ads.com": []string{"upstream.1"}},
},
},
},
})
v.SetDefault("network", map[string]*NetworkConfig{
@@ -181,6 +193,7 @@ type ServiceConfig struct {
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
Daemon bool `mapstructure:"-" toml:"-"`
AllocateIP bool `mapstructure:"-" toml:"-"`
}
@@ -204,6 +217,9 @@ type UpstreamConfig struct {
// The caller should not access this field directly.
// Use UpstreamSendClientInfo instead.
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
// The caller should not access this field directly.
// Use IsDiscoverable instead.
Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"`
g singleflight.Group
rebootstrap atomic.Bool
@@ -224,10 +240,11 @@ type UpstreamConfig struct {
// ListenerConfig specifies the networks configuration that ctrld will run on.
type ListenerConfig struct {
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
AllowWanClients bool `mapstructure:"allow_wan_clients" toml:"allow_wan_clients,omitempty"`
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
}
// IsDirectDnsListener reports whether ctrld can be a direct listener on port 53.
@@ -253,6 +270,7 @@ type ListenerPolicyConfig struct {
Name string `mapstructure:"name" toml:"name,omitempty"`
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
}
@@ -322,13 +340,28 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
}
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
if uc.isControlD() {
if uc.isControlD() || uc.isNextDNS() {
return true
}
}
return false
}
// IsDiscoverable reports whether the upstream can be used for PTR discovery.
// The caller must ensure uc.Init() was called before calling this.
func (uc *UpstreamConfig) IsDiscoverable() bool {
if uc.Discoverable != nil {
return *uc.Discoverable
}
switch uc.Type {
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate:
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
}
}
return false
}
// BootstrapIPs returns the bootstrap IPs list of upstreams.
func (uc *UpstreamConfig) BootstrapIPs() []string {
return uc.bootstrapIPs
@@ -394,8 +427,9 @@ func (uc *UpstreamConfig) ReBootstrap() {
return
}
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
uc.rebootstrap.Store(true)
if uc.rebootstrap.CompareAndSwap(false, true) {
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
}
return true, nil
})
}
@@ -519,6 +553,16 @@ func (uc *UpstreamConfig) isControlD() bool {
return false
}
func (uc *UpstreamConfig) isNextDNS() bool {
domain := uc.Domain
if domain == "" {
if u, err := url.Parse(uc.Endpoint); err == nil {
domain = u.Hostname()
}
}
return domain == "dns.nextdns.io"
}
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
uc.transportOnce.Do(func() {
uc.SetupTransport()

View File

@@ -279,6 +279,61 @@ func TestUpstreamConfig_UpstreamSendClientInfo(t *testing.T) {
}
}
func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
tests := []struct {
name string
uc *UpstreamConfig
discoverable bool
}{
{
"loopback",
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy},
true,
},
{
"rfc1918",
&UpstreamConfig{Endpoint: "192.168.1.1", Type: ResolverTypeLegacy},
true,
},
{
"CGNAT",
&UpstreamConfig{Endpoint: "100.66.67.68", Type: ResolverTypeLegacy},
true,
},
{
"Public IP",
&UpstreamConfig{Endpoint: "8.8.8.8", Type: ResolverTypeLegacy},
false,
},
{
"override discoverable",
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(false)},
false,
},
{
"override non-public",
&UpstreamConfig{Endpoint: "1.1.1.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(true)},
true,
},
{
"non-legacy upstream",
&UpstreamConfig{Endpoint: "https://192.168.1.1/custom-doh", Type: ResolverTypeDOH},
false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init()
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
}
})
}
}
func ptrBool(b bool) *bool {
return &b
}

View File

@@ -10,13 +10,10 @@ import (
"net/http"
"runtime"
"sync"
"time"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
func (uc *UpstreamConfig) setupDOH3Transport() {
@@ -29,9 +26,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
case IpStackSplit:
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if ctrldnet.IPv6Available(ctx) {
if hasIPv6() {
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
} else {
uc.http3RoundTripper6 = uc.http3RoundTripper4
@@ -127,11 +122,6 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
close(ch)
}()
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
@@ -140,6 +130,11 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
select {
case ch <- &parallelDialerResult{conn: conn, err: err}:
@@ -147,6 +142,9 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t
if conn != nil {
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
}
if udpConn != nil {
udpConn.Close()
}
}
}(addr)
}

View File

@@ -54,7 +54,12 @@ func TestLoadDefaultConfig(t *testing.T) {
cfg := defaultConfig(t)
validate := validator.New()
require.NoError(t, ctrld.ValidateConfig(validate, cfg))
assert.Len(t, cfg.Listener, 1)
if assert.Len(t, cfg.Listener, 1) {
l0 := cfg.Listener["0"]
require.NotNil(t, l0.Policy)
assert.Len(t, l0.Policy.Networks, 1)
assert.Len(t, l0.Policy.Rules, 2)
}
assert.Len(t, cfg.Upstream, 2)
}
@@ -96,6 +101,7 @@ func TestConfigValidation(t *testing.T) {
{"lease file format required if lease file exist", configWithExistedLeaseFile(t), true},
{"invalid lease file format", configWithInvalidLeaseFileFormat(t), true},
{"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true},
{"invalid client id pref", configWithInvalidClientIDPref(t), true},
}
for _, tc := range tests {
@@ -233,3 +239,9 @@ func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH
return cfg
}
func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Service.ClientIDPref = "foo"
return cfg
}

View File

@@ -215,6 +215,17 @@ DHCP leases file format.
- Valid values: `dnsmasq`, `isc-dhcp`
- Default: ""
### client_id_preference
Decide how the client ID is generated
If `host` -> client id will only use the hostname i.e.`hash(hostname)`.
If `mac` -> client id will only use the MAC address `hash(mac)`.
Else -> client ID will use both Mac and Hostname i.e. `hash(mac + host)
- Type: string
- Required: no
- Valid values: `mac`, `host`
- Default: ""
## Upstream
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
@@ -319,6 +330,24 @@ If `ip_stack` is empty, or undefined:
- Default value is `both` for non-Control D resolvers.
- Default value is `split` for Control D resolvers.
### send_client_info
Specifying whether to include client info when sending query to upstream.
- Type: boolean
- Required: no
- Default:
- `true` for ControlD upstreams.
- `false` for other upstreams.
### discoverable
Specifying whether the upstream can be used for PTR discovery.
- Type: boolean
- Required: no
- Default:
- `true` for loopback/RFC1918/CGNAT IP address.
- `false` for public IP address.
## Network
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.
@@ -376,7 +405,14 @@ Port number that the listener will listen on for incoming requests. If `port` is
- Default: 0 or 53 or 5354 (depending on platform)
### restricted
If set to `true` makes the listener `REFUSE` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
If set to `true`, makes the listener `REFUSED` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
- Type: bool
- Required: no
- Default: false
### allow_wan_clients
The listener `REFUSED` DNS queries from WAN clients by default. If set to `true`, makes the listener replies to them.
- Type: bool
- Required: no
@@ -386,7 +422,15 @@ If set to `true` makes the listener `REFUSE` DNS queries from all source IP addr
Allows `ctrld` to set policy rules to determine which upstreams the requests will be forwarded to.
If no `policy` is defined or the requests do not match any policy rules, it will be forwarded to corresponding upstream of the listener. For example, the request to `listener.0` will be forwarded to `upstream.0`.
The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either the `network` or a domain. Value is the list of the upstreams. For example:
The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either:
- Network.
- Domain.
- Mac Address.
Value is the list of the upstreams.
For example:
```toml
[listener.0.policy]
@@ -400,12 +444,18 @@ rules = [
{"*.local" = ["upstream.1"]},
{"test.com" = ["upstream.2", "upstream.1"]},
]
macs = [
{"14:54:4a:8e:08:2d" = ["upstream.3"]},
]
```
Above policy will:
- Forward requests on `listener.0` from `network.0` to `upstream.1`.
- Forward requests on `listener.0` for `.local` suffixed domains to `upstream.1`.
- Forward requests on `listener.0` for `test.com` to `upstream.2`. If timeout is reached, retry on `upstream.1`.
- Forward requests on `listener.0` from client with Mac `14:54:4a:8e:08:2d` to `upstream.3`.
- Forward requests on `listener.0` from `network.0` to `upstream.1`.
- All other requests on `listener.0` that do not match above conditions will be forwarded to `upstream.0`.
An empty upstream would not route the request to any defined upstreams, and use the OS default resolver.
@@ -419,6 +469,18 @@ rules = [
]
```
---
Note that the order of matching preference:
```
rules => macs => networks
```
And within each policy, the rules are processed from top to bottom.
---
#### name
`name` is the name for the policy.
@@ -440,6 +502,13 @@ rules = [
- Required: no
- Default: []
### macs:
`macs` is the list of mac rules within the policy. Mac address value is case-insensitive.
- Type: array of macs
- Required: no
- Default: []
### failover_rcodes
For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`.

75
doh.go
View File

@@ -18,11 +18,12 @@ import (
)
const (
dohMacHeader = "x-cd-mac"
dohIPHeader = "x-cd-ip"
dohHostHeader = "x-cd-host"
dohOsHeader = "x-cd-os"
headerApplicationDNS = "application/dns-message"
dohMacHeader = "x-cd-mac"
dohIPHeader = "x-cd-ip"
dohHostHeader = "x-cd-host"
dohOsHeader = "x-cd-os"
dohClientIDPrefHeader = "x-cd-cpref"
headerApplicationDNS = "application/dns-message"
)
// EncodeOsNameMap provides mapping from OS name to a shorter string, used for encoding x-cd-os value.
@@ -76,7 +77,6 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver {
endpoint: uc.u,
isDoH3: uc.Type == ResolverTypeDOH3,
http3RoundTripper: uc.http3RoundTripper,
sendClientInfo: uc.UpstreamSendClientInfo(),
uc: uc,
}
return r
@@ -87,9 +87,9 @@ type dohResolver struct {
endpoint *url.URL
isDoH3 bool
http3RoundTripper http.RoundTripper
sendClientInfo bool
}
// Resolve performs DNS query with given DNS message using DOH protocol.
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
data, err := msg.Pack()
if err != nil {
@@ -106,7 +106,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
if err != nil {
return nil, fmt.Errorf("could not create request: %w", err)
}
addHeader(ctx, req, r.sendClientInfo)
addHeader(ctx, req, r.uc)
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
@@ -146,26 +146,19 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return answer, nil
}
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
req.Header.Set("Content-Type", headerApplicationDNS)
req.Header.Set("Accept", headerApplicationDNS)
req.Header.Set(dohOsHeader, dohOsHeaderValue())
printed := false
if sendClientInfo {
if uc.UpstreamSendClientInfo() {
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
if ci.Mac != "" {
req.Header.Set(dohMacHeader, ci.Mac)
}
if ci.IP != "" {
req.Header.Set(dohIPHeader, ci.IP)
}
if ci.Hostname != "" {
req.Header.Set(dohHostHeader, ci.Hostname)
}
if ci.Self {
req.Header.Set(dohOsHeader, dohOsHeaderValue())
switch {
case uc.isControlD():
addControlDHeaders(req, ci)
case uc.isNextDNS():
addNextDNSHeaders(req, ci)
}
}
}
@@ -173,3 +166,41 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header")
}
}
// addControlDHeaders set DoH/Doh3 HTTP request headers for ControlD upstream.
func addControlDHeaders(req *http.Request, ci *ClientInfo) {
req.Header.Set(dohOsHeader, dohOsHeaderValue())
if ci.Mac != "" {
req.Header.Set(dohMacHeader, ci.Mac)
}
if ci.IP != "" {
req.Header.Set(dohIPHeader, ci.IP)
}
if ci.Hostname != "" {
req.Header.Set(dohHostHeader, ci.Hostname)
}
if ci.Self {
req.Header.Set(dohOsHeader, dohOsHeaderValue())
}
switch ci.ClientIDPref {
case "mac":
req.Header.Set(dohClientIDPrefHeader, "1")
case "host":
req.Header.Set(dohClientIDPrefHeader, "2")
}
}
// addNextDNSHeaders set DoH/Doh3 HTTP request headers for nextdns upstream.
// https://github.com/nextdns/nextdns/blob/v1.41.0/resolver/doh.go#L100
func addNextDNSHeaders(req *http.Request, ci *ClientInfo) {
if ci.Mac != "" {
// https: //github.com/nextdns/nextdns/blob/v1.41.0/run.go#L543
req.Header.Set("X-Device-Model", "mac:"+ci.Mac[:8])
}
if ci.IP != "" {
req.Header.Set("X-Device-Ip", ci.IP)
}
if ci.Hostname != "" {
req.Header.Set("X-Device-Name", ci.Hostname)
}
}

9
go.mod
View File

@@ -25,9 +25,9 @@ require (
github.com/spf13/viper v1.16.0
github.com/stretchr/testify v1.8.3
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/net v0.10.0
golang.org/x/net v0.17.0
golang.org/x/sync v0.2.0
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a
golang.org/x/sys v0.13.0
golang.zx2c4.com/wireguard/windows v0.5.3
tailscale.com v1.44.0
)
@@ -70,11 +70,10 @@ require (
github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda // indirect
golang.org/x/mod v0.10.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/tools v0.9.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

20
go.sum
View File

@@ -55,8 +55,6 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs=
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf h1:40DHYsri+d1bnroFDU2FQAeq68f3kAlOzlQ93kCf26Q=
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -302,8 +300,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -331,8 +329,6 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI
golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda h1:O+EUvnBNPwI4eLthn8W5K+cS8zQZfgTABPLNm6Bna34=
golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda/go.mod h1:aAjjkJNdrh3PMckS4B10TGS2nag27cbKR1y2BpUxsiY=
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
@@ -378,8 +374,8 @@ golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -452,8 +448,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a h1:qMsju+PNttu/NMbq8bQ9waDdxgJMu9QNoUDuhnBaYt0=
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -463,8 +459,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@@ -1,8 +1,11 @@
package clientinfo
import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
@@ -68,25 +71,27 @@ type Table struct {
refreshers []refresher
initOnce sync.Once
dhcp *dhcp
merlin *merlinDiscover
arp *arpDiscover
ptr *ptrDiscover
mdns *mdns
hf *hostsFile
vni *virtualNetworkIface
cfg *ctrld.Config
quitCh chan struct{}
selfIP string
cdUID string
dhcp *dhcp
merlin *merlinDiscover
arp *arpDiscover
ptr *ptrDiscover
mdns *mdns
hf *hostsFile
vni *virtualNetworkIface
svcCfg ctrld.ServiceConfig
quitCh chan struct{}
selfIP string
cdUID string
ptrNameservers []string
}
func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table {
func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
return &Table{
cfg: cfg,
quitCh: make(chan struct{}),
selfIP: selfIP,
cdUID: cdUID,
svcCfg: cfg.Service,
quitCh: make(chan struct{}),
selfIP: selfIP,
cdUID: cdUID,
ptrNameservers: ns,
}
}
@@ -97,7 +102,7 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) {
clientInfoFiles[name] = format
}
func (t *Table) RefreshLoop(stopCh chan struct{}) {
func (t *Table) RefreshLoop(ctx context.Context) {
timer := time.NewTicker(time.Minute * 5)
defer timer.Stop()
for {
@@ -106,7 +111,7 @@ func (t *Table) RefreshLoop(stopCh chan struct{}) {
for _, r := range t.refreshers {
_ = r.refresh()
}
case <-stopCh:
case <-ctx.Done():
close(t.quitCh)
return
}
@@ -182,6 +187,26 @@ func (t *Table) init() {
// PTR lookup.
if t.discoverPTR() {
t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()}
if len(t.ptrNameservers) > 0 {
nss := make([]string, 0, len(t.ptrNameservers))
for _, ns := range t.ptrNameservers {
host, port := ns, "53"
if h, p, err := net.SplitHostPort(ns); err == nil {
host, port = h, p
}
// Only use valid ip:port pair.
if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil {
nss = append(nss, net.JoinHostPort(host, port))
} else {
ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns)
}
}
if len(nss) > 0 {
t.ptr.resolver = ctrld.NewResolverWithNameserver(nss)
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss)
}
}
ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery")
if err := t.ptr.refresh(); err != nil {
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover")
@@ -240,6 +265,21 @@ func (t *Table) LookupHostname(ip, mac string) string {
return ""
}
// LookupRFC1918IPv4 returns the RFC1918 IPv4 address for the given MAC address, if any.
func (t *Table) LookupRFC1918IPv4(mac string) string {
t.initOnce.Do(t.init)
for _, r := range t.ipResolvers {
ip, err := netip.ParseAddr(r.LookupIP(mac))
if err != nil || ip.Is6() {
continue
}
if ip.IsPrivate() {
return ip.String()
}
}
return ""
}
type macEntry struct {
mac string
src string
@@ -338,39 +378,60 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) {
t.vni.ip2name.Store(ci.IP, ci.Hostname)
}
// ipFinder is the interface for retrieving IP address from hostname.
type ipFinder interface {
lookupIPByHostname(name string, v6 bool) string
}
// LookupIPByHostname returns the ip address of given hostname.
// If v6 is true, return IPv6 instead of default IPv4.
func (t *Table) LookupIPByHostname(hostname string, v6 bool) *netip.Addr {
if t == nil {
return nil
}
for _, finder := range []ipFinder{t.hf, t.ptr, t.mdns, t.dhcp} {
if addr := finder.lookupIPByHostname(hostname, v6); addr != "" {
if ip, err := netip.ParseAddr(addr); err == nil {
return &ip
}
}
}
return nil
}
func (t *Table) discoverDHCP() bool {
if t.cfg.Service.DiscoverDHCP == nil {
if t.svcCfg.DiscoverDHCP == nil {
return true
}
return *t.cfg.Service.DiscoverDHCP
return *t.svcCfg.DiscoverDHCP
}
func (t *Table) discoverARP() bool {
if t.cfg.Service.DiscoverARP == nil {
if t.svcCfg.DiscoverARP == nil {
return true
}
return *t.cfg.Service.DiscoverARP
return *t.svcCfg.DiscoverARP
}
func (t *Table) discoverMDNS() bool {
if t.cfg.Service.DiscoverMDNS == nil {
if t.svcCfg.DiscoverMDNS == nil {
return true
}
return *t.cfg.Service.DiscoverMDNS
return *t.svcCfg.DiscoverMDNS
}
func (t *Table) discoverPTR() bool {
if t.cfg.Service.DiscoverPtr == nil {
if t.svcCfg.DiscoverPtr == nil {
return true
}
return *t.cfg.Service.DiscoverPtr
return *t.svcCfg.DiscoverPtr
}
func (t *Table) discoverHosts() bool {
if t.cfg.Service.DiscoverHosts == nil {
if t.svcCfg.DiscoverHosts == nil {
return true
}
return *t.cfg.Service.DiscoverHosts
return *t.svcCfg.DiscoverHosts
}
// normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file.

View File

@@ -25,3 +25,22 @@ func Test_normalizeIP(t *testing.T) {
})
}
}
func TestTable_LookupRFC1918IPv4(t *testing.T) {
table := &Table{
dhcp: &dhcp{},
arp: &arpDiscover{},
}
table.ipResolvers = append(table.ipResolvers, table.dhcp)
table.ipResolvers = append(table.ipResolvers, table.arp)
macAddress := "cc:19:f9:8a:49:e6"
rfc1918IPv4 := "10.0.10.245"
table.dhcp.ip.Store(macAddress, "127.0.0.1")
table.arp.ip.Store(macAddress, rfc1918IPv4)
if got := table.LookupRFC1918IPv4(macAddress); got != rfc1918IPv4 {
t.Fatalf("unexpected result, want: %s, got: %s", rfc1918IPv4, got)
}
}

View File

@@ -8,6 +8,7 @@ import (
"net"
"net/netip"
"os"
"sort"
"strings"
"sync"
@@ -134,6 +135,39 @@ func (d *dhcp) List() []string {
return ips
}
func (d *dhcp) lookupIPByHostname(name string, v6 bool) string {
if d == nil {
return ""
}
var (
rfc1918Addrs []netip.Addr
others []netip.Addr
)
d.ip2name.Range(func(key, value any) bool {
if value != name {
return true
}
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
if addr.IsPrivate() {
rfc1918Addrs = append(rfc1918Addrs, addr)
} else {
others = append(others, addr)
}
}
return true
})
result := [][]netip.Addr{rfc1918Addrs, others}
for _, addrs := range result {
if len(addrs) > 0 {
sort.Slice(addrs, func(i, j int) bool {
return addrs[i].Less(addrs[j])
})
return addrs[0].String()
}
}
return ""
}
// AddLeaseFile adds given lease file for reading/watching clients info.
func (d *dhcp) addLeaseFile(name string, format ctrld.LeaseFileFormat) error {
if d.watcher == nil {

View File

@@ -86,3 +86,15 @@ lease 192.168.1.2 {
})
}
}
func Test_dhcp_lookupIPByHostname(t *testing.T) {
d := &dhcp{}
want := "192.168.1.123"
d.ip2name.Store(want, "foo")
d.ip2name.Store("127.0.0.1", "foo")
d.ip2name.Store("169.254.123.123", "foo")
if got := d.lookupIPByHostname("foo", false); got != want {
t.Fatalf("unexpected result, want: %s, got: %s", want, got)
}
}

View File

@@ -1,6 +1,7 @@
package clientinfo
import (
"net/netip"
"os"
"sync"
@@ -109,6 +110,24 @@ func (hf *hostsFile) String() string {
return "hosts"
}
func (hf *hostsFile) lookupIPByHostname(name string, v6 bool) string {
if hf == nil {
return ""
}
hf.mu.Lock()
defer hf.mu.Unlock()
for addr, names := range hf.m {
if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLoopback() {
for _, n := range names {
if n == name && ip.Is6() == v6 {
return ip.String()
}
}
}
}
return ""
}
// isLocalhostName reports whether the given hostname represents localhost.
func isLocalhostName(hostname string) bool {
switch hostname {

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"net/netip"
"os"
"sync"
"syscall"
@@ -59,6 +60,27 @@ func (m *mdns) List() []string {
return ips
}
func (m *mdns) lookupIPByHostname(name string, v6 bool) string {
if m == nil {
return ""
}
var ip string
m.name.Range(func(key, value any) bool {
if value == name {
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
ip = addr.String()
//lint:ignore S1008 This is used for readable.
if addr.IsLoopback() { // Continue searching if this is loopback address.
return true
}
return false
}
}
return true
})
return ip
}
func (m *mdns) init(quitCh chan struct{}) error {
ifaces, err := multicastInterfaces()
if err != nil {
@@ -123,6 +145,10 @@ func (m *mdns) readLoop(conn *net.UDPConn) {
if err, ok := err.(*net.OpError); ok && (err.Timeout() || err.Temporary()) {
continue
}
// Do not complain about use of closed network connection.
if errors.Is(err, net.ErrClosed) {
return
}
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error")
return
}

View File

@@ -2,6 +2,7 @@ package clientinfo
import (
"context"
"net/netip"
"sync"
"sync/atomic"
"time"
@@ -72,15 +73,16 @@ func (p *ptrDiscover) lookupHostname(ip string) string {
msg := new(dns.Msg)
addr, err := dns.ReverseAddr(ip)
if err != nil {
ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
return ""
}
msg.SetQuestion(addr, dns.TypePTR)
ans, err := p.resolver.Resolve(ctx, msg)
if err != nil {
ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
p.serverDown.Store(true)
go p.checkServer()
if p.serverDown.CompareAndSwap(false, true) {
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
go p.checkServer()
}
return ""
}
for _, rr := range ans.Answer {
@@ -93,6 +95,27 @@ func (p *ptrDiscover) lookupHostname(ip string) string {
return ""
}
func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string {
if p == nil {
return ""
}
var ip string
p.hostname.Range(func(key, value any) bool {
if value == name {
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
ip = addr.String()
//lint:ignore S1008 This is used for readable.
if addr.IsLoopback() { // Continue searching if this is loopback address.
return true
}
return false
}
}
return true
})
return ip
}
// checkServer monitors if the resolver can reach its nameserver. When the nameserver
// is reachable, set p.serverDown to false, so p.lookupHostname can continue working.
func (p *ptrDiscover) checkServer() {

View File

@@ -10,6 +10,7 @@ import (
"net"
"net/http"
"os"
"runtime"
"strings"
"time"
@@ -119,7 +120,7 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig
return d.DialContext(ctx, network, addrs)
}
if router.Name() == ddwrt.Name {
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
}
client := http.Client{

View File

@@ -1,9 +1,12 @@
package dnsmasq
import (
"bytes"
"errors"
"fmt"
"html/template"
"net"
"os"
"path/filepath"
"strings"
@@ -19,6 +22,9 @@ server={{ .IP }}#{{ .Port }}
add-mac
add-subnet=32,128
{{- end}}
{{- if .CacheDisabled}}
cache-size=0
{{- end}}
`
const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf"
@@ -47,6 +53,8 @@ if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then
{{- end}}
pc_delete "dnssec" "$config_file" # disable DNSSEC
pc_delete "trust-anchor=" "$config_file" # disable DNSSEC
pc_delete "cache-size=" "$config_file"
pc_append "cache-size=0" "$config_file" # disable cache
# For John fork
pc_delete "resolv-file" "$config_file" # no WAN DNS settings
@@ -65,6 +73,10 @@ type Upstream struct {
}
func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
return ConfTmplWitchCacheDisabled(tmplText, cfg, true)
}
func ConfTmplWitchCacheDisabled(tmplText string, cfg *ctrld.Config, cacheDisabled bool) (string, error) {
listener := cfg.FirstListener()
if listener == nil {
return "", errors.New("missing listener")
@@ -74,24 +86,26 @@ func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
ip = "127.0.0.1"
}
upstreams := []Upstream{{IP: ip, Port: listener.Port}}
return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo())
return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo(), cacheDisabled)
}
func FirewallaConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
if lc := cfg.FirstListener(); lc != nil && (lc.IP == "0.0.0.0" || lc.IP == "") {
return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo())
return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo(), true)
}
return ConfTmpl(tmplText, cfg)
}
func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo bool) (string, error) {
func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo, cacheDisabled bool) (string, error) {
tmpl := template.Must(template.New("").Parse(tmplText))
var to = &struct {
SendClientInfo bool
Upstreams []Upstream
CacheDisabled bool
}{
SendClientInfo: sendClientInfo,
Upstreams: upstreams,
CacheDisabled: cacheDisabled,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, to); err != nil {
@@ -117,9 +131,28 @@ func firewallaUpstreams(port int) []Upstream {
return upstreams
}
// firewallaDnsmasqConfFiles returns dnsmasq config files of all firewalla interfaces.
func firewallaDnsmasqConfFiles() ([]string, error) {
return filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf")
}
// firewallUpdateConf updates all firewall config files using given function.
func firewallUpdateConf(update func(conf string) error) error {
confFiles, err := firewallaDnsmasqConfFiles()
if err != nil {
return err
}
for _, conf := range confFiles {
if err := update(conf); err != nil {
return fmt.Errorf("%s: %w", conf, err)
}
}
return nil
}
// FirewallaSelfInterfaces returns list of interfaces that will be configured with default dnsmasq setup on Firewalla.
func FirewallaSelfInterfaces() []*net.Interface {
matches, err := filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf")
matches, err := firewallaDnsmasqConfFiles()
if err != nil {
return nil
}
@@ -133,3 +166,32 @@ func FirewallaSelfInterfaces() []*net.Interface {
}
return ifaces
}
// FirewallaDisableCache comments out "cache-size" line in all firewalla dnsmasq config files.
func FirewallaDisableCache() error {
return firewallUpdateConf(DisableCache)
}
// FirewallaEnableCache un-comments out "cache-size" line in all firewalla dnsmasq config files.
func FirewallaEnableCache() error {
return firewallUpdateConf(EnableCache)
}
// DisableCache comments out "cache-size" line in dnsmasq config file.
func DisableCache(conf string) error {
return replaceFileContent(conf, "\ncache-size=", "\n#cache-size=")
}
// EnableCache un-comments "cache-size" line in dnsmasq config file.
func EnableCache(conf string) error {
return replaceFileContent(conf, "\n#cache-size=", "\ncache-size=")
}
func replaceFileContent(filename, old, new string) error {
content, err := os.ReadFile(filename)
if err != nil {
return err
}
content = bytes.ReplaceAll(content, []byte(old), []byte(new))
return os.WriteFile(filename, content, 0644)
}

View File

@@ -8,10 +8,10 @@ import (
"os/exec"
"strings"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
)
const (
@@ -95,7 +95,7 @@ func (e *EdgeOS) setupUSG() error {
return fmt.Errorf("setupUSG: backup current config: %w", err)
}
// Removing all configured upstreams.
// Removing all configured upstreams and cache config.
var sb strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(buf))
for scanner.Scan() {
@@ -109,7 +109,7 @@ func (e *EdgeOS) setupUSG() error {
sb.WriteString(line)
}
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg)
data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false)
if err != nil {
return err
}
@@ -127,7 +127,7 @@ func (e *EdgeOS) setupUSG() error {
}
func (e *EdgeOS) setupUDM() error {
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg)
data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false)
if err != nil {
return err
}

View File

@@ -65,6 +65,11 @@ func (f *Firewalla) Setup() error {
return fmt.Errorf("writing ctrld config: %w", err)
}
// Disable dnsmasq cache.
if err := dnsmasq.FirewallaDisableCache(); err != nil {
return err
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
return fmt.Errorf("restartDNSMasq: %w", err)
@@ -82,6 +87,11 @@ func (f *Firewalla) Cleanup() error {
return fmt.Errorf("removing ctrld config: %w", err)
}
// Enable dnsmasq cache.
if err := dnsmasq.FirewallaEnableCache(); err != nil {
return err
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
return fmt.Errorf("restartDNSMasq: %w", err)

View File

@@ -6,6 +6,7 @@ import (
"os"
"os/exec"
"strings"
"time"
"unicode"
"github.com/kardianos/service"
@@ -44,8 +45,24 @@ func (m *Merlin) Uninstall(_ *service.Config) error {
}
func (m *Merlin) PreRun() error {
// Wait NTP ready.
_ = m.Cleanup()
return ntp.WaitNvram()
if err := ntp.WaitNvram(); err != nil {
return err
}
// Wait until directories mounted.
for _, dir := range []string{"/tmp", "/proc"} {
waitDirExists(dir)
}
// Wait dnsmasq started.
for {
out, _ := exec.Command("pidof", "dnsmasq").CombinedOutput()
if len(bytes.TrimSpace(out)) > 0 {
break
}
time.Sleep(time.Second)
}
return nil
}
func (m *Merlin) Setup() error {
@@ -56,9 +73,6 @@ func (m *Merlin) Setup() error {
if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" {
return nil
}
if _, err := nvram.Run("set", nvram.CtrldSetupKey+"=1"); err != nil {
return err
}
buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath)
// Already setup.
if bytes.Contains(buf, []byte(dnsmasq.MerlinPostConfMarker)) {
@@ -140,3 +154,12 @@ func merlinParsePostConf(buf []byte) []byte {
}
return buf
}
func waitDirExists(dir string) {
for {
if _, err := os.Stat(dir); !os.IsNotExist(err) {
return
}
time.Sleep(time.Second)
}
}

View File

@@ -8,11 +8,10 @@ import (
"os/exec"
"strings"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
)
const (
@@ -20,10 +19,9 @@ const (
openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
)
var errUCIEntryNotFound = errors.New("uci: Entry not found")
type Openwrt struct {
cfg *ctrld.Config
cfg *ctrld.Config
dnsmasqCacheSize string
}
// New returns a router.Router for configuring/setup/run ctrld on Openwrt routers.
@@ -52,6 +50,19 @@ func (o *Openwrt) Setup() error {
if o.cfg.FirstListener().IsDirectDnsListener() {
return nil
}
// Save current dnsmasq config cache size if present.
if cs, err := uci("get", "dhcp.@dnsmasq[0].cachesize"); err == nil {
o.dnsmasqCacheSize = cs
if _, err := uci("delete", "dhcp.@dnsmasq[0].cachesize"); err != nil {
return err
}
// Commit.
if _, err := uci("commit", "dhcp"); err != nil {
return err
}
}
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, o.cfg)
if err != nil {
return err
@@ -59,10 +70,6 @@ func (o *Openwrt) Setup() error {
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil {
return err
}
// Commit.
if _, err := uci("commit"); err != nil {
return err
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
return err
@@ -78,6 +85,18 @@ func (o *Openwrt) Cleanup() error {
if err := os.Remove(openwrtDNSMasqConfigPath); err != nil {
return err
}
// Restore original value if present.
if o.dnsmasqCacheSize != "" {
if _, err := uci("set", fmt.Sprintf("dhcp.@dnsmasq[0].cachesize=%s", o.dnsmasqCacheSize)); err != nil {
return err
}
// Commit.
if _, err := uci("commit", "dhcp"); err != nil {
return err
}
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
return err
@@ -92,6 +111,8 @@ func restartDNSMasq() error {
return nil
}
var errUCIEntryNotFound = errors.New("uci: Entry not found")
func uci(args ...string) (string, error) {
cmd := exec.Command("uci", args...)
var stdout, stderr bytes.Buffer

View File

@@ -5,16 +5,17 @@ import (
"os"
"strconv"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
"github.com/Control-D-Inc/ctrld/internal/router/edgeos"
"github.com/kardianos/service"
)
const (
Name = "ubios"
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
Name = "ubios"
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf"
)
type Ubios struct {
@@ -57,6 +58,10 @@ func (u *Ubios) Setup() error {
if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil {
return err
}
// Disable dnsmasq cache.
if err := dnsmasq.DisableCache(ubiosDNSMasqDnsConfigPath); err != nil {
return err
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
return err
@@ -72,6 +77,10 @@ func (u *Ubios) Cleanup() error {
if err := os.Remove(ubiosDNSMasqConfigPath); err != nil {
return err
}
// Enable dnsmasq cache.
if err := dnsmasq.EnableCache(ubiosDNSMasqDnsConfigPath); err != nil {
return err
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
return err

35
net.go
View File

@@ -2,13 +2,10 @@ package ctrld
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
"tailscale.com/logtail/backoff"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
@@ -17,30 +14,36 @@ var (
ipv6Available atomic.Bool
)
const ipv6ProbingInterval = 10 * time.Second
func hasIPv6() bool {
hasIPv6Once.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
val := ctrldnet.IPv6Available(ctx)
ipv6Available.Store(val)
go probingIPv6(val)
go probingIPv6(context.TODO(), val)
})
return ipv6Available.Load()
}
// TODO(cuonglm): doing poll check natively for supported platforms.
func probingIPv6(old bool) {
b := backoff.NewBackoff("probingIPv6", func(format string, args ...any) {}, 30*time.Second)
bCtx := context.Background()
func probingIPv6(ctx context.Context, old bool) {
ticker := time.NewTicker(ipv6ProbingInterval)
defer ticker.Stop()
for {
func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cur := ctrldnet.IPv6Available(ctx)
if ipv6Available.CompareAndSwap(old, cur) {
old = cur
}
}()
b.BackOff(bCtx, errors.New("no change"))
select {
case <-ctx.Done():
return
case <-ticker.C:
func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
cur := ctrldnet.IPv6Available(ctx)
if ipv6Available.CompareAndSwap(old, cur) {
old = cur
}
}()
}
}
}

View File

@@ -5,10 +5,12 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/miekg/dns"
"tailscale.com/net/interfaces"
)
const (
@@ -24,6 +26,8 @@ const (
ResolverTypeOS = "os"
// ResolverTypeLegacy specifies legacy resolver.
ResolverTypeLegacy = "legacy"
// ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only.
ResolverTypePrivate = "private"
)
var bootstrapDNS = "76.76.2.0"
@@ -61,6 +65,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
return or, nil
case ResolverTypeLegacy:
return &legacyResolver{uc: uc}, nil
case ResolverTypePrivate:
return NewPrivateResolver(), nil
}
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
}
@@ -74,8 +80,9 @@ type osResolverResult struct {
err error
}
// Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from
// available nameservers with a roundrobin algorithm.
// Resolve resolves DNS queries using pre-configured nameservers.
// Query is sent to all nameservers concurrently, and the first
// success response will be returned.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
numServers := len(o.nameservers)
if numServers == 0 {
@@ -240,12 +247,16 @@ func NewBootstrapResolver(servers ...string) Resolver {
}
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers,
// excluding nameservers from /etc/resolv.conf file.
// excluding:
//
// - Nameservers from /etc/resolv.conf file.
// - Nameservers which is local RFC1918 addresses.
//
// This is useful for doing PTR lookup in LAN network.
func NewPrivateResolver() Resolver {
nss := nameservers()
resolveConfNss := nameserversFromResolvconf()
localRfc1918Addrs := Rfc1918Addresses()
n := 0
for _, ns := range nss {
host, _, _ := net.SplitHostPort(ns)
@@ -258,6 +269,10 @@ func NewPrivateResolver() Resolver {
if sliceContains(resolveConfNss, host) {
continue
}
// Ignoring local RFC 1918 addresses.
if sliceContains(localRfc1918Addrs, host) {
continue
}
ip := net.ParseIP(host)
if ip != nil && ip.IsPrivate() && !ip.IsLoopback() {
nss[n] = ns
@@ -265,11 +280,35 @@ func NewPrivateResolver() Resolver {
}
}
nss = nss[:n]
if len(nss) == 0 {
return NewResolverWithNameserver(nss)
}
// NewResolverWithNameserver returns an OS resolver which uses the given nameservers
// for resolving DNS queries. If nameservers is empty, a dummy resolver will be returned.
//
// Each nameserver must be form "host:port". It's the caller responsibility to ensure all
// nameservers are well formatted by using net.JoinHostPort function.
func NewResolverWithNameserver(nameservers []string) Resolver {
if len(nameservers) == 0 {
return &dummyResolver{}
}
resolver := &osResolver{nameservers: nss}
return resolver
return &osResolver{nameservers: nameservers}
}
// Rfc1918Addresses returns the list of local interfaces private IP addresses
func Rfc1918Addresses() []string {
var res []string
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
addrs, _ := i.Addrs()
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok || !ipNet.IP.IsPrivate() {
continue
}
res = append(res, ipNet.IP.String())
}
})
return res
}
func newDialer(dnsAddress string) *net.Dialer {

View File

@@ -82,4 +82,8 @@ rules = [
{"*.ru" = ["upstream.1"]},
{"*.local.host" = ["upstream.2", "upstream.0"]},
]
macs = [
{"14:45:A0:67:83:0A" = ["upstream.2"]},
{"14:54:4a:8e:08:2d" = ["upstream.2"]},
]
`