From 2440d922c64c2de537468c6520c5877eaea8f960 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 12 Oct 2023 21:48:16 +0700 Subject: [PATCH 01/52] all: add MAC address base policy While at it, also update the config doc to clarify the order of matching preference, and the matter of rules order within each policy. --- cmd/cli/dns_proxy.go | 17 +++++++++++++++-- cmd/cli/dns_proxy_test.go | 16 ++++++++++------ config.go | 1 + docs/config.md | 37 +++++++++++++++++++++++++++++++++++-- testhelper/config.go | 4 ++++ 5 files changed, 65 insertions(+), 10 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 12cf781..c5271a9 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -62,7 +62,7 @@ func (p *prog) serveDNS(listenerNum string) error { t := time.Now() ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) ctrld.Log(ctx, mainLog.Load().Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) - upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, domain) + upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) var answer *dns.Msg if !matched && listenerConfig.Restricted { answer = new(dns.Msg) @@ -146,7 +146,7 @@ func (p *prog) serveDNS(listenerNum string) error { // Though domain policy has higher priority than network policy, it is still // processed later, because policy logging want to know whether a network rule // is disregarded in favor of the domain level rule. -func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) { +func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) ([]string, bool) { upstreams := []string{upstreamPrefix + defaultUpstreamNum} matchedPolicy := "no policy" matchedNetwork := "no network" @@ -202,6 +202,19 @@ networkRules: } } +macRules: + for _, rule := range lc.Policy.Macs { + for source, targets := range rule { + if source != "" && strings.EqualFold(source, srcMac) { + matchedPolicy = lc.Policy.Name + matchedNetwork = source + networkTargets = targets + matched = true + break macRules + } + } + } + for _, rule := range lc.Policy.Rules { // There's only one entry per rule, config validation ensures this. for source, targets := range rule { diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 674d486..c3b6c96 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -81,6 +81,7 @@ func Test_prog_upstreamFor(t *testing.T) { tests := []struct { name string ip string + mac string defaultUpstreamNum string lc *ctrld.ListenerConfig domain string @@ -88,11 +89,14 @@ func Test_prog_upstreamFor(t *testing.T) { matched bool testLogMsg string }{ - {"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""}, - {"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""}, - {"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""}, - {"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""}, - {"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"}, + {"Policy map matches", "192.168.0.1:0", "", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""}, + {"Policy split matches", "192.168.0.1:0", "", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""}, + {"Policy map for other network matches", "192.168.1.2:0", "", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""}, + {"No policy map for listener", "192.168.1.2:0", "", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""}, + {"unenforced loging", "192.168.1.2:0", "", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"}, + {"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"}, + {"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"}, + {"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"}, } for _, tc := range tests { @@ -111,7 +115,7 @@ func Test_prog_upstreamFor(t *testing.T) { require.NoError(t, err) require.NotNil(t, addr) ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID()) - upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain) + upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain) assert.Equal(t, tc.matched, matched) assert.Equal(t, tc.upstreams, upstreams) if tc.testLogMsg != "" { diff --git a/config.go b/config.go index 21d636c..c9c1acc 100644 --- a/config.go +++ b/config.go @@ -253,6 +253,7 @@ type ListenerPolicyConfig struct { Name string `mapstructure:"name" toml:"name,omitempty"` Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"` Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"` + Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"` FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"` FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"` } diff --git a/docs/config.md b/docs/config.md index 35fbda5..29dff8d 100644 --- a/docs/config.md +++ b/docs/config.md @@ -386,7 +386,15 @@ If set to `true` makes the listener `REFUSE` DNS queries from all source IP addr Allows `ctrld` to set policy rules to determine which upstreams the requests will be forwarded to. If no `policy` is defined or the requests do not match any policy rules, it will be forwarded to corresponding upstream of the listener. For example, the request to `listener.0` will be forwarded to `upstream.0`. -The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either the `network` or a domain. Value is the list of the upstreams. For example: +The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either: + + - Network. + - Domain. + - Mac Address. + +Value is the list of the upstreams. + +For example: ```toml [listener.0.policy] @@ -400,12 +408,18 @@ rules = [ {"*.local" = ["upstream.1"]}, {"test.com" = ["upstream.2", "upstream.1"]}, ] + +macs = [ + {"14:54:4a:8e:08:2d" = ["upstream.3"]}, +] ``` Above policy will: -- Forward requests on `listener.0` from `network.0` to `upstream.1`. + - Forward requests on `listener.0` for `.local` suffixed domains to `upstream.1`. - Forward requests on `listener.0` for `test.com` to `upstream.2`. If timeout is reached, retry on `upstream.1`. +- Forward requests on `listener.0` from client with Mac `14:54:4a:8e:08:2d` to `upstream.3`. +- Forward requests on `listener.0` from `network.0` to `upstream.1`. - All other requests on `listener.0` that do not match above conditions will be forwarded to `upstream.0`. An empty upstream would not route the request to any defined upstreams, and use the OS default resolver. @@ -419,6 +433,18 @@ rules = [ ] ``` +--- + +Note that the order of matching preference: + +``` +rules => macs => networks +``` + +And within each policy, the rules are processed from top to bottom. + +--- + #### name `name` is the name for the policy. @@ -440,6 +466,13 @@ rules = [ - Required: no - Default: [] +### macs: +`macs` is the list of mac rules within the policy. Mac address value is case-insensitive. + +- Type: array of macs +- Required: no +- Default: [] + ### failover_rcodes For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`. diff --git a/testhelper/config.go b/testhelper/config.go index 5c2e5f4..6199424 100644 --- a/testhelper/config.go +++ b/testhelper/config.go @@ -82,4 +82,8 @@ rules = [ {"*.ru" = ["upstream.1"]}, {"*.local.host" = ["upstream.2", "upstream.0"]}, ] +macs = [ + {"14:45:A0:67:83:0A" = ["upstream.2"]}, + {"14:54:4a:8e:08:2d" = ["upstream.2"]}, +] ` From df4e04719e9f1352f9f79d8a10ca3ca6c4d89a4d Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 13 Oct 2023 20:36:53 +0700 Subject: [PATCH 02/52] cmd/cli: relax service dependency on systemd-networkd-wait-online ctrld wants systemd-networkd-wait-online starts before starting itself, but ctrld should not be blocked waiting for it started. --- cmd/cli/prog_linux.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/cli/prog_linux.go b/cmd/cli/prog_linux.go index ed28561..2b9c69d 100644 --- a/cmd/cli/prog_linux.go +++ b/cmd/cli/prog_linux.go @@ -19,7 +19,6 @@ func setDependencies(svc *service.Config) { "Wants=NetworkManager-wait-online.service", "After=NetworkManager-wait-online.service", "Wants=systemd-networkd-wait-online.service", - "After=systemd-networkd-wait-online.service", "Wants=nss-lookup.target", "After=nss-lookup.target", } From ebd516855bedf0b0c9cec2328d78b216df6ebdf6 Mon Sep 17 00:00:00 2001 From: Ginder Singh Date: Sat, 14 Oct 2023 10:57:14 -0400 Subject: [PATCH 03/52] added safe return if error happens during resolver fetch. --- cmd/cli/cli.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index b67504c..bf7b9a6 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1062,6 +1062,9 @@ func processCDFlags() error { return uer } if err != nil { + if isMobile() { + return errors.New("could not fetch resolver config") + } logger.Warn().Err(err).Msg("could not fetch resolver config") return nil } From 6aafe445f54b7fd3c36b338b122157ca41af521e Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 13 Oct 2023 23:39:37 +0700 Subject: [PATCH 04/52] cmd/cli: add nextdns mode Adding --nextdns flag to "ctrld start" command for generating ctrld config with nextdns resolver id, then use nextdns as an upstream. --- cmd/cli/cli.go | 51 ++++++++++++++++++++++++++++++++++++++-- cmd/cli/main.go | 6 +++-- cmd/cli/nextdns.go | 31 +++++++++++++++++++++++++ config.go | 12 +++++++++- doh.go | 58 ++++++++++++++++++++++++++++++++-------------- 5 files changed, 136 insertions(+), 22 deletions(-) create mode 100644 cmd/cli/nextdns.go diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index bf7b9a6..8d733b3 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -158,6 +158,7 @@ func initCLI() { Run: func(cmd *cobra.Command, args []string) { checkStrFlagEmpty(cmd, cdUidFlagName) checkStrFlagEmpty(cmd, cdOrgFlagName) + validateCdAndNextDNSFlags() sc := &service.Config{} *sc = *svcConfig osArgs := os.Args[2:] @@ -231,6 +232,15 @@ func initCLI() { initLogging() + if nextdns != "" { + removeNextDNSFromArgs(sc) + generateNextDNSConfig() + updateListenerConfig() + if err := writeConfigFile(); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to write config with NextDNS resolver") + } + } + // Explicitly passing config, so on system where home directory could not be obtained, // or sub-process env is different with the parent, we still behave correctly and use // the expected config file. @@ -281,7 +291,7 @@ func initCLI() { } }, } - // Keep these flags in sync with runCmd above, except for "-d". + // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") @@ -295,6 +305,7 @@ func initCLI() { startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") routerCmd := &cobra.Command{ Use: "setup", @@ -1216,6 +1227,11 @@ func selfCheckStatus(s service.Service) service.Status { return service.StatusUnknown } + // Not a ctrld upstream, return status as-is. + if cfg.FirstUpstream().VerifyDomain() == "" { + return status + } + mainLog.Load().Debug().Msg("ctrld listener is ready") mainLog.Load().Debug().Msg("performing self-check") bo := backoff.NewBackoff("self-check", logf, 10*time.Second) @@ -1489,6 +1505,7 @@ func mobileListenerPort() int { func updateListenerConfig() (updated bool) { lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" + nextdnsMode := nextdns != "" for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { @@ -1500,7 +1517,8 @@ func updateListenerConfig() (updated bool) { lcc[n].Port = true } // In cd mode, we always try to pick an ip:port pair to work. - if cdMode { + // Same if nextdns resolver is used. + if cdMode || nextdnsMode { lcc[n].IP = true lcc[n].Port = true } @@ -1801,3 +1819,32 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name) } } + +func validateCdAndNextDNSFlags() { + if (cdUID != "" || cdOrg != "") && nextdns != "" { + mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName) + } +} + +// removeNextDNSFromArgs removes the --nextdns from command line arguments. +func removeNextDNSFromArgs(sc *service.Config) { + a := sc.Arguments[:0] + skip := false + for _, x := range sc.Arguments { + if skip { + skip = false + continue + } + // For "--nextdns XXX", skip it and mark next arg skipped. + if x == "--"+nextdnsFlagName { + skip = true + continue + } + // For "--nextdns=XXX", just skip it. + if strings.HasPrefix(x, "--"+nextdnsFlagName+"=") { + continue + } + a = append(a, x) + } + sc.Arguments = a +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index f4439a5..f79b15f 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -32,6 +32,7 @@ var ( cdDev bool iface string ifaceStartStop string + nextdns string mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter @@ -39,8 +40,9 @@ var ( ) const ( - cdUidFlagName = "cd" - cdOrgFlagName = "cd-org" + cdUidFlagName = "cd" + cdOrgFlagName = "cd-org" + nextdnsFlagName = "nextdns" ) func init() { diff --git a/cmd/cli/nextdns.go b/cmd/cli/nextdns.go new file mode 100644 index 0000000..8aebfc8 --- /dev/null +++ b/cmd/cli/nextdns.go @@ -0,0 +1,31 @@ +package cli + +import ( + "fmt" + + "github.com/Control-D-Inc/ctrld" +) + +const nextdnsURL = "https://dns.nextdns.io" + +func generateNextDNSConfig() { + if nextdns == "" { + return + } + mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver") + cfg = ctrld.Config{ + Listener: map[string]*ctrld.ListenerConfig{ + "0": { + IP: "0.0.0.0", + Port: 53, + }, + }, + Upstream: map[string]*ctrld.UpstreamConfig{ + "0": { + Type: ctrld.ResolverTypeDOH3, + Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, nextdns), + Timeout: 5000, + }, + }, + } +} diff --git a/config.go b/config.go index c9c1acc..489a7fd 100644 --- a/config.go +++ b/config.go @@ -323,7 +323,7 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool { } switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: - if uc.isControlD() { + if uc.isControlD() || uc.isNextDNS() { return true } } @@ -520,6 +520,16 @@ func (uc *UpstreamConfig) isControlD() bool { return false } +func (uc *UpstreamConfig) isNextDNS() bool { + domain := uc.Domain + if domain == "" { + if u, err := url.Parse(uc.Endpoint); err == nil { + domain = u.Hostname() + } + } + return domain == "dns.nextdns.io" +} + func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { uc.SetupTransport() diff --git a/doh.go b/doh.go index e0aa363..96f8051 100644 --- a/doh.go +++ b/doh.go @@ -76,7 +76,6 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver { endpoint: uc.u, isDoH3: uc.Type == ResolverTypeDOH3, http3RoundTripper: uc.http3RoundTripper, - sendClientInfo: uc.UpstreamSendClientInfo(), uc: uc, } return r @@ -87,9 +86,9 @@ type dohResolver struct { endpoint *url.URL isDoH3 bool http3RoundTripper http.RoundTripper - sendClientInfo bool } +// Resolve performs DNS query with given DNS message using DOH protocol. func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { data, err := msg.Pack() if err != nil { @@ -106,7 +105,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if err != nil { return nil, fmt.Errorf("could not create request: %w", err) } - addHeader(ctx, req, r.sendClientInfo) + addHeader(ctx, req, r.uc) dnsTyp := uint16(0) if len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype @@ -146,26 +145,19 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return answer, nil } -func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { +func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { req.Header.Set("Content-Type", headerApplicationDNS) req.Header.Set("Accept", headerApplicationDNS) - req.Header.Set(dohOsHeader, dohOsHeaderValue()) printed := false - if sendClientInfo { + if uc.UpstreamSendClientInfo() { if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != "" - if ci.Mac != "" { - req.Header.Set(dohMacHeader, ci.Mac) - } - if ci.IP != "" { - req.Header.Set(dohIPHeader, ci.IP) - } - if ci.Hostname != "" { - req.Header.Set(dohHostHeader, ci.Hostname) - } - if ci.Self { - req.Header.Set(dohOsHeader, dohOsHeaderValue()) + switch { + case uc.isControlD(): + addControlDHeaders(req, ci) + case uc.isNextDNS(): + addNextDNSHeaders(req, ci) } } } @@ -173,3 +165,35 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") } } + +// addControlDHeaders set DoH/Doh3 HTTP request headers for ControlD upstream. +func addControlDHeaders(req *http.Request, ci *ClientInfo) { + req.Header.Set(dohOsHeader, dohOsHeaderValue()) + if ci.Mac != "" { + req.Header.Set(dohMacHeader, ci.Mac) + } + if ci.IP != "" { + req.Header.Set(dohIPHeader, ci.IP) + } + if ci.Hostname != "" { + req.Header.Set(dohHostHeader, ci.Hostname) + } + if ci.Self { + req.Header.Set(dohOsHeader, dohOsHeaderValue()) + } +} + +// addNextDNSHeaders set DoH/Doh3 HTTP request headers for nextdns upstream. +// https://github.com/nextdns/nextdns/blob/v1.41.0/resolver/doh.go#L100 +func addNextDNSHeaders(req *http.Request, ci *ClientInfo) { + if ci.Mac != "" { + // https: //github.com/nextdns/nextdns/blob/v1.41.0/run.go#L543 + req.Header.Set("X-Device-Model", "mac:"+ci.Mac[:8]) + } + if ci.IP != "" { + req.Header.Set("X-Device-Ip", ci.IP) + } + if ci.Hostname != "" { + req.Header.Set("X-Device-Name", ci.Hostname) + } +} From 904b23eeac7408d402c54e643c28d61ace1043f4 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 17 Oct 2023 00:56:52 +0700 Subject: [PATCH 05/52] cmd/cli: add --proto flag to set upstream type in cd mode --- cmd/cli/cli.go | 21 +++++++++++++++++++-- cmd/cli/main.go | 1 + 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 8d733b3..d0d64c9 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -144,6 +144,7 @@ func initCLI() { _ = runCmd.Flags().MarkHidden("homedir") runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) _ = runCmd.Flags().MarkHidden("iface") + runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) rootCmd.AddCommand(runCmd) @@ -177,6 +178,9 @@ func initCLI() { // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. sc.Arguments = append(sc.Arguments, "--cd="+cdUID) } + if cdUID != "" { + validateCdUpstreamProtocol() + } p := &prog{ router: router.New(&cfg, cdUID != ""), @@ -306,6 +310,7 @@ func initCLI() { _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") + startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) routerCmd := &cobra.Command{ Use: "setup", @@ -788,6 +793,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cdUID = uid } if cdUID != "" { + validateCdUpstreamProtocol() err := processCDFlags() if err != nil { appCallback.Exit(err.Error()) @@ -1099,7 +1105,7 @@ func processCDFlags() error { cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) cfg.Upstream["0"] = &ctrld.UpstreamConfig{ Endpoint: resolverConfig.DOH, - Type: ctrld.ResolverTypeDOH, + Type: cdUpstreamProto, Timeout: 5000, } rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) @@ -1816,7 +1822,18 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { return } if fl.Value.String() == "" { - mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name) + mainLog.Load().Fatal().Msgf(`flag "--%s" value must be non-empty`, fl.Name) + } +} + +func validateCdUpstreamProtocol() { + if cdUID == "" { + return + } + switch cdUpstreamProto { + case ctrld.ResolverTypeDOH, ctrld.ResolverTypeDOH3: + default: + mainLog.Load().Fatal().Msg(`flag "--protocol" must be "doh" or "doh3"`) } } diff --git a/cmd/cli/main.go b/cmd/cli/main.go index f79b15f..bf65044 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -33,6 +33,7 @@ var ( iface string ifaceStartStop string nextdns string + cdUpstreamProto string mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter From baf836557c8005ab1e963e22b37c48165da13c1c Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 19 Oct 2023 00:37:27 +0700 Subject: [PATCH 06/52] cmd/cli: fix wrong checking condition in removeProvTokenFromArgs The provision token is only used once, then do not have any effect after Control D uid is fetched. So making it appears in "ctrld run" command is useless. --- cmd/cli/cli.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index d0d64c9..7a3da52 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1769,12 +1769,12 @@ func removeProvTokenFromArgs(sc *service.Config) { continue } // For "--cd-org XXX", skip it and mark next arg skipped. - if x == cdOrgFlagName { + if x == "--"+cdOrgFlagName { skip = true continue } // For "--cd-org=XXX", just skip it. - if strings.HasPrefix(x, cdOrgFlagName+"=") { + if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") { continue } a = append(a, x) From 712b23a4bb5a141da4fcb3c37effb36d17c9c190 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 19 Oct 2023 23:03:11 +0700 Subject: [PATCH 07/52] cmd/cli: initialize upstream proto for mobile --- cmd/cli/cli.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 7a3da52..f2b9906 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -704,6 +704,7 @@ func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struc homedir = appConfig.HomeDir verbose = appConfig.Verbose cdUID = appConfig.CdUID + cdUpstreamProto = ctrld.ResolverTypeDOH logPath = appConfig.LogPath run(appCallback, stopCh) } From 58a00ea24a55b2c9f10588fdbd664c16bb4b4ccc Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 24 Oct 2023 00:22:03 +0700 Subject: [PATCH 08/52] all: implement reload command This commit adds reload command to ctrld for re-fetch new config from ContorlD API or re-read the current config on disk. --- cmd/cli/cli.go | 208 +++++++++++++++++++++-------- cmd/cli/control_server.go | 36 +++++ cmd/cli/dns_proxy.go | 29 +++- cmd/cli/loop.go | 4 +- cmd/cli/netlink_linux.go | 25 ++-- cmd/cli/netlink_others.go | 4 +- cmd/cli/prog.go | 166 +++++++++++++++++++---- cmd/cli/reload_others.go | 17 +++ cmd/cli/reload_windows.go | 18 +++ internal/clientinfo/client_info.go | 29 ++-- 10 files changed, 417 insertions(+), 119 deletions(-) create mode 100644 cmd/cli/reload_others.go create mode 100644 cmd/cli/reload_windows.go diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index f2b9906..bf6803f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -239,7 +239,7 @@ func initCLI() { if nextdns != "" { removeNextDNSFromArgs(sc) generateNextDNSConfig() - updateListenerConfig() + updateListenerConfig(&cfg) if err := writeConfigFile(); err != nil { mainLog.Load().Error().Err(err).Msg("failed to write config with NextDNS resolver") } @@ -383,6 +383,10 @@ func initCLI() { mainLog.Load().Error().Msg(err.Error()) return } + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } initLogging() tasks := []task{ @@ -404,6 +408,50 @@ func initCLI() { }, } + reloadCmd := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Args: cobra.NoArgs, + Run: func(cmd *cobra.Command, args []string) { + dir, err := userHomeDir() + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir") + } + cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock)) + resp, err := cc.post(reloadPath, nil) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to send reload signal to ctrld") + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + mainLog.Load().Notice().Msg("Service reloaded") + case http.StatusCreated: + s, err := newService(&prog{}, svcConfig) + if err != nil { + mainLog.Load().Error().Msg(err.Error()) + return + } + mainLog.Load().Warn().Msg("Service was reloaded, but new config requires service restart.") + mainLog.Load().Warn().Msg("Restarting service") + if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("Service not installed") + return + } + restartCmd.Run(cmd, args) + default: + buf, err := io.ReadAll(resp.Body) + if err != nil { + mainLog.Load().Fatal().Err(err).Msg("could not read response from control server") + } + mainLog.Load().Error().Err(err).Msgf("failed to reload ctrld: %s", string(buf)) + } + }, + } statusCmd := &cobra.Command{ Use: "status", Short: "Show status of the ctrld service", @@ -519,9 +567,10 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, Short: "Manage ctrld service", Args: cobra.OnlyValidArgs, ValidArgs: []string{ - statusCmd.Use, + startCmd.Use, stopCmd.Use, restartCmd.Use, + reloadCmd.Use, statusCmd.Use, uninstallCmd.Use, interfacesCmd.Use, @@ -530,6 +579,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, serviceCmd.AddCommand(startCmd) serviceCmd.AddCommand(stopCmd) serviceCmd.AddCommand(restartCmd) + serviceCmd.AddCommand(reloadCmd) serviceCmd.AddCommand(statusCmd) serviceCmd.AddCommand(uninstallCmd) serviceCmd.AddCommand(interfacesCmd) @@ -584,6 +634,19 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, } rootCmd.AddCommand(restartCmdAlias) + reloadCmdAlias := &cobra.Command{ + PreRun: func(cmd *cobra.Command, args []string) { + initConsoleLogging() + checkHasElevatedPrivilege() + }, + Use: "reload", + Short: "Reload the ctrld service", + Run: func(cmd *cobra.Command, args []string) { + reloadCmd.Run(cmd, args) + }, + } + rootCmd.AddCommand(reloadCmdAlias) + statusCmdAlias := &cobra.Command{ Use: "status", Short: "Show status of the ctrld service", @@ -716,10 +779,12 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } waitCh := make(chan struct{}) p := &prog{ - waitCh: waitCh, - stopCh: stopCh, - cfg: &cfg, - appCallback: appCallback, + waitCh: waitCh, + stopCh: stopCh, + reloadCh: make(chan struct{}), + reloadDoneCh: make(chan struct{}), + cfg: &cfg, + appCallback: appCallback, } if homedir == "" { if dir, err := userHomeDir(); err == nil { @@ -757,9 +822,11 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { readBase64Config(configBase64) processNoConfigFlags(noConfigStart) + p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) } + p.mu.Unlock() processLogAndCacheFlags() @@ -795,14 +862,46 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { } if cdUID != "" { validateCdUpstreamProtocol() - err := processCDFlags() - if err != nil { - appCallback.Exit(err.Error()) - return + if err := processCDFlags(&cfg); err != nil { + if isMobile() { + appCallback.Exit(err.Error()) + return + } + + uninstallIfInvalidCdUID := func() { + cdLogger := mainLog.Load().With().Str("mode", "cd").Logger() + if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { + s, err := newService(&prog{}, svcConfig) + if err != nil { + cdLogger.Warn().Err(err).Msg("failed to create new service") + return + } + if netIface, _ := netInterface(iface); netIface != nil { + if err := restoreNetworkManager(); err != nil { + cdLogger.Error().Err(err).Msg("could not restore NetworkManager") + return + } + cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface") + if err := resetDNS(netIface); err != nil { + cdLogger.Warn().Err(err).Msg("something went wrong while restoring DNS") + } else { + cdLogger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully") + } + } + + tasks := []task{{s.Uninstall, true}} + if doTasks(tasks) { + cdLogger.Info().Msg("uninstalled service") + } + cdLogger.Fatal().Err(uer).Msg("failed to fetch resolver config") + return + } + } + uninstallIfInvalidCdUID() } } - updated := updateListenerConfig() + updated := updateListenerConfig(&cfg) if cdUID != "" { processLogAndCacheFlags() @@ -830,7 +929,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { initLoggingWithBackup(false) } - validateConfig(&cfg) + if err := validateConfig(&cfg); err != nil { + os.Exit(1) + } initCache() if daemon { @@ -943,7 +1044,7 @@ func readConfigFile(writeDefaultConfig bool) bool { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } - _ = updateListenerConfig() + _ = updateListenerConfig(&cfg) if err := writeConfigFile(); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) } else { @@ -1033,7 +1134,7 @@ func processNoConfigFlags(noConfigStart bool) { v.Set("upstream", upstream) } -func processCDFlags() error { +func processCDFlags(cfg *ctrld.Config) error { logger := mainLog.Load().With().Str("mode", "cd").Logger() logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) bo := backoff.NewBackoff("processCDFlags", logf, 30*time.Second) @@ -1049,47 +1150,17 @@ func processCDFlags() error { } break } - if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { - s, err := newService(&prog{}, svcConfig) - if err != nil { - logger.Warn().Err(err).Msg("failed to create new service") - return nil - } - if netIface, _ := netInterface(iface); netIface != nil { - if err := restoreNetworkManager(); err != nil { - logger.Error().Err(err).Msg("could not restore NetworkManager") - return nil - } - logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface") - if err := resetDNS(netIface); err != nil { - logger.Warn().Err(err).Msg("something went wrong while restoring DNS") - } else { - logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully") - } - } - - tasks := []task{{s.Uninstall, true}} - if doTasks(tasks) { - logger.Info().Msg("uninstalled service") - } - event := logger.Fatal() - if isMobile() { - event = logger.Warn() - } - event.Err(uer).Msg("failed to fetch resolver config") - return uer - } if err != nil { if isMobile() { return errors.New("could not fetch resolver config") } logger.Warn().Err(err).Msg("could not fetch resolver config") - return nil + return err } logger.Info().Msg("generating ctrld config from Control-D configuration") - cfg = ctrld.Config{} + *cfg = ctrld.Config{} // Fetch config, unmarshal to cfg. if resolverConfig.Ctrld.CustomConfig != "" { logger.Info().Msg("using defined custom config of Control-D resolver") @@ -1427,18 +1498,17 @@ func uninstall(p *prog, s service.Service) { } } -func validateConfig(cfg *ctrld.Config) { - err := ctrld.ValidateConfig(validator.New(), cfg) - if err == nil { - return - } - var ve validator.ValidationErrors - if errors.As(err, &ve) { - for _, fe := range ve { - mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) +func validateConfig(cfg *ctrld.Config) error { + if err := ctrld.ValidateConfig(validator.New(), cfg); err != nil { + var ve validator.ValidationErrors + if errors.As(err, &ve) { + for _, fe := range ve { + mainLog.Load().Error().Msgf("invalid config: %s: %s", fe.Namespace(), fieldErrorMsg(fe)) + } } + return err } - os.Exit(1) + return nil } // NOTE: Add more case here once new validation tag is used in ctrld.Config struct. @@ -1509,7 +1579,16 @@ func mobileListenerPort() int { // updateListenerConfig updates the config for listeners if not defined, // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. -func updateListenerConfig() (updated bool) { +func updateListenerConfig(cfg *ctrld.Config) bool { + updated, _ := tryUpdateListenerConfig(cfg, true) + return updated +} + +// tryUpdateListenerConfig tries updating listener config with a working one. +// If fatal is true, and there's listen address conflicted, the function do +// fatal error. +func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { + ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" nextdnsMode := nextdns != "" @@ -1622,7 +1701,11 @@ func updateListenerConfig() (updated bool) { break } if !check.IP && !check.Port { - logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err) + if fatal { + logMsg(mainLog.Load().Fatal(), n, "failed to listen: %v", err) + } + ok = false + break } if tryAllPort53 { tryAllPort53 = false @@ -1683,12 +1766,19 @@ func updateListenerConfig() (updated bool) { listener.Port = oldPort } if listener.IP == oldIP && listener.Port == oldPort { - logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) + if fatal { + logMsg(mainLog.Load().Fatal(), n, "could not listener on %s: %v", net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port)), err) + } + ok = false + break } logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, pick a random ip+port", addr) attempts++ } } + if !ok { + return + } // Specific case for systemd-resolved. if useSystemdResolved { diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 5f5ac51..80bc1ab 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -8,12 +8,15 @@ import ( "os" "sort" "time" + + "github.com/Control-D-Inc/ctrld" ) const ( contentTypeJson = "application/json" listClientsPath = "/clients" startedPath = "/started" + reloadPath = "/reload" ) type controlServer struct { @@ -75,6 +78,39 @@ func (p *prog) registerControlServerHandler() { w.WriteHeader(http.StatusRequestTimeout) } })) + p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + listeners := make(map[string]*ctrld.ListenerConfig) + p.mu.Lock() + for k, v := range p.cfg.Listener { + listeners[k] = &ctrld.ListenerConfig{ + IP: v.IP, + Port: v.Port, + } + } + p.mu.Unlock() + if err := p.sendReloadSignal(); err != nil { + mainLog.Load().Err(err).Msg("could not send reload signal") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + select { + case <-p.reloadDoneCh: + case <-time.After(5 * time.Second): + http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError) + return + } + + p.mu.Lock() + defer p.mu.Unlock() + for k, v := range p.cfg.Listener { + l := listeners[k] + if l == nil || l.IP != v.IP || l.Port != v.Port { + w.WriteHeader(http.StatusCreated) + return + } + } + w.WriteHeader(http.StatusOK) + })) } func jsonResponse(next http.Handler) http.Handler { diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index c5271a9..be0b731 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/hex" + "errors" "fmt" "net" "net/netip" @@ -37,7 +38,9 @@ var osUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } -func (p *prog) serveDNS(listenerNum string) error { +var errReload = errors.New("reload") + +func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { @@ -78,6 +81,12 @@ func (p *prog) serveDNS(listenerNum string) error { }) g, ctx := errgroup.WithContext(context.Background()) + // When receiving reload signal, return a non-nil error so other + // goroutines in errgroup.Group could be terminated. + g.Go(func() error { + <-reloadCh + return errReload + }) for _, proto := range []string{"udp", "tcp"} { proto := proto if needLocalIPv6Listener() { @@ -121,11 +130,13 @@ func (p *prog) serveDNS(listenerNum string) error { addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) s, errCh := runDNSServer(addr, proto, handler) defer s.Shutdown() - select { - case err := <-errCh: - return err - case <-time.After(5 * time.Second): - p.started <- struct{}{} + if !reload { + select { + case err := <-errCh: + return err + case <-time.After(5 * time.Second): + p.started <- struct{}{} + } } select { case <-p.stopCh: @@ -136,7 +147,11 @@ func (p *prog) serveDNS(listenerNum string) error { return nil }) } - return g.Wait() + err := g.Wait() + if errors.Is(err, errReload) { // This is an error for trigger reload, not a real error. + return nil + } + return err } // upstreamFor returns the list of upstreams for resolving the given domain, diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 87dabf8..5e6d911 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -79,13 +79,15 @@ func (p *prog) checkDnsLoop() { } // checkDnsLoopTicker performs p.checkDnsLoop every minute. -func (p *prog) checkDnsLoopTicker() { +func (p *prog) checkDnsLoopTicker(ctx context.Context) { timer := time.NewTicker(time.Minute) defer timer.Stop() for { select { case <-p.stopCh: return + case <-ctx.Done(): + return case <-timer.C: p.checkDnsLoop() } diff --git a/cmd/cli/netlink_linux.go b/cmd/cli/netlink_linux.go index 0faae84..d757f8b 100644 --- a/cmd/cli/netlink_linux.go +++ b/cmd/cli/netlink_linux.go @@ -1,11 +1,13 @@ package cli import ( + "context" + "github.com/vishvananda/netlink" "golang.org/x/sys/unix" ) -func (p *prog) watchLinkState() { +func (p *prog) watchLinkState(ctx context.Context) { ch := make(chan netlink.LinkUpdate) done := make(chan struct{}) defer close(done) @@ -13,14 +15,19 @@ func (p *prog) watchLinkState() { mainLog.Load().Warn().Err(err).Msg("could not subscribe link") return } - for lu := range ch { - if lu.Change == 0xFFFFFFFF { - continue - } - if lu.Change&unix.IFF_UP != 0 { - mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") - for _, uc := range p.cfg.Upstream { - uc.ReBootstrap() + for { + select { + case <-ctx.Done(): + return + case lu := <-ch: + if lu.Change == 0xFFFFFFFF { + continue + } + if lu.Change&unix.IFF_UP != 0 { + mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping") + for _, uc := range p.cfg.Upstream { + uc.ReBootstrap() + } } } } diff --git a/cmd/cli/netlink_others.go b/cmd/cli/netlink_others.go index f0afd21..5a298b9 100644 --- a/cmd/cli/netlink_others.go +++ b/cmd/cli/netlink_others.go @@ -2,4 +2,6 @@ package cli -func (p *prog) watchLinkState() {} +import "context" + +func (p *prog) watchLinkState(ctx context.Context) {} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index e30a03d..a475a77 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -2,6 +2,7 @@ package cli import ( "bytes" + "context" "errors" "fmt" "math/rand" @@ -16,6 +17,7 @@ import ( "syscall" "github.com/kardianos/service" + "github.com/spf13/viper" "tailscale.com/net/interfaces" "github.com/Control-D-Inc/ctrld" @@ -45,11 +47,13 @@ var svcConfig = &service.Config{ var useSystemdResolved = false type prog struct { - mu sync.Mutex - waitCh chan struct{} - stopCh chan struct{} - logConn net.Conn - cs *controlServer + mu sync.Mutex + waitCh chan struct{} + stopCh chan struct{} + reloadCh chan struct{} // For Windows. + reloadDoneCh chan struct{} + logConn net.Conn + cs *controlServer cfg *ctrld.Config appCallback *AppCallback @@ -69,11 +73,90 @@ type prog struct { } func (p *prog) Start(s service.Service) error { - p.cfg = &cfg - go p.run() + go p.runWait() return nil } +// runWait runs ctrld components, waiting for signal to reload. +func (p *prog) runWait() { + p.mu.Lock() + p.cfg = &cfg + p.mu.Unlock() + reloadSigCh := make(chan os.Signal, 1) + notifyReloadSigCh(reloadSigCh) + + reload := false + logger := mainLog.Load() + for { + reloadCh := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + p.run(reload, reloadCh) + reload = true + }() + select { + case sig := <-reloadSigCh: + logger.Notice().Msgf("got signal: %s, reloading...", sig.String()) + case <-p.reloadCh: + logger.Notice().Msg("reloading...") + case <-p.stopCh: + close(reloadCh) + return + } + + waitOldRunDone := func() { + close(reloadCh) + <-done + } + newCfg := &ctrld.Config{} + v := viper.NewWithOptions(viper.KeyDelimiter("::")) + ctrld.InitConfig(v, "ctrld") + if configPath != "" { + v.SetConfigFile(configPath) + } + if err := v.ReadInConfig(); err != nil { + logger.Err(err).Msg("could not read new config") + waitOldRunDone() + continue + } + if err := v.Unmarshal(&newCfg); err != nil { + logger.Err(err).Msg("could not unmarshal new config") + waitOldRunDone() + continue + } + if cdUID != "" { + if err := processCDFlags(newCfg); err != nil { + logger.Err(err).Msg("could not fetch ControlD config") + waitOldRunDone() + continue + } + } + + waitOldRunDone() + + _, ok := tryUpdateListenerConfig(newCfg, false) + if !ok { + logger.Error().Msg("could not update listener config") + continue + } + if err := validateConfig(newCfg); err != nil { + logger.Err(err).Msg("invalid config") + continue + } + + p.mu.Lock() + *p.cfg = *newCfg + p.mu.Unlock() + + logger.Notice().Msg("reloading config successfully") + select { + case p.reloadDoneCh <- struct{}{}: + default: + } + } +} + func (p *prog) preRun() { if !service.Interactive() { p.setDNS() @@ -87,7 +170,15 @@ func (p *prog) preRun() { } } -func (p *prog) run() { +// run runs the ctrld main components. +// +// The reload boolean indicates that the function is run when ctrld first start +// or when ctrld receive reloading signal. Platform specifics setup is only done +// on started, mean reload is "false". +// +// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded, +// so all listeners could be terminated and re-spawned again. +func (p *prog) run(reload bool, reloadCh chan struct{}) { // Wait the caller to signal that we can do our logic. <-p.waitCh p.preRun() @@ -146,19 +237,29 @@ func (p *prog) run() { format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) p.ciTable.AddLeaseFile(leaseFile, format) } + + // context for managing spawn goroutines. + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + // Newer versions of android and iOS denies permission which breaks connectivity. if !isMobile() { + wg.Add(1) go func() { + defer wg.Done() p.ciTable.Init() - p.ciTable.RefreshLoop(p.stopCh) + p.ciTable.RefreshLoop(ctx) }() - go p.watchLinkState() + go p.watchLinkState(ctx) } for listenerNum := range p.cfg.Listener { p.cfg.Listener[listenerNum].Init() go func(listenerNum string) { - defer wg.Done() + defer func() { + cancelFunc() + wg.Done() + }() listenerConfig := p.cfg.Listener[listenerNum] upstreamConfig := p.cfg.Upstream[listenerNum] if upstreamConfig == nil { @@ -166,35 +267,44 @@ func (p *prog) run() { } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum); err != nil { + if err := p.serveDNS(listenerNum, reload, reloadCh); err != nil { mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) } }(listenerNum) } - for i := 0; i < numListeners; i++ { - <-p.started - } - for _, f := range p.onStarted { - f() + if !reload { + for i := 0; i < numListeners; i++ { + <-p.started + } + for _, f := range p.onStarted { + f() + } } + // Check for possible DNS loop. p.checkDnsLoop() close(p.onStartedDone) // Start check DNS loop ticker. - go p.checkDnsLoopTicker() + wg.Add(1) + go func() { + defer wg.Done() + p.checkDnsLoopTicker(ctx) + }() - // Stop writing log to unix socket. - consoleWriter.Out = os.Stdout - initLoggingWithBackup(false) - if p.logConn != nil { - _ = p.logConn.Close() - } - if p.cs != nil { - p.registerControlServerHandler() - if err := p.cs.start(); err != nil { - mainLog.Load().Warn().Err(err).Msg("could not start control server") + if !reload { + // Stop writing log to unix socket. + consoleWriter.Out = os.Stdout + initLoggingWithBackup(false) + if p.logConn != nil { + _ = p.logConn.Close() + } + if p.cs != nil { + p.registerControlServerHandler() + if err := p.cs.start(); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not start control server") + } } } wg.Wait() diff --git a/cmd/cli/reload_others.go b/cmd/cli/reload_others.go new file mode 100644 index 0000000..0977af9 --- /dev/null +++ b/cmd/cli/reload_others.go @@ -0,0 +1,17 @@ +//go:build !windows + +package cli + +import ( + "os" + "os/signal" + "syscall" +) + +func notifyReloadSigCh(ch chan os.Signal) { + signal.Notify(ch, syscall.SIGUSR1) +} + +func (p *prog) sendReloadSignal() error { + return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) +} diff --git a/cmd/cli/reload_windows.go b/cmd/cli/reload_windows.go new file mode 100644 index 0000000..0e817e4 --- /dev/null +++ b/cmd/cli/reload_windows.go @@ -0,0 +1,18 @@ +package cli + +import ( + "errors" + "os" + "time" +) + +func notifyReloadSigCh(ch chan os.Signal) {} + +func (p *prog) sendReloadSignal() error { + select { + case p.reloadCh <- struct{}{}: + return nil + case <-time.After(5 * time.Second): + } + return errors.New("timeout while sending reload signal") +} diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 3e92fd1..6d6cbf9 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -1,6 +1,7 @@ package clientinfo import ( + "context" "fmt" "net/netip" "strings" @@ -75,7 +76,7 @@ type Table struct { mdns *mdns hf *hostsFile vni *virtualNetworkIface - cfg *ctrld.Config + svcCfg ctrld.ServiceConfig quitCh chan struct{} selfIP string cdUID string @@ -83,7 +84,7 @@ type Table struct { func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table { return &Table{ - cfg: cfg, + svcCfg: cfg.Service, quitCh: make(chan struct{}), selfIP: selfIP, cdUID: cdUID, @@ -97,7 +98,7 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) { clientInfoFiles[name] = format } -func (t *Table) RefreshLoop(stopCh chan struct{}) { +func (t *Table) RefreshLoop(ctx context.Context) { timer := time.NewTicker(time.Minute * 5) defer timer.Stop() for { @@ -106,7 +107,7 @@ func (t *Table) RefreshLoop(stopCh chan struct{}) { for _, r := range t.refreshers { _ = r.refresh() } - case <-stopCh: + case <-ctx.Done(): close(t.quitCh) return } @@ -339,38 +340,38 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) { } func (t *Table) discoverDHCP() bool { - if t.cfg.Service.DiscoverDHCP == nil { + if t.svcCfg.DiscoverDHCP == nil { return true } - return *t.cfg.Service.DiscoverDHCP + return *t.svcCfg.DiscoverDHCP } func (t *Table) discoverARP() bool { - if t.cfg.Service.DiscoverARP == nil { + if t.svcCfg.DiscoverARP == nil { return true } - return *t.cfg.Service.DiscoverARP + return *t.svcCfg.DiscoverARP } func (t *Table) discoverMDNS() bool { - if t.cfg.Service.DiscoverMDNS == nil { + if t.svcCfg.DiscoverMDNS == nil { return true } - return *t.cfg.Service.DiscoverMDNS + return *t.svcCfg.DiscoverMDNS } func (t *Table) discoverPTR() bool { - if t.cfg.Service.DiscoverPtr == nil { + if t.svcCfg.DiscoverPtr == nil { return true } - return *t.cfg.Service.DiscoverPtr + return *t.svcCfg.DiscoverPtr } func (t *Table) discoverHosts() bool { - if t.cfg.Service.DiscoverHosts == nil { + if t.svcCfg.DiscoverHosts == nil { return true } - return *t.cfg.Service.DiscoverHosts + return *t.svcCfg.DiscoverHosts } // normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file. From d88cf52b4e6ed7510613bae02712b71d64a74a90 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 24 Oct 2023 22:38:50 +0700 Subject: [PATCH 09/52] cmd/cli: always rebootstrap when check upstream Otherwise, network changes may not be seen on some platforms, causing ctrld failed to recover and failing all requests. While at it, also doing the check DNS in separate goroutine, prevent it from blocking ctrld from notifying others that it "started". The issue was seen when ctrld is configured as direct listener, requests are flooded before ctrld started, causing the healtch process failed. --- cmd/cli/loop.go | 5 ++++- cmd/cli/prog.go | 6 +++--- cmd/cli/upstream_monitor.go | 19 +++++++++++-------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index 5e6d911..a9d3972 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -56,7 +56,10 @@ func (p *prog) checkDnsLoop() { mainLog.Load().Debug().Msg("start checking DNS loop") upstream := make(map[string]*ctrld.UpstreamConfig) p.loopMu.Lock() - for _, uc := range p.cfg.Upstream { + for n, uc := range p.cfg.Upstream { + if p.um.isDown("upstream." + n) { + continue + } uid := uc.UID() p.loop[uid] = false upstream[uid] = uc diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index a475a77..d304cce 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -282,14 +282,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } } - // Check for possible DNS loop. - p.checkDnsLoop() close(p.onStartedDone) - // Start check DNS loop ticker. wg.Add(1) go func() { defer wg.Done() + // Check for possible DNS loop. + p.checkDnsLoop() + // Start check DNS loop ticker. p.checkDnsLoopTicker(ctx) }() diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 4b3ee69..83087a4 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -7,7 +7,6 @@ import ( "time" "github.com/miekg/dns" - "tailscale.com/logtail/backoff" "github.com/Control-D-Inc/ctrld" ) @@ -15,8 +14,8 @@ import ( const ( // maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down. maxFailureRequest = 100 - // checkUpstreamMaxBackoff is the max backoff time when checking upstream status. - checkUpstreamMaxBackoff = 2 * time.Minute + // checkUpstreamBackoffSleep is the time interval between each upstream checks. + checkUpstreamBackoffSleep = 2 * time.Second ) // upstreamMonitor performs monitoring upstreams health. @@ -76,7 +75,6 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf um.checking[upstream] = true um.mu.Unlock() - bo := backoff.NewBackoff("checkUpstream", logf, checkUpstreamMaxBackoff) resolver, err := ctrld.NewResolver(uc) if err != nil { mainLog.Load().Warn().Err(err).Msg("could not check upstream") @@ -84,15 +82,20 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf } msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) - ctx := context.Background() - for { + check := func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + uc.ReBootstrap() _, err := resolver.Resolve(ctx, msg) - if err == nil { + return err + } + for { + if err := check(); err == nil { mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint) um.reset(upstream) return } - bo.BackOff(ctx, err) + time.Sleep(checkUpstreamBackoffSleep) } } From 44ba6aadd9243eb1067a6cda34ab34a248a71e17 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 27 Oct 2023 22:02:19 +0700 Subject: [PATCH 10/52] internal/clientinfo: do not complain about net.ErrClosed The probeLoop may have closed the connection before readLoop return, and we don't care about this error. So prevent it from annoying the log. --- internal/clientinfo/mdns.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index 5875b69..9a5fa85 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -123,6 +123,10 @@ func (m *mdns) readLoop(conn *net.UDPConn) { if err, ok := err.(*net.OpError); ok && (err.Timeout() || err.Temporary()) { continue } + // Do not complain about use of closed network connection. + if errors.Is(err, net.ErrClosed) { + return + } ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error") return } From 63f959c9515bbb4676a90ebd60a71fd6d9431db7 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 27 Oct 2023 22:53:15 +0700 Subject: [PATCH 11/52] all: spoof loopback ranges in client info Sending them are useless, so using RFC1918 address instead. --- cmd/cli/dns_proxy.go | 14 ++++++++++++++ internal/clientinfo/client_info.go | 15 +++++++++++++++ internal/clientinfo/client_info_test.go | 19 +++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index be0b731..9c486ed 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -573,9 +573,23 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo { ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac) } ci.Self = queryFromSelf(ci.IP) + p.spoofLoopbackIpInClientInfo(ci) return ci } +// spoofLoopbackIpInClientInfo replaces loopback IPs in client info. +// +// - Preference IPv4. +// - Preference RFC1918. +func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) { + if ip := net.ParseIP(ci.IP); ip == nil || !ip.IsLoopback() { + return + } + if ip := p.ciTable.LookupRFC1918IPv4(ci.Mac); ip != "" { + ci.IP = ip + } +} + // queryFromSelf reports whether the input IP is from device running ctrld. func queryFromSelf(ip string) bool { netIP := netip.MustParseAddr(ip) diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 6d6cbf9..f591174 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -241,6 +241,21 @@ func (t *Table) LookupHostname(ip, mac string) string { return "" } +// LookupRFC1918IPv4 returns the RFC1918 IPv4 address for the given MAC address, if any. +func (t *Table) LookupRFC1918IPv4(mac string) string { + t.initOnce.Do(t.init) + for _, r := range t.ipResolvers { + ip, err := netip.ParseAddr(r.LookupIP(mac)) + if err != nil || ip.Is6() { + continue + } + if ip.IsPrivate() { + return ip.String() + } + } + return "" +} + type macEntry struct { mac string src string diff --git a/internal/clientinfo/client_info_test.go b/internal/clientinfo/client_info_test.go index 79e5912..e6575f2 100644 --- a/internal/clientinfo/client_info_test.go +++ b/internal/clientinfo/client_info_test.go @@ -25,3 +25,22 @@ func Test_normalizeIP(t *testing.T) { }) } } + +func TestTable_LookupRFC1918IPv4(t *testing.T) { + table := &Table{ + dhcp: &dhcp{}, + arp: &arpDiscover{}, + } + + table.ipResolvers = append(table.ipResolvers, table.dhcp) + table.ipResolvers = append(table.ipResolvers, table.arp) + + macAddress := "cc:19:f9:8a:49:e6" + rfc1918IPv4 := "10.0.10.245" + table.dhcp.ip.Store(macAddress, "127.0.0.1") + table.arp.ip.Store(macAddress, rfc1918IPv4) + + if got := table.LookupRFC1918IPv4(macAddress); got != rfc1918IPv4 { + t.Fatalf("unexpected result, want: %s, got: %s", rfc1918IPv4, got) + } +} From 3fea92c8b1e63186e2c7ce1095f5de785d21e9a3 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 6 Nov 2023 20:36:06 +0700 Subject: [PATCH 12/52] Bump golang.org/x/net to v0.17.0 --- go.mod | 9 ++++----- go.sum | 20 ++++++++------------ 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 58ba1e4..fec32ef 100644 --- a/go.mod +++ b/go.mod @@ -25,9 +25,9 @@ require ( github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.8.3 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/net v0.10.0 + golang.org/x/net v0.17.0 golang.org/x/sync v0.2.0 - golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a + golang.org/x/sys v0.13.0 golang.zx2c4.com/wireguard/windows v0.5.3 tailscale.com v1.44.0 ) @@ -70,11 +70,10 @@ require ( github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect github.com/vishvananda/netns v0.0.4 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect - golang.org/x/crypto v0.9.0 // indirect + golang.org/x/crypto v0.14.0 // indirect golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect - golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda // indirect golang.org/x/mod v0.10.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 409133a..c792103 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,6 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8 github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs= -github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf h1:40DHYsri+d1bnroFDU2FQAeq68f3kAlOzlQ93kCf26Q= github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -302,8 +300,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -331,8 +329,6 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= -golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda h1:O+EUvnBNPwI4eLthn8W5K+cS8zQZfgTABPLNm6Bna34= -golang.org/x/mobile v0.0.0-20230531173138-3c911d8e3eda/go.mod h1:aAjjkJNdrh3PMckS4B10TGS2nag27cbKR1y2BpUxsiY= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= @@ -378,8 +374,8 @@ golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -452,8 +448,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a h1:qMsju+PNttu/NMbq8bQ9waDdxgJMu9QNoUDuhnBaYt0= -golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -463,8 +459,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From 4816a09e3abd64668ebed35b1fb33a2ca3706ac4 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 2 Nov 2023 21:53:39 +0700 Subject: [PATCH 13/52] all: use private resolver for private IP address These queries could not be resolved by Control D upstreams, so it's useless and less performance to send them to servers. --- cmd/cli/dns_proxy.go | 65 +++++++++++++++++++++++++++++++++++++++ cmd/cli/dns_proxy_test.go | 27 ++++++++++++++++ resolver.go | 4 +++ 3 files changed, 96 insertions(+) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 9c486ed..af7628c 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -38,6 +38,12 @@ var osUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } +var privateUpstreamConfig = &ctrld.UpstreamConfig{ + Name: "Private resolver", + Type: ctrld.ResolverTypePrivate, + Timeout: 2000, +} + var errReload = errors.New("reload") func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) error { @@ -54,6 +60,11 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { p.sema.acquire() defer p.sema.release() + if len(m.Question) == 0 { + answer := new(dns.Msg) + answer.SetRcode(m, dns.RcodeFormatError) + return + } go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) @@ -261,6 +272,11 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} } + if isPrivatePtrLookup(msg) { + ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup -> [%s]", upstreamOS) + upstreamConfigs = []*ctrld.UpstreamConfig{privateUpstreamConfig} + upstreams = []string{upstreamOS} + } // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { for _, upstream := range upstreams { @@ -634,3 +650,52 @@ func rfc1918Addresses() []string { }) return res } + +// ipFromARPA parses a FQDN arpa domain and return the IP address if valid. +func ipFromARPA(arpa string) net.IP { + if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok { + if ptrIP := net.ParseIP(arpa); ptrIP != nil { + return net.IP{ptrIP[15], ptrIP[14], ptrIP[13], ptrIP[12]} + } + } + if arpa, ok := strings.CutSuffix(arpa, ".ip6.arpa."); ok { + l := net.IPv6len * 2 + base := 16 + ip := make(net.IP, net.IPv6len) + for i := 0; i < l && arpa != ""; i++ { + idx := strings.LastIndexByte(arpa, '.') + off := idx + 1 + if idx == -1 { + idx = 0 + off = 0 + } else if idx == len(arpa)-1 { + return nil + } + n, err := strconv.ParseUint(arpa[off:], base, 8) + if err != nil { + return nil + } + b := byte(n) + ii := i / 2 + if i&1 == 1 { + b |= ip[ii] << 4 + } + ip[ii] = b + arpa = arpa[:idx] + } + return ip + } + return nil +} + +// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN network. +func isPrivatePtrLookup(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + q := m.Question[0] + if ip := ipFromARPA(q.Name); ip != nil { + return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() + } + return false +} diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index c3b6c96..8d18fa0 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -238,3 +238,30 @@ func Test_remoteAddrFromMsg(t *testing.T) { }) } } + +func Test_ipFromARPA(t *testing.T) { + tests := []struct { + IP string + ARPA string + }{ + {"1.2.3.4", "4.3.2.1.in-addr.arpa."}, + {"245.110.36.114", "114.36.110.245.in-addr.arpa."}, + {"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."}, + {"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."}, + {"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."}, + {"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."}, + {"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."}, + {"", "asd.in-addr.arpa."}, + {"", "asd.ip6.arpa."}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.IP, func(t *testing.T) { + t.Parallel() + if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) { + t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got) + } + }) + } +} diff --git a/resolver.go b/resolver.go index 969da86..f08263b 100644 --- a/resolver.go +++ b/resolver.go @@ -24,6 +24,8 @@ const ( ResolverTypeOS = "os" // ResolverTypeLegacy specifies legacy resolver. ResolverTypeLegacy = "legacy" + // ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only. + ResolverTypePrivate = "private" ) var bootstrapDNS = "76.76.2.0" @@ -61,6 +63,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { return or, nil case ResolverTypeLegacy: return &legacyResolver{uc: uc}, nil + case ResolverTypePrivate: + return NewPrivateResolver(), nil } return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ) } From 43ff2f648c3b5de58d246da5f9f804f43d46e762 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 12 Oct 2023 21:56:27 +0700 Subject: [PATCH 14/52] internal/router/dnsmasq: disable cache So multiple upstreams config could work properly. --- internal/router/dnsmasq/dnsmasq.go | 56 +++++++++++++++++++++++++- internal/router/edgeos/edgeos.go | 24 ++++++++--- internal/router/firewalla/firewalla.go | 10 +++++ internal/router/openwrt/openwrt.go | 26 +----------- internal/router/ubios/ubios.go | 17 ++++++-- 5 files changed, 97 insertions(+), 36 deletions(-) diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index 54ba8fd..52cf1a6 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -1,9 +1,12 @@ package dnsmasq import ( + "bytes" "errors" + "fmt" "html/template" "net" + "os" "path/filepath" "strings" @@ -19,6 +22,7 @@ server={{ .IP }}#{{ .Port }} add-mac add-subnet=32,128 {{- end}} +cache-size=0 ` const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf" @@ -47,6 +51,8 @@ if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then {{- end}} pc_delete "dnssec" "$config_file" # disable DNSSEC pc_delete "trust-anchor=" "$config_file" # disable DNSSEC + pc_delete "cache-size=" "$config_file" + pc_append "cache-size=0" "$config_file" # disable cache # For John fork pc_delete "resolv-file" "$config_file" # no WAN DNS settings @@ -117,9 +123,28 @@ func firewallaUpstreams(port int) []Upstream { return upstreams } +// firewallaDnsmasqConfFiles returns dnsmasq config files of all firewalla interfaces. +func firewallaDnsmasqConfFiles() ([]string, error) { + return filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf") +} + +// firewallUpdateConf updates all firewall config files using given function. +func firewallUpdateConf(update func(conf string) error) error { + confFiles, err := firewallaDnsmasqConfFiles() + if err != nil { + return err + } + for _, conf := range confFiles { + if err := update(conf); err != nil { + return fmt.Errorf("%s: %w", conf, err) + } + } + return nil +} + // FirewallaSelfInterfaces returns list of interfaces that will be configured with default dnsmasq setup on Firewalla. func FirewallaSelfInterfaces() []*net.Interface { - matches, err := filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf") + matches, err := firewallaDnsmasqConfFiles() if err != nil { return nil } @@ -133,3 +158,32 @@ func FirewallaSelfInterfaces() []*net.Interface { } return ifaces } + +// FirewallaDisableCache comments out "cache-size" line in all firewalla dnsmasq config files. +func FirewallaDisableCache() error { + return firewallUpdateConf(DisableCache) +} + +// FirewallaEnableCache un-comments out "cache-size" line in all firewalla dnsmasq config files. +func FirewallaEnableCache() error { + return firewallUpdateConf(EnableCache) +} + +// DisableCache comments out "cache-size" line in dnsmasq config file. +func DisableCache(conf string) error { + return replaceFileContent(conf, "\ncache-size=", "\n#cache-size=") +} + +// EnableCache un-comments "cache-size" line in dnsmasq config file. +func EnableCache(conf string) error { + return replaceFileContent(conf, "\n#cache-size=", "\ncache-size=") +} + +func replaceFileContent(filename, old, new string) error { + content, err := os.ReadFile(filename) + if err != nil { + return err + } + content = bytes.ReplaceAll(content, []byte(old), []byte(new)) + return os.WriteFile(filename, content, 0644) +} diff --git a/internal/router/edgeos/edgeos.go b/internal/router/edgeos/edgeos.go index f50f610..0552882 100644 --- a/internal/router/edgeos/edgeos.go +++ b/internal/router/edgeos/edgeos.go @@ -15,11 +15,12 @@ import ( ) const ( - Name = "edgeos" - edgeOSDNSMasqConfigPath = "/etc/dnsmasq.d/dnsmasq-zzz-ctrld.conf" - usgDNSMasqConfigPath = "/etc/dnsmasq.conf" - usgDNSMasqBackupConfigPath = "/etc/dnsmasq.conf.bak" - toggleContentFilteringLink = "https://community.ui.com/questions/UDM-Pro-disable-enable-DNS-filtering/e2cc4060-e56a-4139-b200-62d7f773ff8f" + Name = "edgeos" + edgeOSDNSMasqDefaultConfigPath = "/etc/dnsmasq.conf" + edgeOSDNSMasqConfigPath = "/etc/dnsmasq.d/dnsmasq-zzz-ctrld.conf" + usgDNSMasqConfigPath = "/etc/dnsmasq.conf" + usgDNSMasqBackupConfigPath = "/etc/dnsmasq.conf.bak" + toggleContentFilteringLink = "https://community.ui.com/questions/UDM-Pro-disable-enable-DNS-filtering/e2cc4060-e56a-4139-b200-62d7f773ff8f" ) var ErrContentFilteringEnabled = fmt.Errorf(`the "Content Filtering" feature" is enabled, which is conflicted with ctrld.\n @@ -95,7 +96,7 @@ func (e *EdgeOS) setupUSG() error { return fmt.Errorf("setupUSG: backup current config: %w", err) } - // Removing all configured upstreams. + // Removing all configured upstreams and cache config. var sb strings.Builder scanner := bufio.NewScanner(bytes.NewReader(buf)) for scanner.Scan() { @@ -106,6 +107,9 @@ func (e *EdgeOS) setupUSG() error { if strings.HasPrefix(line, "all-servers") { continue } + if strings.HasPrefix(line, "cache-size") { + continue + } sb.WriteString(line) } @@ -127,6 +131,10 @@ func (e *EdgeOS) setupUSG() error { } func (e *EdgeOS) setupUDM() error { + // Disable dnsmasq cache. + if err := dnsmasq.DisableCache(edgeOSDNSMasqDefaultConfigPath); err != nil { + return err + } data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg) if err != nil { return err @@ -153,6 +161,10 @@ func (e *EdgeOS) cleanupUSG() error { } func (e *EdgeOS) cleanupUDM() error { + // Enable dnsmasq cache. + if err := dnsmasq.EnableCache(edgeOSDNSMasqDefaultConfigPath); err != nil { + return err + } // Remove the custom dnsmasq config if err := os.Remove(edgeOSDNSMasqConfigPath); err != nil { return fmt.Errorf("cleanupUDM: os.Remove: %w", err) diff --git a/internal/router/firewalla/firewalla.go b/internal/router/firewalla/firewalla.go index cdf6586..66cd15e 100644 --- a/internal/router/firewalla/firewalla.go +++ b/internal/router/firewalla/firewalla.go @@ -65,6 +65,11 @@ func (f *Firewalla) Setup() error { return fmt.Errorf("writing ctrld config: %w", err) } + // Disable dnsmasq cache. + if err := dnsmasq.FirewallaDisableCache(); err != nil { + return err + } + // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return fmt.Errorf("restartDNSMasq: %w", err) @@ -82,6 +87,11 @@ func (f *Firewalla) Cleanup() error { return fmt.Errorf("removing ctrld config: %w", err) } + // Enable dnsmasq cache. + if err := dnsmasq.FirewallaEnableCache(); err != nil { + return err + } + // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return fmt.Errorf("restartDNSMasq: %w", err) diff --git a/internal/router/openwrt/openwrt.go b/internal/router/openwrt/openwrt.go index 83ea884..d3bc511 100644 --- a/internal/router/openwrt/openwrt.go +++ b/internal/router/openwrt/openwrt.go @@ -1,18 +1,14 @@ package openwrt import ( - "bytes" - "errors" "fmt" "os" "os/exec" - "strings" - - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) const ( @@ -20,8 +16,6 @@ const ( openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf" ) -var errUCIEntryNotFound = errors.New("uci: Entry not found") - type Openwrt struct { cfg *ctrld.Config } @@ -59,10 +53,6 @@ func (o *Openwrt) Setup() error { if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil { return err } - // Commit. - if _, err := uci("commit"); err != nil { - return err - } // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err @@ -91,17 +81,3 @@ func restartDNSMasq() error { } return nil } - -func uci(args ...string) (string, error) { - cmd := exec.Command("uci", args...) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - if err := cmd.Run(); err != nil { - if strings.HasPrefix(stderr.String(), errUCIEntryNotFound.Error()) { - return "", errUCIEntryNotFound - } - return "", fmt.Errorf("%s:%w", stderr.String(), err) - } - return strings.TrimSpace(stdout.String()), nil -} diff --git a/internal/router/ubios/ubios.go b/internal/router/ubios/ubios.go index b0762db..32c7576 100644 --- a/internal/router/ubios/ubios.go +++ b/internal/router/ubios/ubios.go @@ -5,16 +5,17 @@ import ( "os" "strconv" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" + "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" "github.com/Control-D-Inc/ctrld/internal/router/edgeos" - "github.com/kardianos/service" ) const ( - Name = "ubios" - ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf" + Name = "ubios" + ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf" + ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf" ) type Ubios struct { @@ -57,6 +58,10 @@ func (u *Ubios) Setup() error { if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil { return err } + // Disable dnsmasq cache. + if err := dnsmasq.DisableCache(ubiosDNSMasqDnsConfigPath); err != nil { + return err + } // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err @@ -72,6 +77,10 @@ func (u *Ubios) Cleanup() error { if err := os.Remove(ubiosDNSMasqConfigPath); err != nil { return err } + // Enable dnsmasq cache. + if err := dnsmasq.EnableCache(ubiosDNSMasqDnsConfigPath); err != nil { + return err + } // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err From 8e0a96a44c9f140f038eb1ce34dd49f6d6abc674 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 8 Nov 2023 14:46:42 +0700 Subject: [PATCH 15/52] Fix panic dues to quic-go changes quic.DialEarly requires separate UDP connection for each quic.EarlyConnection instead of re-using the same one. --- config_quic.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/config_quic.go b/config_quic.go index cd3eaee..3f397cb 100644 --- a/config_quic.go +++ b/config_quic.go @@ -127,11 +127,6 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t close(ch) }() - udpConn, err := net.ListenUDP("udp", nil) - if err != nil { - return nil, err - } - for _, addr := range addrs { go func(addr string) { defer wg.Done() @@ -140,6 +135,11 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t ch <- ¶llelDialerResult{conn: nil, err: err} return } + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + ch <- ¶llelDialerResult{conn: nil, err: err} + return + } conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg) select { case ch <- ¶llelDialerResult{conn: conn, err: err}: @@ -147,6 +147,9 @@ func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *t if conn != nil { conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "") } + if udpConn != nil { + udpConn.Close() + } } }(addr) } From efb5a92571c956d53793665c9b8f1fd4e72c19cb Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 8 Nov 2023 01:11:49 +0700 Subject: [PATCH 16/52] Using time interval for probing ipv6 A backoff with small max time will flood requests to Control D server, causing false positive for abuse mitiation system. While a big max time will cause ctrld not realize network change as fast as possible. While at it, also sync DoH3 code with DoH code, ensuring no others place can trigger requests flooding for ipv6 probing. --- config_quic.go | 7 +------ net.go | 35 +++++++++++++++++++---------------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/config_quic.go b/config_quic.go index 3f397cb..5103231 100644 --- a/config_quic.go +++ b/config_quic.go @@ -10,13 +10,10 @@ import ( "net/http" "runtime" "sync" - "time" "github.com/miekg/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" - - ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) func (uc *UpstreamConfig) setupDOH3Transport() { @@ -29,9 +26,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() { uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) case IpStackSplit: uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if ctrldnet.IPv6Available(ctx) { + if hasIPv6() { uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) } else { uc.http3RoundTripper6 = uc.http3RoundTripper4 diff --git a/net.go b/net.go index 110d67e..3ae3bb5 100644 --- a/net.go +++ b/net.go @@ -2,13 +2,10 @@ package ctrld import ( "context" - "errors" "sync" "sync/atomic" "time" - "tailscale.com/logtail/backoff" - ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) @@ -17,30 +14,36 @@ var ( ipv6Available atomic.Bool ) +const ipv6ProbingInterval = 10 * time.Second + func hasIPv6() bool { hasIPv6Once.Do(func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() val := ctrldnet.IPv6Available(ctx) ipv6Available.Store(val) - go probingIPv6(val) + go probingIPv6(context.TODO(), val) }) return ipv6Available.Load() } // TODO(cuonglm): doing poll check natively for supported platforms. -func probingIPv6(old bool) { - b := backoff.NewBackoff("probingIPv6", func(format string, args ...any) {}, 30*time.Second) - bCtx := context.Background() +func probingIPv6(ctx context.Context, old bool) { + ticker := time.NewTicker(ipv6ProbingInterval) + defer ticker.Stop() for { - func() { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - cur := ctrldnet.IPv6Available(ctx) - if ipv6Available.CompareAndSwap(old, cur) { - old = cur - } - }() - b.BackOff(bCtx, errors.New("no change")) + select { + case <-ctx.Done(): + return + case <-ticker.C: + func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + cur := ctrldnet.IPv6Available(ctx) + if ipv6Available.CompareAndSwap(old, cur) { + old = cur + } + }() + } } } From 990bc620f7f9fd3ffbc957964df0f8d7c52008ea Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 8 Nov 2023 18:15:48 +0700 Subject: [PATCH 17/52] cmd/cli: strip EDNS0_SUBNET for RFC 1918 and loopback address Since passing them to upstream is pointless, these cannot be used by anything on the WAN. --- cmd/cli/dns_proxy.go | 18 +++++++++++++++++ cmd/cli/dns_proxy_test.go | 42 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index af7628c..666bf50 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -71,6 +71,7 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) reqId := requestID() remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) + stripClientSubnet(m) remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String()) t := time.Now() @@ -498,6 +499,23 @@ func ipAndMacFromMsg(msg *dns.Msg) (string, string) { return ip, mac } +// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address, +// passing them to upstream is pointless, these cannot be used by anything on the WAN. +func stripClientSubnet(msg *dns.Msg) { + if opt := msg.IsEdns0(); opt != nil { + opts := make([]dns.EDNS0, 0, len(opt.Option)) + for _, s := range opt.Option { + if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) { + continue + } + opts = append(opts, s) + } + if len(opts) != len(opt.Option) { + opt.Option = opts + } + } +} + func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr { if ci != nil && ci.IP != "" { switch addr := addr.(type) { diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 8d18fa0..d0e5c74 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -265,3 +265,45 @@ func Test_ipFromARPA(t *testing.T) { }) } } + +func newDnsMsgWithClientIP(ip string) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)}) + m.Extra = append(m.Extra, o) + return m +} +func Test_stripClientSubnet(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + wantSubnet bool + }{ + {"no edns0", new(dns.Msg), false}, + {"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false}, + {"loopback IP v6", newDnsMsgWithClientIP("::1"), false}, + {"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false}, + {"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false}, + {"public IP", newDnsMsgWithClientIP("1.1.1.1"), true}, + {"invalid IP", newDnsMsgWithClientIP(""), true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + stripClientSubnet(tc.msg) + hasSubnet := false + if opt := tc.msg.IsEdns0(); opt != nil { + for _, s := range opt.Option { + if _, ok := s.(*dns.EDNS0_SUBNET); ok { + hasSubnet = true + } + } + } + if tc.wantSubnet != hasSubnet { + t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet) + } + }) + } +} From 4614b98e94451cbf29a3bcfe81974e11112b42c0 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 7 Nov 2023 21:35:42 +0700 Subject: [PATCH 18/52] internal/clientinfo: emit error once if ptr discovery failed So it won't spam ctrld log unnecessary, prevent confusion. While at it, also change the log level from Warn to Info, since this error is not actionable by the user. --- internal/clientinfo/ptr_lookup.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 6a9d99b..fea79fb 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -72,15 +72,16 @@ func (p *ptrDiscover) lookupHostname(ip string) string { msg := new(dns.Msg) addr, err := dns.ReverseAddr(ip) if err != nil { - ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("invalid ip address") + ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address") return "" } msg.SetQuestion(addr, dns.TypePTR) ans, err := p.resolver.Resolve(ctx, msg) if err != nil { - ctrld.ProxyLogger.Load().Warn().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") - p.serverDown.Store(true) - go p.checkServer() + if p.serverDown.CompareAndSwap(false, true) { + ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup") + go p.checkServer() + } return "" } for _, rr := range ans.Answer { From 09188bedf7e2161fc4d682b44b839669eb5e44cf Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 9 Nov 2023 18:20:39 +0700 Subject: [PATCH 19/52] cmd/cli: fix wrong generated config for nextdns resolver Generating nextdns config must happen after stopping current ctrld process. Otherwise, config processing may pick wrong IP+Port. While at it, also making logging better when updating listener config: - Change warn to info, prevent confusing that "something is wrong". - Do not emit info when generating working default config, which may cause duplicated messages printed. --- cmd/cli/cli.go | 42 +++++++++++++++++++++++++++--------------- cmd/cli/nextdns.go | 6 +++--- cmd/cli/prog.go | 2 +- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index bf6803f..13e366d 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -238,11 +238,6 @@ func initCLI() { if nextdns != "" { removeNextDNSFromArgs(sc) - generateNextDNSConfig() - updateListenerConfig(&cfg) - if err := writeConfigFile(); err != nil { - mainLog.Load().Error().Err(err).Msg("failed to write config with NextDNS resolver") - } } // Explicitly passing config, so on system where home directory could not be obtained, @@ -265,6 +260,7 @@ func initCLI() { tasks := []task{ {s.Stop, false}, + {func() error { return doGenerateNextDNSConfig(nextdns) }, true}, {s.Uninstall, false}, {s.Install, false}, {s.Start, true}, @@ -1044,7 +1040,8 @@ func readConfigFile(writeDefaultConfig bool) bool { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } - _ = updateListenerConfig(&cfg) + nop := zerolog.Nop() + _, _ = tryUpdateListenerConfig(&cfg, &nop, true) if err := writeConfigFile(); err != nil { mainLog.Load().Fatal().Msgf("failed to write default config file: %v", err) } else { @@ -1580,14 +1577,14 @@ func mobileListenerPort() int { // or defined but invalid to be used, e.g: using loopback address other // than 127.0.0.1 with systemd-resolved. func updateListenerConfig(cfg *ctrld.Config) bool { - updated, _ := tryUpdateListenerConfig(cfg, true) + updated, _ := tryUpdateListenerConfig(cfg, nil, true) return updated } // tryUpdateListenerConfig tries updating listener config with a working one. // If fatal is true, and there's listen address conflicted, the function do // fatal error. -func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { +func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fatal bool) (updated, ok bool) { ok = true lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" @@ -1610,6 +1607,10 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { } updated = updated || lcc[n].IP || lcc[n].Port } + il := mainLog.Load() + if infoLogger != nil { + il = infoLogger + } if isMobile() { // On Mobile, only use first listener, ignore others. firstLn := cfg.FirstListener() @@ -1716,7 +1717,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { listener.Port = 53 } if check.IP { - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) + logMsg(il.Info(), n, "could not listen on address: %s, trying: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) } continue } @@ -1729,7 +1730,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { listener.Port = 53 } if check.IP { - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) + logMsg(il.Info(), n, "could not listen on address: %s, trying localhost: %s", addr, net.JoinHostPort(listener.IP, strconv.Itoa(listener.Port))) } continue } @@ -1741,7 +1742,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { if check.Port { listener.Port = 5354 } - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying current ip with port 5354", addr) + logMsg(il.Info(), n, "could not listen on address: %s, trying current ip with port 5354", addr) continue } if tryPort5354 { @@ -1752,7 +1753,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { if check.Port { listener.Port = 5354 } - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr) + logMsg(il.Info(), n, "could not listen on address: %s, trying 0.0.0.0:5354", addr) continue } if check.IP && !isZeroIP { // for "0.0.0.0" or "::", we only need to try new port. @@ -1772,7 +1773,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { ok = false break } - logMsg(mainLog.Load().Warn(), n, "could not listen on address: %s, pick a random ip+port", addr) + logMsg(il.Info(), n, "could not listen on address: %s, pick a random ip+port", addr) attempts++ } } @@ -1788,7 +1789,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { // ip address, other than "127.0.0.1", so trying to listen on default route interface // address instead. if ip := net.ParseIP(listener.IP); ip != nil && ip.IsLoopback() && ip.String() != "127.0.0.1" { - logMsg(mainLog.Load().Warn(), n, "using loopback interface do not work with systemd-resolved") + logMsg(il.Info(), n, "using loopback interface do not work with systemd-resolved") found := false if netIface, _ := net.InterfaceByName(defaultIfaceName()); netIface != nil { addrs, _ := netIface.Addrs() @@ -1798,7 +1799,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, fatal bool) (updated, ok bool) { if err := tryListen(addr); err == nil { found = true listener.IP = netIP.IP.String() - logMsg(mainLog.Load().Warn(), n, "use %s as listener address", listener.IP) + logMsg(il.Info(), n, "use %s as listener address", listener.IP) break } } @@ -1956,3 +1957,14 @@ func removeNextDNSFromArgs(sc *service.Config) { } sc.Arguments = a } + +// doGenerateNextDNSConfig generates a working config with nextdns resolver. +func doGenerateNextDNSConfig(uid string) error { + if uid == "" { + return nil + } + mainLog.Load().Notice().Msgf("Generating nextdns config: %s", defaultConfigFile) + generateNextDNSConfig(uid) + updateListenerConfig(&cfg) + return writeConfigFile() +} diff --git a/cmd/cli/nextdns.go b/cmd/cli/nextdns.go index 8aebfc8..f4fed47 100644 --- a/cmd/cli/nextdns.go +++ b/cmd/cli/nextdns.go @@ -8,8 +8,8 @@ import ( const nextdnsURL = "https://dns.nextdns.io" -func generateNextDNSConfig() { - if nextdns == "" { +func generateNextDNSConfig(uid string) { + if uid == "" { return } mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver") @@ -23,7 +23,7 @@ func generateNextDNSConfig() { Upstream: map[string]*ctrld.UpstreamConfig{ "0": { Type: ctrld.ResolverTypeDOH3, - Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, nextdns), + Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, uid), Timeout: 5000, }, }, diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d304cce..d29f374 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -135,7 +135,7 @@ func (p *prog) runWait() { waitOldRunDone() - _, ok := tryUpdateListenerConfig(newCfg, false) + _, ok := tryUpdateListenerConfig(newCfg, nil, false) if !ok { logger.Error().Msg("could not update listener config") continue From c3b4ae9c79fc6fcadfb5446ca24afd7dc5c249f8 Mon Sep 17 00:00:00 2001 From: Ginder Singh Date: Thu, 9 Nov 2023 16:33:48 +0000 Subject: [PATCH 20/52] Older android missing certificate --- cmd/cli/cli.go | 2 +- internal/controld/config.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 13e366d..e91115a 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1149,7 +1149,7 @@ func processCDFlags(cfg *ctrld.Config) error { } if err != nil { if isMobile() { - return errors.New("could not fetch resolver config") + return err } logger.Warn().Err(err).Msg("could not fetch resolver config") return err diff --git a/internal/controld/config.go b/internal/controld/config.go index 4e4bc2e..4cc6770 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "os" + "runtime" "strings" "time" @@ -119,7 +120,7 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig return d.DialContext(ctx, network, addrs) } - if router.Name() == ddwrt.Name { + if router.Name() == ddwrt.Name || runtime.GOOS == "android" { transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()} } client := http.Client{ From 294a90a807d030b7c6b9f45bcb668a09de509ef4 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 13 Nov 2023 20:03:46 +0700 Subject: [PATCH 21/52] internal/router/openwrt: ensure dnsmasq cache is disabled Users may have their own dnsmasq cache set via LUCI web, thus ctrld needs to delete the cache-size setting to ensure its dnsmasq config works. --- internal/router/openwrt/openwrt.go | 47 +++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/internal/router/openwrt/openwrt.go b/internal/router/openwrt/openwrt.go index d3bc511..ad98db9 100644 --- a/internal/router/openwrt/openwrt.go +++ b/internal/router/openwrt/openwrt.go @@ -1,9 +1,12 @@ package openwrt import ( + "bytes" + "errors" "fmt" "os" "os/exec" + "strings" "github.com/kardianos/service" @@ -17,7 +20,8 @@ const ( ) type Openwrt struct { - cfg *ctrld.Config + cfg *ctrld.Config + dnsmasqCacheSize string } // New returns a router.Router for configuring/setup/run ctrld on Openwrt routers. @@ -46,6 +50,19 @@ func (o *Openwrt) Setup() error { if o.cfg.FirstListener().IsDirectDnsListener() { return nil } + + // Save current dnsmasq config cache size if present. + if cs, err := uci("get", "dhcp.@dnsmasq[0].cachesize"); err == nil { + o.dnsmasqCacheSize = cs + if _, err := uci("delete", "dhcp.@dnsmasq[0].cachesize"); err != nil { + return err + } + // Commit. + if _, err := uci("commit", "dhcp"); err != nil { + return err + } + } + data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, o.cfg) if err != nil { return err @@ -68,6 +85,18 @@ func (o *Openwrt) Cleanup() error { if err := os.Remove(openwrtDNSMasqConfigPath); err != nil { return err } + + // Restore original value if present. + if o.dnsmasqCacheSize != "" { + if _, err := uci("set", fmt.Sprintf("dhcp.@dnsmasq[0].cachesize=%s", o.dnsmasqCacheSize)); err != nil { + return err + } + // Commit. + if _, err := uci("commit", "dhcp"); err != nil { + return err + } + } + // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err @@ -81,3 +110,19 @@ func restartDNSMasq() error { } return nil } + +var errUCIEntryNotFound = errors.New("uci: Entry not found") + +func uci(args ...string) (string, error) { + cmd := exec.Command("uci", args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + if strings.HasPrefix(stderr.String(), errUCIEntryNotFound.Error()) { + return "", errUCIEntryNotFound + } + return "", fmt.Errorf("%s:%w", stderr.String(), err) + } + return strings.TrimSpace(stdout.String()), nil +} From d01f5c27777bb1f70ecf76a4aa16d0758da12037 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 9 Nov 2023 22:10:24 +0700 Subject: [PATCH 22/52] cmd/cli: do not stop listener when reloading We could not do a reload if the listener config changes, so do not turn them off to try updating new listener config. --- cmd/cli/control_server.go | 14 +++++ cmd/cli/dns_proxy.go | 36 +++--------- cmd/cli/prog.go | 121 ++++++++++++++++++++++++-------------- 3 files changed, 101 insertions(+), 70 deletions(-) diff --git a/cmd/cli/control_server.go b/cmd/cli/control_server.go index 80bc1ab..5ee7112 100644 --- a/cmd/cli/control_server.go +++ b/cmd/cli/control_server.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "os" + "reflect" "sort" "time" @@ -87,6 +88,7 @@ func (p *prog) registerControlServerHandler() { Port: v.Port, } } + oldSvc := p.cfg.Service p.mu.Unlock() if err := p.sendReloadSignal(); err != nil { mainLog.Load().Err(err).Msg("could not send reload signal") @@ -102,6 +104,10 @@ func (p *prog) registerControlServerHandler() { p.mu.Lock() defer p.mu.Unlock() + + // Checking for cases that we could not do a reload. + + // 1. Listener config ip or port changes. for k, v := range p.cfg.Listener { l := listeners[k] if l == nil || l.IP != v.IP || l.Port != v.Port { @@ -109,6 +115,14 @@ func (p *prog) registerControlServerHandler() { return } } + + // 2. Service config changes. + if !reflect.DeepEqual(oldSvc, p.cfg.Service) { + w.WriteHeader(http.StatusCreated) + return + } + + // Otherwise, reload is done. w.WriteHeader(http.StatusOK) })) } diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 666bf50..de8aef7 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -4,7 +4,6 @@ import ( "context" "crypto/rand" "encoding/hex" - "errors" "fmt" "net" "net/netip" @@ -44,19 +43,14 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } -var errReload = errors.New("reload") - -func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) error { +func (p *prog) serveDNS(listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") return allocErr } - var failoverRcodes []int - if listenerConfig.Policy != nil { - failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers - } + handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { p.sema.acquire() defer p.sema.release() @@ -83,7 +77,7 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) answer = new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) } else { - answer = p.proxy(ctx, upstreams, failoverRcodes, m, ci) + answer = p.proxy(ctx, upstreams, listenerConfig.Policy.FailoverRcodeNumbers, m, ci) rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } @@ -93,12 +87,6 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) }) g, ctx := errgroup.WithContext(context.Background()) - // When receiving reload signal, return a non-nil error so other - // goroutines in errgroup.Group could be terminated. - g.Go(func() error { - <-reloadCh - return errReload - }) for _, proto := range []string{"udp", "tcp"} { proto := proto if needLocalIPv6Listener() { @@ -142,13 +130,11 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) s, errCh := runDNSServer(addr, proto, handler) defer s.Shutdown() - if !reload { - select { - case err := <-errCh: - return err - case <-time.After(5 * time.Second): - p.started <- struct{}{} - } + select { + case err := <-errCh: + return err + case <-time.After(5 * time.Second): + p.started <- struct{}{} } select { case <-p.stopCh: @@ -159,11 +145,7 @@ func (p *prog) serveDNS(listenerNum string, reload bool, reloadCh chan struct{}) return nil }) } - err := g.Wait() - if errors.Is(err, errReload) { // This is an error for trigger reload, not a real error. - return nil - } - return err + return g.Wait() } // upstreamFor returns the list of upstreams for resolving the given domain, diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d29f374..be50ea6 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -135,16 +135,32 @@ func (p *prog) runWait() { waitOldRunDone() - _, ok := tryUpdateListenerConfig(newCfg, nil, false) - if !ok { - logger.Error().Msg("could not update listener config") - continue + p.mu.Lock() + curListener := p.cfg.Listener + p.mu.Unlock() + + for n, lc := range newCfg.Listener { + curLc := curListener[n] + if curLc == nil { + continue + } + if lc.IP == "" { + lc.IP = curLc.IP + } + if lc.Port == 0 { + lc.Port = curLc.Port + } } if err := validateConfig(newCfg); err != nil { logger.Err(err).Msg("invalid config") continue } + // This needs to be done here, otherwise, the DNS handler may observe an invalid + // upstream config because its initialization function have not been called yet. + mainLog.Load().Debug().Msg("setup upstream with new config") + setupUpstream(newCfg) + p.mu.Lock() *p.cfg = *newCfg p.mu.Unlock() @@ -170,6 +186,21 @@ func (p *prog) preRun() { } } +func setupUpstream(cfg *ctrld.Config) { + for n := range cfg.Upstream { + uc := cfg.Upstream[n] + uc.Init() + if uc.BootstrapIP == "" { + uc.SetupBootstrapIP() + mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) + } else { + mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) + } + uc.SetCertPool(rootCertPool) + go uc.Ping() + } +} + // run runs the ctrld main components. // // The reload boolean indicates that the function is run when ctrld first start @@ -183,7 +214,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { <-p.waitCh p.preRun() numListeners := len(p.cfg.Listener) - p.started = make(chan struct{}, numListeners) + if !reload { + p.started = make(chan struct{}, numListeners) + } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) if p.cfg.Service.CacheEnable { @@ -194,15 +227,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.cache = cacher } } - p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} - if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil { - n := *mcr - if n == 0 { - p.sema = &noopSemaphore{} - } else { - p.sema = &chanSemaphore{ready: make(chan struct{}, n)} - } - } + var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) @@ -218,24 +243,24 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } p.um = newUpstreamMonitor(p.cfg) - for n := range p.cfg.Upstream { - uc := p.cfg.Upstream[n] - uc.Init() - if uc.BootstrapIP == "" { - uc.SetupBootstrapIP() - mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs()) - } else { - mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n) - } - uc.SetCertPool(rootCertPool) - go uc.Ping() - } - p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID) - if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { - mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) - format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) - p.ciTable.AddLeaseFile(leaseFile, format) + if !reload { + p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} + if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil { + n := *mcr + if n == 0 { + p.sema = &noopSemaphore{} + } else { + p.sema = &chanSemaphore{ready: make(chan struct{}, n)} + } + } + setupUpstream(p.cfg) + p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID) + if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { + mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) + format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) + p.ciTable.AddLeaseFile(leaseFile, format) + } } // context for managing spawn goroutines. @@ -243,7 +268,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { defer cancelFunc() // Newer versions of android and iOS denies permission which breaks connectivity. - if !isMobile() { + if !isMobile() && !reload { wg.Add(1) go func() { defer wg.Done() @@ -255,22 +280,32 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { for listenerNum := range p.cfg.Listener { p.cfg.Listener[listenerNum].Init() - go func(listenerNum string) { + if !reload { + go func(listenerNum string) { + listenerConfig := p.cfg.Listener[listenerNum] + upstreamConfig := p.cfg.Upstream[listenerNum] + if upstreamConfig == nil { + mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + } + addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) + mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) + if err := p.serveDNS(listenerNum); err != nil { + mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) + } + }(listenerNum) + } + go func() { defer func() { cancelFunc() wg.Done() }() - listenerConfig := p.cfg.Listener[listenerNum] - upstreamConfig := p.cfg.Upstream[listenerNum] - if upstreamConfig == nil { - mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum) + select { + case <-p.stopCh: + case <-ctx.Done(): + case <-reloadCh: } - addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) - mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr) - if err := p.serveDNS(listenerNum, reload, reloadCh); err != nil { - mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum) - } - }(listenerNum) + return + }() } if !reload { From 180eae60f2bdbd282c3e4b28ba4e0ec80ab32529 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 14 Nov 2023 00:28:15 +0700 Subject: [PATCH 23/52] all: allowing config defined discover ptr endpoints The default gateway is usually the DNS server in normal home network setup for most users. However, there's case that it is not, causing discover ptr failed. This commit add discover_ptr_endpoints config parameter, so users can define what DNS nameservers will be used. --- config.go | 33 +++++++++++++++--------------- docs/config.md | 16 +++++++++++++++ internal/clientinfo/client_info.go | 21 +++++++++++++++++++ resolver.go | 19 ++++++++++++----- 4 files changed, 68 insertions(+), 21 deletions(-) diff --git a/config.go b/config.go index 489a7fd..97a837e 100644 --- a/config.go +++ b/config.go @@ -167,22 +167,23 @@ func (c *Config) FirstUpstream() *UpstreamConfig { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` - LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` - CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` - CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` - CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` - CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` - MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` - DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` - DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` - DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` - DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"` - DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` - DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` - DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` - Daemon bool `mapstructure:"-" toml:"-"` - AllocateIP bool `mapstructure:"-" toml:"-"` + LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` + LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` + MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` + DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` + DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` + DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` + DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"` + DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` + DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` + DiscoverPtrEndpoints []string `mapstructure:"discover_ptr_endpoints" toml:"discover_ptr_endpoints,omitempty"` + DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` + Daemon bool `mapstructure:"-" toml:"-"` + AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. diff --git a/docs/config.md b/docs/config.md index 29dff8d..57a794b 100644 --- a/docs/config.md +++ b/docs/config.md @@ -193,6 +193,22 @@ Perform LAN client discovery using PTR queries. - Required: no - Default: true +### discover_ptr_endpoints +List of DNS nameservers used for PTR discovery. + +Each entry can be either "ip" (default port 53) or "ip:port" pair. Invalid entry will be ignored. + +- Type: array of string +- Required: no +- Default: [] + +Example: + +```toml +[service] +discover_ptr_endpoints = ["192.168.1.1", "192.168.2.1:5354"] +``` + ### discover_hosts Perform LAN client discovery using hosts file. diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index f591174..ee1a14f 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -3,7 +3,9 @@ package clientinfo import ( "context" "fmt" + "net" "net/netip" + "strconv" "strings" "sync" "time" @@ -183,6 +185,25 @@ func (t *Table) init() { // PTR lookup. if t.discoverPTR() { t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} + if len(t.svcCfg.DiscoverPtrEndpoints) > 0 { + nss := make([]string, 0, len(t.svcCfg.DiscoverPtrEndpoints)) + for _, ns := range t.svcCfg.DiscoverPtrEndpoints { + host, port := ns, "53" + if h, p, err := net.SplitHostPort(ns); err == nil { + host, port = h, p + } + // Only use valid ip:port pair. + if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { + nss = append(nss, net.JoinHostPort(host, port)) + } else { + ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) + } + } + if len(nss) > 0 { + t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) + ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) + } + } ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") if err := t.ptr.refresh(); err != nil { ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover") diff --git a/resolver.go b/resolver.go index f08263b..8fd26cf 100644 --- a/resolver.go +++ b/resolver.go @@ -78,8 +78,9 @@ type osResolverResult struct { err error } -// Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from -// available nameservers with a roundrobin algorithm. +// Resolve resolves DNS queries using pre-configured nameservers. +// Query is sent to all nameservers concurrently, and the first +// success response will be returned. func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { numServers := len(o.nameservers) if numServers == 0 { @@ -269,11 +270,19 @@ func NewPrivateResolver() Resolver { } } nss = nss[:n] - if len(nss) == 0 { + return NewResolverWithNameserver(nss) +} + +// NewResolverWithNameserver returns an OS resolver which uses the given nameservers +// for resolving DNS queries. If nameservers is empty, a dummy resolver will be returned. +// +// Each nameserver must be form "host:port". It's the caller responsibility to ensure all +// nameservers are well formatted by using net.JoinHostPort function. +func NewResolverWithNameserver(nameservers []string) Resolver { + if len(nameservers) == 0 { return &dummyResolver{} } - resolver := &osResolver{nameservers: nss} - return resolver + return &osResolver{nameservers: nameservers} } func newDialer(dnsAddress string) *net.Dialer { From 91d319804b161a079ad62d9643f9385f2bacaf0a Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 14 Nov 2023 09:28:10 +0700 Subject: [PATCH 24/52] cmd/cli: only use failover rcodes if defined --- cmd/cli/dns_proxy.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index de8aef7..69d94f3 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -77,7 +77,11 @@ func (p *prog) serveDNS(listenerNum string) error { answer = new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) } else { - answer = p.proxy(ctx, upstreams, listenerConfig.Policy.FailoverRcodeNumbers, m, ci) + var failoverRcode []int + if listenerConfig.Policy != nil { + failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers + } + answer = p.proxy(ctx, upstreams, failoverRcode, m, ci) rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } From cd9c75088415d9515e9d40816a94567ce42b4469 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 14 Nov 2023 10:08:13 +0700 Subject: [PATCH 25/52] cmd/cli: do not run pre run on reload --- cmd/cli/prog.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index be50ea6..867c08a 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -212,7 +212,9 @@ func setupUpstream(cfg *ctrld.Config) { func (p *prog) run(reload bool, reloadCh chan struct{}) { // Wait the caller to signal that we can do our logic. <-p.waitCh - p.preRun() + if !reload { + p.preRun() + } numListeners := len(p.cfg.Listener) if !reload { p.started = make(chan struct{}, numListeners) From 494d8be77728adf83d95f3e732f66321cb128803 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 14 Nov 2023 21:46:23 +0700 Subject: [PATCH 26/52] cmd/cli: skip router setup with "ctrld service start" Either do magic stuff and make things work automatically (normal users), or don't do any of it and just run ctrld as a service (power users). --- cmd/cli/cli.go | 44 ++++++++++++++++++++++++++++++-------------- cmd/cli/main.go | 42 ++++++++++++++++++++++-------------------- 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index e91115a..26f228b 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -144,6 +144,8 @@ func initCLI() { _ = runCmd.Flags().MarkHidden("homedir") runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) _ = runCmd.Flags().MarkHidden("iface") + runCmd.Flags().BoolVarP(&setupRouter, "router", "", false, "Do setup router") + _ = runCmd.Flags().MarkHidden("router") runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) rootCmd.AddCommand(runCmd) @@ -253,7 +255,7 @@ func initCLI() { return } - if router.Name() != "" { + if router.Name() != "" && setupRouter { mainLog.Load().Debug().Msg("cleaning up router before installing") _ = p.router.Cleanup() } @@ -307,6 +309,8 @@ func initCLI() { startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) + startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, "Do router setup") + _ = startCmd.Flags().MarkHidden("router") routerCmd := &cobra.Command{ Use: "setup", @@ -591,11 +595,16 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if !cmd.Flags().Changed("iface") { os.Args = append(os.Args, "--iface="+ifaceStartStop) } + if !cmd.Flags().Changed("router") { + os.Args = append(os.Args, fmt.Sprintf("--router=%v", setupRouterStartStop)) + } iface = ifaceStartStop + setupRouter = setupRouterStartStop startCmd.Run(cmd, args) }, } startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmdAlias.Flags().BoolVarP(&setupRouterStartStop, "router", "", true, "Do router setup") startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) rootCmd.AddCommand(startCmdAlias) stopCmdAlias := &cobra.Command{ @@ -609,11 +618,16 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if !cmd.Flags().Changed("iface") { os.Args = append(os.Args, "--iface="+ifaceStartStop) } + if !cmd.Flags().Changed("router") { + os.Args = append(os.Args, fmt.Sprintf("--router=%v", setupRouterStartStop)) + } iface = ifaceStartStop + setupRouter = setupRouterStartStop stopCmd.Run(cmd, args) }, } stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) + stopCmdAlias.Flags().BoolVarP(&setupRouterStartStop, "router", "", true, "Do router setup") stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) rootCmd.AddCommand(stopCmdAlias) @@ -974,19 +988,21 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if cp := router.CertPool(); cp != nil { rootCertPool = cp } - p.onStarted = append(p.onStarted, func() { - mainLog.Load().Debug().Msg("router setup on start") - if err := p.router.Setup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not configure router") - } - }) - p.onStopped = append(p.onStopped, func() { - mainLog.Load().Debug().Msg("router cleanup on stop") - if err := p.router.Cleanup(); err != nil { - mainLog.Load().Error().Err(err).Msg("could not cleanup router") - } - p.resetDNS() - }) + if setupRouter { + p.onStarted = append(p.onStarted, func() { + mainLog.Load().Debug().Msg("router setup on start") + if err := p.router.Setup(); err != nil { + mainLog.Load().Error().Err(err).Msg("could not configure router") + } + }) + p.onStopped = append(p.onStopped, func() { + mainLog.Load().Debug().Msg("router cleanup on stop") + if err := p.router.Cleanup(); err != nil { + mainLog.Load().Error().Err(err).Msg("could not cleanup router") + } + p.resetDNS() + }) + } } close(waitCh) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index bf65044..b1287ae 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -14,26 +14,28 @@ import ( ) var ( - configPath string - configBase64 string - daemon bool - listenAddress string - primaryUpstream string - secondaryUpstream string - domains []string - logPath string - homedir string - cacheSize int - cfg ctrld.Config - verbose int - silent bool - cdUID string - cdOrg string - cdDev bool - iface string - ifaceStartStop string - nextdns string - cdUpstreamProto string + configPath string + configBase64 string + daemon bool + listenAddress string + primaryUpstream string + secondaryUpstream string + domains []string + logPath string + homedir string + cacheSize int + cfg ctrld.Config + verbose int + silent bool + cdUID string + cdOrg string + cdDev bool + iface string + ifaceStartStop string + nextdns string + cdUpstreamProto string + setupRouter bool + setupRouterStartStop bool mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter From 4f125cf107b1b1454585cd5ed4289052e41db213 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 15 Nov 2023 14:58:36 +0700 Subject: [PATCH 27/52] cmd/cli: notice users where config file is written/read --- cmd/cli/cli.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 26f228b..5215b14 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -266,6 +266,9 @@ func initCLI() { {s.Uninstall, false}, {s.Install, false}, {s.Start, true}, + // Note that startCmd do not actually write ControlD config, but the config file was + // generated after s.Start, so we notice users here for consistent with nextdns mode. + {noticeWritingControlDConfig, false}, } mainLog.Load().Notice().Msg("Starting service") if doTasks(tasks) { @@ -1042,6 +1045,7 @@ func readConfigFile(writeDefaultConfig bool) bool { // If err == nil, there's a config supplied via `--config`, no default config written. err := v.ReadInConfig() if err == nil { + mainLog.Load().Notice().Msg("Reading config: " + v.ConfigFileUsed()) mainLog.Load().Info().Msg("loading config file from: " + v.ConfigFileUsed()) defaultConfigFile = v.ConfigFileUsed() return true @@ -1984,3 +1988,10 @@ func doGenerateNextDNSConfig(uid string) error { updateListenerConfig(&cfg) return writeConfigFile() } + +func noticeWritingControlDConfig() error { + if cdUID != "" { + mainLog.Load().Notice().Msgf("Generating controld config: %s", defaultConfigFile) + } + return nil +} From 0a30fdea69bf60622ef59981a302c011901c37a7 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 15 Nov 2023 15:07:13 +0700 Subject: [PATCH 28/52] Add listener policy to default generated config So technical user can figure thing out based on self-documented commands, without referring to actual documentation. --- config.go | 10 ++++++++++ config_test.go | 7 ++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 97a837e..051fb80 100644 --- a/config.go +++ b/config.go @@ -82,6 +82,16 @@ func InitConfig(v *viper.Viper, name string) { "0": { IP: "", Port: 0, + Policy: &ListenerPolicyConfig{ + Name: "Main Policy", + Networks: []Rule{ + {"network.0": []string{"upstream.0"}}, + }, + Rules: []Rule{ + {"example.com": []string{"upstream.0"}}, + {"*.ads.com": []string{"upstream.1"}}, + }, + }, }, }) v.SetDefault("network", map[string]*NetworkConfig{ diff --git a/config_test.go b/config_test.go index ca57372..4123f00 100644 --- a/config_test.go +++ b/config_test.go @@ -54,7 +54,12 @@ func TestLoadDefaultConfig(t *testing.T) { cfg := defaultConfig(t) validate := validator.New() require.NoError(t, ctrld.ValidateConfig(validate, cfg)) - assert.Len(t, cfg.Listener, 1) + if assert.Len(t, cfg.Listener, 1) { + l0 := cfg.Listener["0"] + require.NotNil(t, l0.Policy) + assert.Len(t, l0.Policy.Networks, 1) + assert.Len(t, l0.Policy.Rules, 2) + } assert.Len(t, cfg.Upstream, 2) } From 856abb71b72448bbe5d715f411c2ac7eeddcfa65 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 16 Nov 2023 18:22:38 +0700 Subject: [PATCH 29/52] cmd/cli: only notice reading config with "ctrld start" While at it, also updating the documentation of related functions. --- cmd/cli/cli.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 5215b14..8c2c3d4 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -230,7 +230,7 @@ func initCLI() { }() } - tryReadingConfig(writeDefaultConfig) + tryReadingConfigWithNotice(writeDefaultConfig, true) if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) @@ -1041,11 +1041,17 @@ func writeConfigFile() error { return nil } -func readConfigFile(writeDefaultConfig bool) bool { +// readConfigFile reads in config file. +// +// - It writes default config file if config file not found if writeDefaultConfig is true. +// - It emits notice message to user if notice is true. +func readConfigFile(writeDefaultConfig, notice bool) bool { // If err == nil, there's a config supplied via `--config`, no default config written. err := v.ReadInConfig() if err == nil { - mainLog.Load().Notice().Msg("Reading config: " + v.ConfigFileUsed()) + if notice { + mainLog.Load().Notice().Msg("Reading config: " + v.ConfigFileUsed()) + } mainLog.Load().Info().Msg("loading config file from: " + v.ConfigFileUsed()) defaultConfigFile = v.ConfigFileUsed() return true @@ -1452,21 +1458,35 @@ func userHomeDir() (string, error) { return dir, nil } +// tryReadingConfig is like tryReadingConfigWithNotice, with notice set to false. func tryReadingConfig(writeDefaultConfig bool) { + tryReadingConfigWithNotice(writeDefaultConfig, false) +} + +// tryReadingConfigWithNotice tries reading in config files, either specified by user or from default +// locations. If notice is true, emitting a notice message to user which config file was read. +func tryReadingConfigWithNotice(writeDefaultConfig, notice bool) { // --config is specified. if configPath != "" { v.SetConfigFile(configPath) - readConfigFile(false) + readConfigFile(false, notice) return } // no config start or base64 config mode. if !writeDefaultConfig { return } - readConfig(writeDefaultConfig) + readConfigWithNotice(writeDefaultConfig, notice) } +// readConfig calls readConfigWithNotice with notice set to false. func readConfig(writeDefaultConfig bool) { + readConfigWithNotice(writeDefaultConfig, false) +} + +// readConfigWithNotice calls readConfigFile with config file set to ctrld.toml +// or config.toml for compatible with earlier versions of ctrld. +func readConfigWithNotice(writeDefaultConfig, notice bool) { configs := []struct { name string written bool @@ -1483,7 +1503,7 @@ func readConfig(writeDefaultConfig bool) { for _, config := range configs { ctrld.SetConfigNameWithPath(v, config.name, dir) v.SetConfigFile(configPath) - if readConfigFile(config.written) { + if readConfigFile(config.written, notice) { break } } From 564c9ef71293cc74abc80b75a97f0828321b6b00 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 16 Nov 2023 19:48:40 +0700 Subject: [PATCH 30/52] cmd/cli: use IP as hostname for ipv4 clients only For Android devices, when it joins the network, it uses ctrld to resolve its private DNS once and never reaches ctrld again. For each time, it uses a different IPv6 address, which causes hundreds/thousands different client IDs created for the same device, which is pointless. --- cmd/cli/dns_proxy.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 69d94f3..77ec44e 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -586,8 +586,17 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo { if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" { ci.Hostname = hostname } else { - ci.Hostname = ci.IP - p.ciTable.StoreVPNClient(ci) + // Only use IP as hostname for IPv4 clients. + // For Android devices, when it joins the network, it uses ctrld to resolve + // its private DNS once and never reaches ctrld again. For each time, it uses + // a different IPv6 address, which causes hundreds/thousands different client + // IDs created for the same device, which is pointless. + // + // TODO(cuonglm): investigate whether this can be a false positive for other clients? + if !ctrldnet.IsIPv6(ci.IP) { + ci.Hostname = ci.IP + p.ciTable.StoreVPNClient(ci) + } } } else { ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac) From a2116e5eb530f410519cdca2f0f434117a262527 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 17 Nov 2023 17:35:42 +0700 Subject: [PATCH 31/52] cmd/cli: do not substitute MAC if empty Using IPv4 as hostname is enough to distinguish clients. --- cmd/cli/dns_proxy.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 77ec44e..855e5d3 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -579,10 +579,8 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo { } // If MAC is still empty here, that mean the requests are made from virtual interface, - // like VPN/Wireguard clients, so we use whatever MAC address associated with remoteIP - // (most likely 127.0.0.1), and ci.IP as hostname, so we can distinguish those clients. + // like VPN/Wireguard clients, so we use ci.IP as hostname to distinguish those clients. if ci.Mac == "" { - ci.Mac = p.ciTable.LookupMac(remoteIP) if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" { ci.Hostname = hostname } else { From 9e6e647ff8ddfb70d3a5eca2320982d881841436 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Sat, 18 Nov 2023 06:11:27 +0700 Subject: [PATCH 32/52] Use discover_ptr_endpoints for PTR resolver --- cmd/cli/dns_proxy.go | 6 ++++++ cmd/cli/prog.go | 4 ++++ config.go | 23 +++++++++++++++++++++++ internal/clientinfo/client_info.go | 23 +++-------------------- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 855e5d3..06d0702 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -285,6 +285,12 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) + if upstreamConfig.Type == ctrld.ResolverTypePrivate { + if r := p.ptrResolver; r != nil { + ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR resolver", p.cfg.Service.DiscoverPtrEndpoints) + dnsResolver = r + } + } if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") return nil, err diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 867c08a..fb88c81 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -62,6 +62,7 @@ type prog struct { ciTable *clientinfo.Table um *upstreamMonitor router router.Router + ptrResolver ctrld.Resolver loopMu sync.Mutex loop map[string]bool @@ -229,6 +230,9 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.cache = cacher } } + if r := p.cfg.Service.PtrResolver(); r != nil { + p.ptrResolver = r + } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) diff --git a/config.go b/config.go index 051fb80..8b37078 100644 --- a/config.go +++ b/config.go @@ -196,6 +196,29 @@ type ServiceConfig struct { AllocateIP bool `mapstructure:"-" toml:"-"` } +// PtrResolver returns a Resolver used for PTR lookup, based on ServiceConfig.DiscoverPtrEndpoints value. +func (s ServiceConfig) PtrResolver() Resolver { + if len(s.DiscoverPtrEndpoints) > 0 { + nss := make([]string, 0, len(s.DiscoverPtrEndpoints)) + for _, ns := range s.DiscoverPtrEndpoints { + host, port := ns, "53" + if h, p, err := net.SplitHostPort(ns); err == nil { + host, port = h, p + } + // Only use valid ip:port pair. + if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { + nss = append(nss, net.JoinHostPort(host, port)) + } else { + ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for PTR resolver: %q", ns) + } + } + if len(nss) > 0 { + return NewResolverWithNameserver(nss) + } + } + return nil +} + // NetworkConfig specifies configuration for networks where ctrld will handle requests. type NetworkConfig struct { Name string `mapstructure:"name" toml:"name,omitempty"` diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index ee1a14f..0e60643 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -3,9 +3,7 @@ package clientinfo import ( "context" "fmt" - "net" "net/netip" - "strconv" "strings" "sync" "time" @@ -185,24 +183,9 @@ func (t *Table) init() { // PTR lookup. if t.discoverPTR() { t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} - if len(t.svcCfg.DiscoverPtrEndpoints) > 0 { - nss := make([]string, 0, len(t.svcCfg.DiscoverPtrEndpoints)) - for _, ns := range t.svcCfg.DiscoverPtrEndpoints { - host, port := ns, "53" - if h, p, err := net.SplitHostPort(ns); err == nil { - host, port = h, p - } - // Only use valid ip:port pair. - if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { - nss = append(nss, net.JoinHostPort(host, port)) - } else { - ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) - } - } - if len(nss) > 0 { - t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) - } + if r := t.svcCfg.PtrResolver(); r != nil { + ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR discover", t.svcCfg.DiscoverPtrEndpoints) + t.ptr.resolver = r } ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") if err := t.ptr.refresh(); err != nil { From 17f6d7a77b0b63f4ffed2737b3566919cf7b4ef5 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Sat, 18 Nov 2023 06:19:16 +0700 Subject: [PATCH 33/52] cmd/cli: notice writing default config in local mode --- cmd/cli/cli.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 8c2c3d4..8b2999f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1075,6 +1075,9 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { if err != nil { mainLog.Load().Fatal().Msgf("failed to get default config file path: %v", err) } + if cdUID == "" && nextdns == "" { + mainLog.Load().Notice().Msg("Generating controld default config: " + fp) + } mainLog.Load().Info().Msg("writing default config file to: " + fp) } return false From 28ec1869fc97743809184d6ec5fd7bcb3ef41f7a Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 21 Nov 2023 21:34:50 +0700 Subject: [PATCH 34/52] internal/router/merlin: hardening pre-run condition The postconf script added by ctrld requires all of these conditions to work correctly: - /proc, /tmp were mounted. - dnsmasq is running. Currently, ctrld is only waiting for NTP ready, which may not ensure both of those conditions are true. Explicitly checking those conditions is a safer approach. --- internal/router/merlin/merlin.go | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/internal/router/merlin/merlin.go b/internal/router/merlin/merlin.go index 84ebd1c..8b6a0fc 100644 --- a/internal/router/merlin/merlin.go +++ b/internal/router/merlin/merlin.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "strings" + "time" "unicode" "github.com/kardianos/service" @@ -44,8 +45,24 @@ func (m *Merlin) Uninstall(_ *service.Config) error { } func (m *Merlin) PreRun() error { + // Wait NTP ready. _ = m.Cleanup() - return ntp.WaitNvram() + if err := ntp.WaitNvram(); err != nil { + return err + } + // Wait until directories mounted. + for _, dir := range []string{"/tmp", "/proc"} { + waitDirExists(dir) + } + // Wait dnsmasq started. + for { + out, _ := exec.Command("pidof", "dnsmasq").CombinedOutput() + if len(bytes.TrimSpace(out)) > 0 { + break + } + time.Sleep(time.Second) + } + return nil } func (m *Merlin) Setup() error { @@ -56,9 +73,6 @@ func (m *Merlin) Setup() error { if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" { return nil } - if _, err := nvram.Run("set", nvram.CtrldSetupKey+"=1"); err != nil { - return err - } buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath) // Already setup. if bytes.Contains(buf, []byte(dnsmasq.MerlinPostConfMarker)) { @@ -140,3 +154,12 @@ func merlinParsePostConf(buf []byte) []byte { } return buf } + +func waitDirExists(dir string) { + for { + if _, err := os.Stat(dir); !os.IsNotExist(err) { + return + } + time.Sleep(time.Second) + } +} From 2bebe93e47396f7c7bcb414d3c660b0fd4204c89 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 22 Nov 2023 15:35:40 +0700 Subject: [PATCH 35/52] internal/router: do not disable cache on EdgeOS The dnsmasq cache-size setting on EdgeOS could be re-generated anytime by vyatta router/dhcp components. This conflicts with setting generated by ctrld, causing dnsmasq fails to start. It's better to keep dnsmasq cache enabled on EdgeOS, we can turn it off again once we find a reliable way to control cache-size setting. --- internal/router/dnsmasq/dnsmasq.go | 14 +++++++++++--- internal/router/edgeos/edgeos.go | 30 +++++++++--------------------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/internal/router/dnsmasq/dnsmasq.go b/internal/router/dnsmasq/dnsmasq.go index 52cf1a6..9fab8b6 100644 --- a/internal/router/dnsmasq/dnsmasq.go +++ b/internal/router/dnsmasq/dnsmasq.go @@ -22,7 +22,9 @@ server={{ .IP }}#{{ .Port }} add-mac add-subnet=32,128 {{- end}} +{{- if .CacheDisabled}} cache-size=0 +{{- end}} ` const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf" @@ -71,6 +73,10 @@ type Upstream struct { } func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) { + return ConfTmplWitchCacheDisabled(tmplText, cfg, true) +} + +func ConfTmplWitchCacheDisabled(tmplText string, cfg *ctrld.Config, cacheDisabled bool) (string, error) { listener := cfg.FirstListener() if listener == nil { return "", errors.New("missing listener") @@ -80,24 +86,26 @@ func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) { ip = "127.0.0.1" } upstreams := []Upstream{{IP: ip, Port: listener.Port}} - return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo()) + return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo(), cacheDisabled) } func FirewallaConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) { if lc := cfg.FirstListener(); lc != nil && (lc.IP == "0.0.0.0" || lc.IP == "") { - return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo()) + return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo(), true) } return ConfTmpl(tmplText, cfg) } -func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo bool) (string, error) { +func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo, cacheDisabled bool) (string, error) { tmpl := template.Must(template.New("").Parse(tmplText)) var to = &struct { SendClientInfo bool Upstreams []Upstream + CacheDisabled bool }{ SendClientInfo: sendClientInfo, Upstreams: upstreams, + CacheDisabled: cacheDisabled, } var sb strings.Builder if err := tmpl.Execute(&sb, to); err != nil { diff --git a/internal/router/edgeos/edgeos.go b/internal/router/edgeos/edgeos.go index 0552882..3e7003b 100644 --- a/internal/router/edgeos/edgeos.go +++ b/internal/router/edgeos/edgeos.go @@ -8,19 +8,18 @@ import ( "os/exec" "strings" - "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" + "github.com/kardianos/service" "github.com/Control-D-Inc/ctrld" - "github.com/kardianos/service" + "github.com/Control-D-Inc/ctrld/internal/router/dnsmasq" ) const ( - Name = "edgeos" - edgeOSDNSMasqDefaultConfigPath = "/etc/dnsmasq.conf" - edgeOSDNSMasqConfigPath = "/etc/dnsmasq.d/dnsmasq-zzz-ctrld.conf" - usgDNSMasqConfigPath = "/etc/dnsmasq.conf" - usgDNSMasqBackupConfigPath = "/etc/dnsmasq.conf.bak" - toggleContentFilteringLink = "https://community.ui.com/questions/UDM-Pro-disable-enable-DNS-filtering/e2cc4060-e56a-4139-b200-62d7f773ff8f" + Name = "edgeos" + edgeOSDNSMasqConfigPath = "/etc/dnsmasq.d/dnsmasq-zzz-ctrld.conf" + usgDNSMasqConfigPath = "/etc/dnsmasq.conf" + usgDNSMasqBackupConfigPath = "/etc/dnsmasq.conf.bak" + toggleContentFilteringLink = "https://community.ui.com/questions/UDM-Pro-disable-enable-DNS-filtering/e2cc4060-e56a-4139-b200-62d7f773ff8f" ) var ErrContentFilteringEnabled = fmt.Errorf(`the "Content Filtering" feature" is enabled, which is conflicted with ctrld.\n @@ -107,13 +106,10 @@ func (e *EdgeOS) setupUSG() error { if strings.HasPrefix(line, "all-servers") { continue } - if strings.HasPrefix(line, "cache-size") { - continue - } sb.WriteString(line) } - data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg) + data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false) if err != nil { return err } @@ -131,11 +127,7 @@ func (e *EdgeOS) setupUSG() error { } func (e *EdgeOS) setupUDM() error { - // Disable dnsmasq cache. - if err := dnsmasq.DisableCache(edgeOSDNSMasqDefaultConfigPath); err != nil { - return err - } - data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg) + data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false) if err != nil { return err } @@ -161,10 +153,6 @@ func (e *EdgeOS) cleanupUSG() error { } func (e *EdgeOS) cleanupUDM() error { - // Enable dnsmasq cache. - if err := dnsmasq.EnableCache(edgeOSDNSMasqDefaultConfigPath); err != nil { - return err - } // Remove the custom dnsmasq config if err := os.Remove(edgeOSDNSMasqConfigPath); err != nil { return fmt.Errorf("cleanupUDM: os.Remove: %w", err) From a2cb895cdc5436577b30dc40bbcb7c607a24db12 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 22 Nov 2023 19:07:53 +0700 Subject: [PATCH 36/52] cmd/cli: watch changes to /etc/resolv.conf On some routers, change to network may trigger re-rendering /etc/resolv.conf file, causing requests from router itself stop using ctrld. Fixing this by watching changes to /etc/resolv.conf, then revert them. --- cmd/cli/os_linux.go | 68 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/cmd/cli/os_linux.go b/cmd/cli/os_linux.go index 7fb692c..3036d03 100644 --- a/cmd/cli/os_linux.go +++ b/cmd/cli/os_linux.go @@ -9,10 +9,12 @@ import ( "net" "net/netip" "os/exec" + "path/filepath" "strings" "syscall" "time" + "github.com/fsnotify/fsnotify" "github.com/insomniacslk/dhcp/dhcpv4/nclient4" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/dhcpv6/client6" @@ -23,7 +25,10 @@ import ( "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) -const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system" +const ( + resolvConfPath = "/etc/resolv.conf" + resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system" +) // allocate loopback ip // sudo ip a add 127.0.0.2/24 dev lo @@ -64,6 +69,11 @@ func setDNS(iface *net.Interface, nameservers []string) error { Nameservers: ns, SearchDomains: []dnsname.FQDN{}, } + defer func() { + if r.Mode() == "direct" { + go watchResolveConf(osConfig) + } + }() trySystemdResolve := false for i := 0; i < maxSetDNSAttempts; i++ { @@ -299,3 +309,59 @@ func sliceIndex[S ~[]E, E comparable](s S, v E) int { } return -1 } + +// watchResolveConf watches any changes to /etc/resolv.conf file, +// and reverting to the original config set by ctrld. +func watchResolveConf(oc dns.OSConfig) { + mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file") + watcher, err := fsnotify.NewWatcher() + if err != nil { + mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf") + return + } + + // We watch /etc instead of /etc/resolv.conf directly, + // see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well + watchDir := filepath.Dir(resolvConfPath) + if err := watcher.Add(watchDir); err != nil { + mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list") + return + } + + r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter. + if err != nil { + mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator") + return + } + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Name != resolvConfPath { // skip if not /etc/resolv.conf changes. + continue + } + if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { + mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting") + if err := watcher.Remove(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to pause watcher") + continue + } + if err := r.SetDNS(oc); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes") + } + if err := watcher.Add(watchDir); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to continue running watcher") + return + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf") + } + } +} From f9a3f4c045d3cfaaa5cd51bf3053eab184a8f393 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 23 Nov 2023 23:56:49 +0700 Subject: [PATCH 37/52] Implement new flow for LAN and private PTR resolution - Use client info table. - If no sufficient data, use gateway/os/defined local upstreams. - If no data is returned, use remote upstream --- cmd/cli/dns_proxy.go | 134 +++++++++++++++++++++++++---- cmd/cli/dns_proxy_test.go | 4 +- cmd/cli/prog.go | 49 +++++++---- config.go | 76 ++++++++-------- config_internal_test.go | 55 ++++++++++++ docs/config.md | 34 ++++---- internal/clientinfo/client_info.go | 79 +++++++++++++---- internal/clientinfo/dhcp.go | 17 ++++ internal/clientinfo/hostsfile.go | 19 ++++ internal/clientinfo/mdns.go | 18 ++++ internal/clientinfo/ptr_lookup.go | 18 ++++ 11 files changed, 396 insertions(+), 107 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 06d0702..1be818f 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -17,6 +17,7 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/net/interfaces" "tailscale.com/net/netaddr" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/dnscache" @@ -25,6 +26,7 @@ import ( const ( staleTTL = 60 * time.Second + localTTL = 3600 * time.Second // EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option. // https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81 // This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification. @@ -81,7 +83,7 @@ func (p *prog) serveDNS(listenerNum string) error { if listenerConfig.Policy != nil { failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers } - answer = p.proxy(ctx, upstreams, failoverRcode, m, ci) + answer = p.proxy(ctx, upstreams, failoverRcode, m, ci, matched) rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } @@ -251,7 +253,7 @@ macRules: return upstreams, matched } -func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo) *dns.Msg { +func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo, matched bool) *dns.Msg { var staleAnswer *dns.Msg serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) @@ -259,11 +261,84 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} upstreams = []string{upstreamOS} } - if isPrivatePtrLookup(msg) { - ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup -> [%s]", upstreamOS) - upstreamConfigs = []*ctrld.UpstreamConfig{privateUpstreamConfig} - upstreams = []string{upstreamOS} + + // LAN/PTR lookup flow: + // + // 1. If there's matching rule, follow it. + // 2. Try from client info table. + // 3. Try private resolver. + // 4. Try remote upstream. + isLanOrPtrQuery := false + if !matched { + switch { + case isPrivatePtrLookup(msg): + isLanOrPtrQuery = true + ip := ipFromARPA(msg.Question[0].Name) + if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" { + answer := new(dns.Msg) + answer.SetReply(msg) + answer.Compress = true + answer.Answer = []dns.RR{&dns.PTR{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + }, + Ptr: dns.Fqdn(name), + }} + ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table") + ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + Mac: p.ciTable.LookupMac(ip.String()), + IP: ip.String(), + Hostname: name, + }) + return answer + } + upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams) + case isLanHostnameQuery(msg): + isLanOrPtrQuery = true + q := msg.Question[0] + hostname := strings.TrimSuffix(q.Name, ".") + if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil { + answer := new(dns.Msg) + answer.SetReply(msg) + answer.Compress = true + switch { + case ip.Is4(): + answer.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: uint32(localTTL.Seconds()), + }, + A: ip.AsSlice(), + }} + case ip.Is6(): + answer.Answer = []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: uint32(localTTL.Seconds()), + }, + AAAA: ip.AsSlice(), + }} + } + ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table") + ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + Mac: p.ciTable.LookupMac(ip.String()), + IP: ip.String(), + Hostname: hostname, + }) + return answer + } + upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) + ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams) + } } + // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { for _, upstream := range upstreams { @@ -285,12 +360,6 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) { ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) - if upstreamConfig.Type == ctrld.ResolverTypePrivate { - if r := p.ptrResolver; r != nil { - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR resolver", p.cfg.Service.DiscoverPtrEndpoints) - dnsResolver = r - } - } if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver") return nil, err @@ -344,6 +413,11 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i } continue } + // We are doing LAN/PTR lookup using private resolver, so always process next one. + // Except for the last, we want to send response instead of saying all upstream failed. + if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 { + continue + } if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) { ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream") continue @@ -352,7 +426,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i // set compression, as it is not set by default when unpacking answer.Compress = true - if p.cache != nil { + if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { ttl := ttlFromMsg(answer) now := time.Now() expired := now.Add(time.Duration(ttl) * time.Second) @@ -371,6 +445,16 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i return answer } +func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) { + if len(p.localUpstreams) > 0 { + tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams)) + tmp = append(tmp, p.localUpstreams...) + tmp = append(tmp, upstreams...) + return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp) + } + return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...) +} + func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { @@ -705,14 +789,34 @@ func ipFromARPA(arpa string) net.IP { return nil } -// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN network. +// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network. func isPrivatePtrLookup(m *dns.Msg) bool { if m == nil || len(m.Question) == 0 { return false } q := m.Question[0] if ip := ipFromARPA(q.Name); ip != nil { - return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() + if addr, ok := netip.AddrFromSlice(ip); ok { + return addr.IsPrivate() || + addr.IsLoopback() || + addr.IsLinkLocalUnicast() || + tsaddr.CGNATRange().Contains(addr) + } } return false } + +func isLanHostnameQuery(m *dns.Msg) bool { + if m == nil || len(m.Question) == 0 { + return false + } + q := m.Question[0] + switch q.Qtype { + case dns.TypeA, dns.TypeAAAA: + default: + return false + } + return !strings.Contains(q.Name, ".") || + strings.HasSuffix(q.Name, ".domain") || + strings.HasSuffix(q.Name, ".lan") +} diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index d0e5c74..70197ad 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -153,8 +153,8 @@ func TestCache(t *testing.T) { answer2.SetRcode(msg, dns.RcodeRefused) prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute))) - got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil) - got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil) + got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil, false) + got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil, false) assert.NotSame(t, got1, got2) assert.Equal(t, answer1.Rcode, got1.Rcode) assert.Equal(t, answer2.Rcode, got2.Rcode) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index fb88c81..f828426 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -19,6 +19,7 @@ import ( "github.com/kardianos/service" "github.com/spf13/viper" "tailscale.com/net/interfaces" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/clientinfo" @@ -32,6 +33,7 @@ const ( ctrldControlUnixSock = "ctrld_control.sock" upstreamPrefix = "upstream." upstreamOS = upstreamPrefix + "os" + upstreamPrivate = upstreamPrefix + "private" ) var logf = func(format string, args ...any) { @@ -55,14 +57,15 @@ type prog struct { logConn net.Conn cs *controlServer - cfg *ctrld.Config - appCallback *AppCallback - cache dnscache.Cacher - sema semaphore - ciTable *clientinfo.Table - um *upstreamMonitor - router router.Router - ptrResolver ctrld.Resolver + cfg *ctrld.Config + localUpstreams []string + ptrNameservers []string + appCallback *AppCallback + cache dnscache.Cacher + sema semaphore + ciTable *clientinfo.Table + um *upstreamMonitor + router router.Router loopMu sync.Mutex loop map[string]bool @@ -160,7 +163,7 @@ func (p *prog) runWait() { // This needs to be done here, otherwise, the DNS handler may observe an invalid // upstream config because its initialization function have not been called yet. mainLog.Load().Debug().Msg("setup upstream with new config") - setupUpstream(newCfg) + p.setupUpstream(newCfg) p.mu.Lock() *p.cfg = *newCfg @@ -187,7 +190,9 @@ func (p *prog) preRun() { } } -func setupUpstream(cfg *ctrld.Config) { +func (p *prog) setupUpstream(cfg *ctrld.Config) { + localUpstreams := make([]string, 0, len(cfg.Upstream)) + ptrNameservers := make([]string, 0, len(cfg.Upstream)) for n := range cfg.Upstream { uc := cfg.Upstream[n] uc.Init() @@ -199,7 +204,16 @@ func setupUpstream(cfg *ctrld.Config) { } uc.SetCertPool(rootCertPool) go uc.Ping() + + if canBeLocalUpstream(uc.Domain) { + localUpstreams = append(localUpstreams, upstreamPrefix+n) + } + if uc.IsDiscoverable() { + ptrNameservers = append(ptrNameservers, uc.Endpoint) + } } + p.localUpstreams = localUpstreams + p.ptrNameservers = ptrNameservers } // run runs the ctrld main components. @@ -230,9 +244,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.cache = cacher } } - if r := p.cfg.Service.PtrResolver(); r != nil { - p.ptrResolver = r - } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) @@ -260,8 +271,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { p.sema = &chanSemaphore{ready: make(chan struct{}, n)} } } - setupUpstream(p.cfg) - p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID) + p.setupUpstream(p.cfg) + p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers) if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" { mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile) format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat) @@ -613,3 +624,11 @@ func defaultRouteIP() string { mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP") return ip } + +// canBeLocalUpstream reports whether the IP address can be used as a local upstream. +func canBeLocalUpstream(addr string) bool { + if ip, err := netip.ParseAddr(addr); err == nil { + return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) + } + return false +} diff --git a/config.go b/config.go index 8b37078..1bd9043 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,7 @@ import ( "math/rand" "net" "net/http" + "net/netip" "net/url" "os" "runtime" @@ -26,6 +27,7 @@ import ( "github.com/spf13/viper" "golang.org/x/sync/singleflight" "tailscale.com/logtail/backoff" + "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld/internal/dnsrcode" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" @@ -177,46 +179,22 @@ func (c *Config) FirstUpstream() *UpstreamConfig { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` - LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` - CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` - CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` - CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` - CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` - MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` - DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` - DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` - DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` - DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"` - DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` - DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` - DiscoverPtrEndpoints []string `mapstructure:"discover_ptr_endpoints" toml:"discover_ptr_endpoints,omitempty"` - DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` - Daemon bool `mapstructure:"-" toml:"-"` - AllocateIP bool `mapstructure:"-" toml:"-"` -} - -// PtrResolver returns a Resolver used for PTR lookup, based on ServiceConfig.DiscoverPtrEndpoints value. -func (s ServiceConfig) PtrResolver() Resolver { - if len(s.DiscoverPtrEndpoints) > 0 { - nss := make([]string, 0, len(s.DiscoverPtrEndpoints)) - for _, ns := range s.DiscoverPtrEndpoints { - host, port := ns, "53" - if h, p, err := net.SplitHostPort(ns); err == nil { - host, port = h, p - } - // Only use valid ip:port pair. - if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { - nss = append(nss, net.JoinHostPort(host, port)) - } else { - ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for PTR resolver: %q", ns) - } - } - if len(nss) > 0 { - return NewResolverWithNameserver(nss) - } - } - return nil + LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` + LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` + MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` + DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"` + DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"` + DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"` + DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"` + DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` + DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` + DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` + Daemon bool `mapstructure:"-" toml:"-"` + AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. @@ -238,6 +216,9 @@ type UpstreamConfig struct { // The caller should not access this field directly. // Use UpstreamSendClientInfo instead. SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"` + // The caller should not access this field directly. + // Use IsDiscoverable instead. + Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"` g singleflight.Group rebootstrap atomic.Bool @@ -364,6 +345,21 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool { return false } +// IsDiscoverable reports whether the upstream can be used for PTR discovery. +// The caller must ensure uc.Init() was called before calling this. +func (uc *UpstreamConfig) IsDiscoverable() bool { + if uc.Discoverable != nil { + return *uc.Discoverable + } + switch uc.Type { + case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate: + if ip, err := netip.ParseAddr(uc.Domain); err == nil { + return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip) + } + } + return false +} + // BootstrapIPs returns the bootstrap IPs list of upstreams. func (uc *UpstreamConfig) BootstrapIPs() []string { return uc.bootstrapIPs diff --git a/config_internal_test.go b/config_internal_test.go index 89cec19..96beddc 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -279,6 +279,61 @@ func TestUpstreamConfig_UpstreamSendClientInfo(t *testing.T) { } } +func TestUpstreamConfig_IsDiscoverable(t *testing.T) { + tests := []struct { + name string + uc *UpstreamConfig + discoverable bool + }{ + { + "loopback", + &UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy}, + true, + }, + { + "rfc1918", + &UpstreamConfig{Endpoint: "192.168.1.1", Type: ResolverTypeLegacy}, + true, + }, + { + "CGNAT", + &UpstreamConfig{Endpoint: "100.66.67.68", Type: ResolverTypeLegacy}, + true, + }, + { + "Public IP", + &UpstreamConfig{Endpoint: "8.8.8.8", Type: ResolverTypeLegacy}, + false, + }, + { + "override discoverable", + &UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(false)}, + false, + }, + { + "override non-public", + &UpstreamConfig{Endpoint: "1.1.1.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(true)}, + true, + }, + { + "non-legacy upstream", + &UpstreamConfig{Endpoint: "https://192.168.1.1/custom-doh", Type: ResolverTypeDOH}, + false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tc.uc.Init() + if got := tc.uc.IsDiscoverable(); got != tc.discoverable { + t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got) + } + }) + } +} + func ptrBool(b bool) *bool { return &b } diff --git a/docs/config.md b/docs/config.md index 57a794b..bc0ead2 100644 --- a/docs/config.md +++ b/docs/config.md @@ -193,22 +193,6 @@ Perform LAN client discovery using PTR queries. - Required: no - Default: true -### discover_ptr_endpoints -List of DNS nameservers used for PTR discovery. - -Each entry can be either "ip" (default port 53) or "ip:port" pair. Invalid entry will be ignored. - -- Type: array of string -- Required: no -- Default: [] - -Example: - -```toml -[service] -discover_ptr_endpoints = ["192.168.1.1", "192.168.2.1:5354"] -``` - ### discover_hosts Perform LAN client discovery using hosts file. @@ -335,6 +319,24 @@ If `ip_stack` is empty, or undefined: - Default value is `both` for non-Control D resolvers. - Default value is `split` for Control D resolvers. +### send_client_info +Specifying whether to include client info when sending query to upstream. + +- Type: boolean +- Required: no +- Default: + - `true` for ControlD upstreams. + - `false` for other upstreams. + +### discoverable +Specifying whether the upstream can be used for PTR discovery. + +- Type: boolean +- Required: no +- Default: + - `true` for loopback/RFC1918/CGNAT IP address. + - `false` for public IP address. + ## Network The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs. diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index 0e60643..07e4cf0 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -3,7 +3,9 @@ package clientinfo import ( "context" "fmt" + "net" "net/netip" + "strconv" "strings" "sync" "time" @@ -69,25 +71,27 @@ type Table struct { refreshers []refresher initOnce sync.Once - dhcp *dhcp - merlin *merlinDiscover - arp *arpDiscover - ptr *ptrDiscover - mdns *mdns - hf *hostsFile - vni *virtualNetworkIface - svcCfg ctrld.ServiceConfig - quitCh chan struct{} - selfIP string - cdUID string + dhcp *dhcp + merlin *merlinDiscover + arp *arpDiscover + ptr *ptrDiscover + mdns *mdns + hf *hostsFile + vni *virtualNetworkIface + svcCfg ctrld.ServiceConfig + quitCh chan struct{} + selfIP string + cdUID string + ptrNameservers []string } -func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table { +func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table { return &Table{ - svcCfg: cfg.Service, - quitCh: make(chan struct{}), - selfIP: selfIP, - cdUID: cdUID, + svcCfg: cfg.Service, + quitCh: make(chan struct{}), + selfIP: selfIP, + cdUID: cdUID, + ptrNameservers: ns, } } @@ -183,9 +187,25 @@ func (t *Table) init() { // PTR lookup. if t.discoverPTR() { t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()} - if r := t.svcCfg.PtrResolver(); r != nil { - ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for PTR discover", t.svcCfg.DiscoverPtrEndpoints) - t.ptr.resolver = r + if len(t.ptrNameservers) > 0 { + nss := make([]string, 0, len(t.ptrNameservers)) + for _, ns := range t.ptrNameservers { + host, port := ns, "53" + if h, p, err := net.SplitHostPort(ns); err == nil { + host, port = h, p + } + // Only use valid ip:port pair. + if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil { + nss = append(nss, net.JoinHostPort(host, port)) + } else { + ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns) + } + } + if len(nss) > 0 { + t.ptr.resolver = ctrld.NewResolverWithNameserver(nss) + ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss) + } + } ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery") if err := t.ptr.refresh(); err != nil { @@ -358,6 +378,27 @@ func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) { t.vni.ip2name.Store(ci.IP, ci.Hostname) } +// ipFinder is the interface for retrieving IP address from hostname. +type ipFinder interface { + lookupIPByHostname(name string, v6 bool) string +} + +// LookupIPByHostname returns the ip address of given hostname. +// If v6 is true, return IPv6 instead of default IPv4. +func (t *Table) LookupIPByHostname(hostname string, v6 bool) *netip.Addr { + if t == nil { + return nil + } + for _, finder := range []ipFinder{t.hf, t.ptr, t.mdns, t.dhcp} { + if addr := finder.lookupIPByHostname(hostname, v6); addr != "" { + if ip, err := netip.ParseAddr(addr); err == nil { + return &ip + } + } + } + return nil +} + func (t *Table) discoverDHCP() bool { if t.svcCfg.DiscoverDHCP == nil { return true diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index 7c1b2cf..e036638 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -134,6 +134,23 @@ func (d *dhcp) List() []string { return ips } +func (d *dhcp) lookupIPByHostname(name string, v6 bool) string { + if d == nil { + return "" + } + var ip string + d.ip2name.Range(func(key, value any) bool { + if value == name { + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + ip = addr.String() + return false + } + } + return true + }) + return ip +} + // AddLeaseFile adds given lease file for reading/watching clients info. func (d *dhcp) addLeaseFile(name string, format ctrld.LeaseFileFormat) error { if d.watcher == nil { diff --git a/internal/clientinfo/hostsfile.go b/internal/clientinfo/hostsfile.go index baf05fb..8c86987 100644 --- a/internal/clientinfo/hostsfile.go +++ b/internal/clientinfo/hostsfile.go @@ -1,6 +1,7 @@ package clientinfo import ( + "net/netip" "os" "sync" @@ -109,6 +110,24 @@ func (hf *hostsFile) String() string { return "hosts" } +func (hf *hostsFile) lookupIPByHostname(name string, v6 bool) string { + if hf == nil { + return "" + } + hf.mu.Lock() + defer hf.mu.Unlock() + for addr, names := range hf.m { + if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLoopback() { + for _, n := range names { + if n == name && ip.Is6() == v6 { + return ip.String() + } + } + } + } + return "" +} + // isLocalhostName reports whether the given hostname represents localhost. func isLocalhostName(hostname string) bool { switch hostname { diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index 9a5fa85..59e6e9c 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "os" "sync" "syscall" @@ -59,6 +60,23 @@ func (m *mdns) List() []string { return ips } +func (m *mdns) lookupIPByHostname(name string, v6 bool) string { + if m == nil { + return "" + } + var ip string + m.name.Range(func(key, value any) bool { + if value == name { + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + ip = addr.String() + return false + } + } + return true + }) + return ip +} + func (m *mdns) init(quitCh chan struct{}) error { ifaces, err := multicastInterfaces() if err != nil { diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index fea79fb..b6204d5 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -2,6 +2,7 @@ package clientinfo import ( "context" + "net/netip" "sync" "sync/atomic" "time" @@ -94,6 +95,23 @@ func (p *ptrDiscover) lookupHostname(ip string) string { return "" } +func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { + if p == nil { + return "" + } + var ip string + p.hostname.Range(func(key, value any) bool { + if value == name { + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + ip = addr.String() + return false + } + } + return true + }) + return ip +} + // checkServer monitors if the resolver can reach its nameserver. When the nameserver // is reachable, set p.serverDown to false, so p.lookupHostname can continue working. func (p *ptrDiscover) checkServer() { From 5897c174d31ad2ccd5048835c64d9428fc6969e0 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 4 Dec 2023 18:39:02 +0700 Subject: [PATCH 38/52] all: fix LAN hostname checking condition The LAN hostname in question is FQDN, "." suffix must be trimmed before checking. While at it, also add tests for LAN/PTR query checking functions. --- cmd/cli/dns_proxy.go | 8 ++-- cmd/cli/dns_proxy_test.go | 67 +++++++++++++++++++++++++++++++ internal/clientinfo/dhcp.go | 3 ++ internal/clientinfo/mdns.go | 3 ++ internal/clientinfo/ptr_lookup.go | 3 ++ 5 files changed, 81 insertions(+), 3 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 1be818f..3c63782 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -806,6 +806,7 @@ func isPrivatePtrLookup(m *dns.Msg) bool { return false } +// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname. func isLanHostnameQuery(m *dns.Msg) bool { if m == nil || len(m.Question) == 0 { return false @@ -816,7 +817,8 @@ func isLanHostnameQuery(m *dns.Msg) bool { default: return false } - return !strings.Contains(q.Name, ".") || - strings.HasSuffix(q.Name, ".domain") || - strings.HasSuffix(q.Name, ".lan") + name := strings.TrimSuffix(q.Name, ".") + return !strings.Contains(name, ".") || + strings.HasSuffix(name, ".domain") || + strings.HasSuffix(name, ".lan") } diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 70197ad..118914a 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -274,6 +274,7 @@ func newDnsMsgWithClientIP(ip string) *dns.Msg { m.Extra = append(m.Extra, o) return m } + func Test_stripClientSubnet(t *testing.T) { tests := []struct { name string @@ -307,3 +308,69 @@ func Test_stripClientSubnet(t *testing.T) { }) } } + +func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(hostname, typ) + return m +} + +func Test_isLanHostnameQuery(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + isLanHostnameQuery bool + }{ + {"A", newDnsMsgWithHostname("foo", dns.TypeA), true}, + {"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true}, + {"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false}, + {"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false}, + {"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got) + } + }) + } +} + +func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg { + t.Helper() + m := new(dns.Msg) + ptr, err := dns.ReverseAddr(ip) + if err != nil { + t.Fatal(err) + } + m.SetQuestion(ptr, dns.TypePTR) + return m +} + +func Test_isPrivatePtrLookup(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + isPrivatePtrLookup bool + }{ + // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as + {"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true}, + {"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true}, + {"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true}, + {"CGNAT", newDnsMsgPtr("100.66.27.28", t), true}, + {"Loopback", newDnsMsgPtr("127.0.0.1", t), true}, + {"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true}, + {"Public IP", newDnsMsgPtr("8.8.8.8", t), false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got) + } + }) + } +} diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index e036638..a103263 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -143,6 +143,9 @@ func (d *dhcp) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() + if addr.IsLoopback() { // Continue searching if this is loopback address. + return true + } return false } } diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index 59e6e9c..f89e13f 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -69,6 +69,9 @@ func (m *mdns) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() + if addr.IsLoopback() { // Continue searching if this is loopback address. + return true + } return false } } diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index b6204d5..1439752 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -104,6 +104,9 @@ func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() + if addr.IsLoopback() { // Continue searching if this is loopback address. + return true + } return false } } From c3ff8182affd6598b7cab15fed4ee8feeda4229a Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 5 Dec 2023 01:29:31 +0700 Subject: [PATCH 39/52] all: ignoring local interfaces RFC1918 IP for private resolver Otherwises, the discovery may make a looping with new PTR query flow. --- cmd/cli/dns_proxy.go | 17 +---------------- cmd/cli/prog.go | 2 +- resolver.go | 28 +++++++++++++++++++++++++++- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 3c63782..e2477f2 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -114,7 +114,7 @@ func (p *prog) serveDNS(listenerNum string) error { // addresses of the machine. So ctrld could receive queries from LAN clients. if needRFC1918Listeners(listenerConfig) { g.Go(func() error { - for _, addr := range rfc1918Addresses() { + for _, addr := range ctrld.Rfc1918Addresses() { func() { listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port)) s, errCh := runDNSServer(listenAddr, proto, handler) @@ -737,21 +737,6 @@ func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool { return lc.IP == "127.0.0.1" && lc.Port == 53 } -func rfc1918Addresses() []string { - var res []string - interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { - addrs, _ := i.Addrs() - for _, addr := range addrs { - ipNet, ok := addr.(*net.IPNet) - if !ok || !ipNet.IP.IsPrivate() { - continue - } - res = append(res, ipNet.IP.String()) - } - }) - return res -} - // ipFromARPA parses a FQDN arpa domain and return the IP address if valid. func ipFromARPA(arpa string) net.IP { if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok { diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index f828426..9fcb42f 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -438,7 +438,7 @@ func (p *prog) setDNS() { nameservers := []string{ns} if needRFC1918Listeners(lc) { - nameservers = append(nameservers, rfc1918Addresses()...) + nameservers = append(nameservers, ctrld.Rfc1918Addresses()...) } if err := setDNS(netIface, nameservers); err != nil { logger.Error().Err(err).Msgf("could not set DNS for interface") diff --git a/resolver.go b/resolver.go index 8fd26cf..750679c 100644 --- a/resolver.go +++ b/resolver.go @@ -5,10 +5,12 @@ import ( "errors" "fmt" "net" + "net/netip" "sync" "time" "github.com/miekg/dns" + "tailscale.com/net/interfaces" ) const ( @@ -245,12 +247,16 @@ func NewBootstrapResolver(servers ...string) Resolver { } // NewPrivateResolver returns an OS resolver, which includes only private DNS servers, -// excluding nameservers from /etc/resolv.conf file. +// excluding: +// +// - Nameservers from /etc/resolv.conf file. +// - Nameservers which is local RFC1918 addresses. // // This is useful for doing PTR lookup in LAN network. func NewPrivateResolver() Resolver { nss := nameservers() resolveConfNss := nameserversFromResolvconf() + localRfc1918Addrs := Rfc1918Addresses() n := 0 for _, ns := range nss { host, _, _ := net.SplitHostPort(ns) @@ -263,6 +269,10 @@ func NewPrivateResolver() Resolver { if sliceContains(resolveConfNss, host) { continue } + // Ignoring local RFC 1918 addresses. + if sliceContains(localRfc1918Addrs, host) { + continue + } ip := net.ParseIP(host) if ip != nil && ip.IsPrivate() && !ip.IsLoopback() { nss[n] = ns @@ -285,6 +295,22 @@ func NewResolverWithNameserver(nameservers []string) Resolver { return &osResolver{nameservers: nameservers} } +// Rfc1918Addresses returns the list of local interfaces private IP addresses +func Rfc1918Addresses() []string { + var res []string + interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { + addrs, _ := i.Addrs() + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok || !ipNet.IP.IsPrivate() { + continue + } + res = append(res, ipNet.IP.String()) + } + }) + return res +} + func newDialer(dnsAddress string) *net.Dialer { return &net.Dialer{ Resolver: &net.Resolver{ From 7591a0ccc672d81c0fc6f9372d5480619179c224 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 5 Dec 2023 21:58:34 +0700 Subject: [PATCH 40/52] all: add client id preference config param So client can chose how client id is generated. --- client_info.go | 9 +++++---- cmd/cli/dns_proxy.go | 1 + config.go | 1 + config_test.go | 7 +++++++ docs/config.md | 11 +++++++++++ doh.go | 17 ++++++++++++----- 6 files changed, 37 insertions(+), 9 deletions(-) diff --git a/client_info.go b/client_info.go index f32526a..05d2910 100644 --- a/client_info.go +++ b/client_info.go @@ -5,10 +5,11 @@ type ClientInfoCtxKey struct{} // ClientInfo represents ctrld's clients information. type ClientInfo struct { - Mac string - IP string - Hostname string - Self bool + Mac string + IP string + Hostname string + Self bool + ClientIDPref string } // LeaseFileFormat specifies the format of DHCP lease file. diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index e2477f2..26f3931 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -67,6 +67,7 @@ func (p *prog) serveDNS(listenerNum string) error { reqId := requestID() remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) + ci.ClientIDPref = p.cfg.Service.ClientIDPref stripClientSubnet(m) remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String()) diff --git a/config.go b/config.go index 1bd9043..d3509be 100644 --- a/config.go +++ b/config.go @@ -193,6 +193,7 @@ type ServiceConfig struct { DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"` DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"` DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"` + ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"` Daemon bool `mapstructure:"-" toml:"-"` AllocateIP bool `mapstructure:"-" toml:"-"` } diff --git a/config_test.go b/config_test.go index 4123f00..ff20bc2 100644 --- a/config_test.go +++ b/config_test.go @@ -101,6 +101,7 @@ func TestConfigValidation(t *testing.T) { {"lease file format required if lease file exist", configWithExistedLeaseFile(t), true}, {"invalid lease file format", configWithInvalidLeaseFileFormat(t), true}, {"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true}, + {"invalid client id pref", configWithInvalidClientIDPref(t), true}, } for _, tc := range tests { @@ -238,3 +239,9 @@ func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config { cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH return cfg } + +func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Service.ClientIDPref = "foo" + return cfg +} diff --git a/docs/config.md b/docs/config.md index bc0ead2..266c17a 100644 --- a/docs/config.md +++ b/docs/config.md @@ -215,6 +215,17 @@ DHCP leases file format. - Valid values: `dnsmasq`, `isc-dhcp` - Default: "" +### client_id_preference +Decide how client ID has is generated. + +If `host` -> client id will be a `hash(hostname)`. +If `mac` -> client id will be `hash(mac)`. + +- Type: string +- Required: no +- Valid values: `mac`, `host` +- Default: "" + ## Upstream The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to. diff --git a/doh.go b/doh.go index 96f8051..25ed2cb 100644 --- a/doh.go +++ b/doh.go @@ -18,11 +18,12 @@ import ( ) const ( - dohMacHeader = "x-cd-mac" - dohIPHeader = "x-cd-ip" - dohHostHeader = "x-cd-host" - dohOsHeader = "x-cd-os" - headerApplicationDNS = "application/dns-message" + dohMacHeader = "x-cd-mac" + dohIPHeader = "x-cd-ip" + dohHostHeader = "x-cd-host" + dohOsHeader = "x-cd-os" + dohClientIDPrefHeader = "x-cd-cpref" + headerApplicationDNS = "application/dns-message" ) // EncodeOsNameMap provides mapping from OS name to a shorter string, used for encoding x-cd-os value. @@ -181,6 +182,12 @@ func addControlDHeaders(req *http.Request, ci *ClientInfo) { if ci.Self { req.Header.Set(dohOsHeader, dohOsHeaderValue()) } + switch ci.ClientIDPref { + case "mac": + req.Header.Set(dohClientIDPrefHeader, "1") + case "host": + req.Header.Set(dohClientIDPrefHeader, "2") + } } // addNextDNSHeaders set DoH/Doh3 HTTP request headers for nextdns upstream. From 8939debbc05a82f3db06d4ecf4957184417f6428 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 5 Dec 2023 23:51:00 +0700 Subject: [PATCH 41/52] cmd/cli: do not send test query to external upstreams --- cmd/cli/loop.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index a9d3972..ec25840 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -60,6 +60,11 @@ func (p *prog) checkDnsLoop() { if p.um.isDown("upstream." + n) { continue } + // Do not send test query to external upstream. + if !canBeLocalUpstream(uc.Domain) { + mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n) + continue + } uid := uc.UID() p.loop[uid] = false upstream[uid] = uc From af2c1c87e080781512d5e9ef1755bd7a740827ac Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 5 Dec 2023 19:52:11 +0700 Subject: [PATCH 42/52] cmd/cli: improve logging for new LAN/PTR flow --- cmd/cli/dns_proxy.go | 206 +++++++++++++++++++++++--------------- cmd/cli/dns_proxy_test.go | 59 ++++++++--- 2 files changed, 167 insertions(+), 98 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 26f3931..9513e45 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -45,6 +45,23 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{ Timeout: 2000, } +// proxyRequest contains data for proxying a DNS query to upstream. +type proxyRequest struct { + msg *dns.Msg + ci *ctrld.ClientInfo + failoverRcodes []int + ufr *upstreamForResult +} + +// upstreamForResult represents the result of processing rules for a request. +type upstreamForResult struct { + upstreams []string + matchedPolicy string + matchedNetwork string + matchedRule string + matched bool +} + func (p *prog) serveDNS(listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated @@ -74,9 +91,10 @@ func (p *prog) serveDNS(listenerNum string) error { t := time.Now() ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) ctrld.Log(ctx, mainLog.Load().Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) - upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) + res := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) var answer *dns.Msg - if !matched && listenerConfig.Restricted { + if !res.matched && listenerConfig.Restricted { + ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String()) answer = new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) } else { @@ -84,7 +102,12 @@ func (p *prog) serveDNS(listenerNum string) error { if listenerConfig.Policy != nil { failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers } - answer = p.proxy(ctx, upstreams, failoverRcode, m, ci, matched) + answer = p.proxy(ctx, &proxyRequest{ + msg: m, + ci: ci, + failoverRcodes: failoverRcode, + ufr: res, + }) rtt := time.Since(t) ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } @@ -162,27 +185,24 @@ func (p *prog) serveDNS(listenerNum string) error { // Though domain policy has higher priority than network policy, it is still // processed later, because policy logging want to know whether a network rule // is disregarded in favor of the domain level rule. -func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) ([]string, bool) { +func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) { upstreams := []string{upstreamPrefix + defaultUpstreamNum} matchedPolicy := "no policy" matchedNetwork := "no network" matchedRule := "no rule" matched := false + res = &upstreamForResult{} defer func() { - if !matched && lc.Restricted { - ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", addr.String()) - return - } - if matched { - ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams) - } else { - ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams) - } + res.upstreams = upstreams + res.matched = matched + res.matchedPolicy = matchedPolicy + res.matchedNetwork = matchedNetwork + res.matchedRule = matchedRule }() if lc.Policy == nil { - return upstreams, false + return } do := func(policyUpstreams []string) { @@ -242,7 +262,7 @@ macRules: matchedRule = source do(targets) matched = true - return upstreams, matched + return } } } @@ -251,11 +271,77 @@ macRules: do(networkTargets) } - return upstreams, matched + return } -func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo, matched bool) *dns.Msg { +func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg { + ip := ipFromARPA(msg.Question[0].Name) + if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" { + answer := new(dns.Msg) + answer.SetReply(msg) + answer.Compress = true + answer.Answer = []dns.RR{&dns.PTR{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + }, + Ptr: dns.Fqdn(name), + }} + ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table") + ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + Mac: p.ciTable.LookupMac(ip.String()), + IP: ip.String(), + Hostname: name, + }) + return answer + } + return nil +} + +func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg { + q := msg.Question[0] + hostname := strings.TrimSuffix(q.Name, ".") + if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil { + answer := new(dns.Msg) + answer.SetReply(msg) + answer.Compress = true + switch { + case ip.Is4(): + answer.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: uint32(localTTL.Seconds()), + }, + A: ip.AsSlice(), + }} + case ip.Is6(): + answer.Answer = []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: uint32(localTTL.Seconds()), + }, + AAAA: ip.AsSlice(), + }} + } + ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table") + ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ + Mac: p.ciTable.LookupMac(ip.String()), + IP: ip.String(), + Hostname: hostname, + }) + return answer + } + return nil +} + +func (p *prog) proxy(ctx context.Context, req *proxyRequest) *dns.Msg { var staleAnswer *dns.Msg + upstreams := req.ufr.upstreams serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) if len(upstreamConfigs) == 0 { @@ -270,85 +356,38 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i // 3. Try private resolver. // 4. Try remote upstream. isLanOrPtrQuery := false - if !matched { + if req.ufr.matched { + ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams) + } else { switch { - case isPrivatePtrLookup(msg): + case isPrivatePtrLookup(req.msg): isLanOrPtrQuery = true - ip := ipFromARPA(msg.Question[0].Name) - if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" { - answer := new(dns.Msg) - answer.SetReply(msg) - answer.Compress = true - answer.Answer = []dns.RR{&dns.PTR{ - Hdr: dns.RR_Header{ - Name: msg.Question[0].Name, - Rrtype: dns.TypePTR, - Class: dns.ClassINET, - }, - Ptr: dns.Fqdn(name), - }} - ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table") - ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ - Mac: p.ciTable.LookupMac(ip.String()), - IP: ip.String(), - Hostname: name, - }) + if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil { return answer } upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams) - case isLanHostnameQuery(msg): + case isLanHostnameQuery(req.msg): isLanOrPtrQuery = true - q := msg.Question[0] - hostname := strings.TrimSuffix(q.Name, ".") - if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil { - answer := new(dns.Msg) - answer.SetReply(msg) - answer.Compress = true - switch { - case ip.Is4(): - answer.Answer = []dns.RR{&dns.A{ - Hdr: dns.RR_Header{ - Name: msg.Question[0].Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: uint32(localTTL.Seconds()), - }, - A: ip.AsSlice(), - }} - case ip.Is6(): - answer.Answer = []dns.RR{&dns.AAAA{ - Hdr: dns.RR_Header{ - Name: msg.Question[0].Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: uint32(localTTL.Seconds()), - }, - AAAA: ip.AsSlice(), - }} - } - ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table") - ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{ - Mac: p.ciTable.LookupMac(ip.String()), - IP: ip.String(), - Hostname: hostname, - }) + if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil { return answer } upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs) ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams) + default: + ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams) } } // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 - if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { + if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { for _, upstream := range upstreams { - cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream)) + cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream)) if cachedValue == nil { continue } answer := cachedValue.Msg.Copy() - answer.SetRcode(msg, answer.Rcode) + answer.SetRcode(req.msg, answer.Rcode) now := time.Now() if cachedValue.Expire.After(now) { ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response") @@ -375,9 +414,9 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i return dnsResolver.Resolve(resolveCtx, msg) } resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { - if upstreamConfig.UpstreamSendClientInfo() && ci != nil { + if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request") - ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci) + ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci) } answer, err := resolve1(n, upstreamConfig, msg) if err != nil { @@ -404,7 +443,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n]) continue } - answer := resolve(n, upstreamConfig, msg) + answer := resolve(n, upstreamConfig, req.msg) if answer == nil { if serveStaleCache && staleAnswer != nil { ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response") @@ -417,9 +456,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i // We are doing LAN/PTR lookup using private resolver, so always process next one. // Except for the last, we want to send response instead of saying all upstream failed. if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 { + ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n]) continue } - if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) { + if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) { ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream") continue } @@ -427,7 +467,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i // set compression, as it is not set by default when unpacking answer.Compress = true - if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { + if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR { ttl := ttlFromMsg(answer) now := time.Now() expired := now.Add(time.Duration(ttl) * time.Second) @@ -435,14 +475,14 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i expired = now.Add(time.Duration(cachedTTL) * time.Second) } setCachedAnswerTTL(answer, now, expired) - p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired)) + p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired)) ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response") } return answer } ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) answer := new(dns.Msg) - answer.SetRcode(msg, dns.RcodeServerFailure) + answer.SetRcode(req.msg, dns.RcodeServerFailure) return answer } diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 118914a..281d59c 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -67,8 +67,9 @@ func Test_canonicalName(t *testing.T) { func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) - prog := &prog{cfg: cfg} - for _, nc := range prog.cfg.Network { + p := &prog{cfg: cfg} + p.um = newUpstreamMonitor(p.cfg) + for _, nc := range p.cfg.Network { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { @@ -89,14 +90,14 @@ func Test_prog_upstreamFor(t *testing.T) { matched bool testLogMsg string }{ - {"Policy map matches", "192.168.0.1:0", "", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""}, - {"Policy split matches", "192.168.0.1:0", "", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""}, - {"Policy map for other network matches", "192.168.1.2:0", "", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""}, - {"No policy map for listener", "192.168.1.2:0", "", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""}, - {"unenforced loging", "192.168.1.2:0", "", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"}, - {"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"}, - {"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"}, - {"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"}, + {"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""}, + {"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""}, + {"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""}, + {"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""}, + {"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"}, + {"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"}, + {"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"}, + {"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"}, } for _, tc := range tests { @@ -115,9 +116,13 @@ func Test_prog_upstreamFor(t *testing.T) { require.NoError(t, err) require.NotNil(t, addr) ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID()) - upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain) - assert.Equal(t, tc.matched, matched) - assert.Equal(t, tc.upstreams, upstreams) + ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain) + p.proxy(ctx, &proxyRequest{ + msg: newDnsMsgWithHostname("foo", dns.TypeA), + ufr: ufr, + }) + assert.Equal(t, tc.matched, ufr.matched) + assert.Equal(t, tc.upstreams, ufr.upstreams) if tc.testLogMsg != "" { assert.Contains(t, logOutput.String(), tc.testLogMsg) } @@ -153,8 +158,32 @@ func TestCache(t *testing.T) { answer2.SetRcode(msg, dns.RcodeRefused) prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute))) - got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil, false) - got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil, false) + req1 := &proxyRequest{ + msg: msg, + ci: nil, + failoverRcodes: nil, + ufr: &upstreamForResult{ + upstreams: []string{"upstream.1"}, + matchedPolicy: "", + matchedNetwork: "", + matchedRule: "", + matched: false, + }, + } + req2 := &proxyRequest{ + msg: msg, + ci: nil, + failoverRcodes: nil, + ufr: &upstreamForResult{ + upstreams: []string{"upstream.0"}, + matchedPolicy: "", + matchedNetwork: "", + matchedRule: "", + matched: false, + }, + } + got1 := prog.proxy(context.Background(), req1) + got2 := prog.proxy(context.Background(), req2) assert.NotSame(t, got1, got2) assert.Equal(t, answer1.Rcode, got1.Rcode) assert.Equal(t, answer2.Rcode, got2.Rcode) From 0bb51aa71d7845a279191d0c1a08f0c7d87de5af Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 5 Dec 2023 20:24:12 +0700 Subject: [PATCH 43/52] cmd/cli: add loop guard for LAN/PTR queries --- cmd/cli/dns_proxy.go | 13 ++++++++++++- cmd/cli/dns_proxy_test.go | 2 ++ cmd/cli/loop.go | 31 +++++++++++++++++++++++++++++++ cmd/cli/loop_test.go | 38 ++++++++++++++++++++++++++++++++++++++ cmd/cli/prog.go | 4 ++++ 5 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 cmd/cli/loop_test.go diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 9513e45..080bebc 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -275,7 +275,13 @@ macRules: } func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg { - ip := ipFromARPA(msg.Question[0].Name) + cDomainName := msg.Question[0].Name + locked := p.ptrLoopGuard.TryLock(cDomainName) + defer p.ptrLoopGuard.Unlock(cDomainName) + if !locked { + return nil + } + ip := ipFromARPA(cDomainName) if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" { answer := new(dns.Msg) answer.SetReply(msg) @@ -302,6 +308,11 @@ func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg { q := msg.Question[0] hostname := strings.TrimSuffix(q.Name, ".") + locked := p.lanLoopGuard.TryLock(hostname) + defer p.lanLoopGuard.Unlock(hostname) + if !locked { + return nil + } if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil { answer := new(dns.Msg) answer.SetReply(msg) diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 281d59c..82c4f63 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -69,6 +69,8 @@ func Test_prog_upstreamFor(t *testing.T) { cfg := testhelper.SampleConfig(t) p := &prog{cfg: cfg} p.um = newUpstreamMonitor(p.cfg) + p.lanLoopGuard = newLoopGuard() + p.ptrLoopGuard = newLoopGuard() for _, nc := range p.cfg.Network { for _, cidr := range nc.Cidrs { _, ipNet, err := net.ParseCIDR(cidr) diff --git a/cmd/cli/loop.go b/cmd/cli/loop.go index ec25840..06a7e03 100644 --- a/cmd/cli/loop.go +++ b/cmd/cli/loop.go @@ -3,6 +3,7 @@ package cli import ( "context" "strings" + "sync" "time" "github.com/miekg/dns" @@ -15,6 +16,36 @@ const ( loopTestQtype = dns.TypeTXT ) +// newLoopGuard returns new loopGuard. +func newLoopGuard() *loopGuard { + return &loopGuard{inflight: make(map[string]struct{})} +} + +// loopGuard guards against DNS loop, ensuring only one query +// for a given domain is processed at a time. +type loopGuard struct { + mu sync.Mutex + inflight map[string]struct{} +} + +// TryLock marks the domain as being processed. +func (lg *loopGuard) TryLock(domain string) bool { + lg.mu.Lock() + defer lg.mu.Unlock() + if _, inflight := lg.inflight[domain]; !inflight { + lg.inflight[domain] = struct{}{} + return true + } + return false +} + +// Unlock marks the domain as being done. +func (lg *loopGuard) Unlock(domain string) { + lg.mu.Lock() + defer lg.mu.Unlock() + delete(lg.inflight, domain) +} + // isLoop reports whether the given upstream config is detected as having DNS loop. func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool { p.loopMu.Lock() diff --git a/cmd/cli/loop_test.go b/cmd/cli/loop_test.go new file mode 100644 index 0000000..e8cfb2a --- /dev/null +++ b/cmd/cli/loop_test.go @@ -0,0 +1,38 @@ +package cli + +import ( + "sync" + "testing" +) + +func Test_loopGuard(t *testing.T) { + lg := newLoopGuard() + key := "foo" + + var mu sync.Mutex + i := 0 + n := 1000 + do := func() { + locked := lg.TryLock(key) + defer lg.Unlock(key) + if locked { + mu.Lock() + i++ + mu.Unlock() + } + } + + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + do() + }() + } + wg.Wait() + + if i == n { + t.Fatalf("i must not be increased %d times", n) + } +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 9fcb42f..55dfafc 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -66,6 +66,8 @@ type prog struct { ciTable *clientinfo.Table um *upstreamMonitor router router.Router + ptrLoopGuard *loopGuard + lanLoopGuard *loopGuard loopMu sync.Mutex loop map[string]bool @@ -236,6 +238,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { } p.onStartedDone = make(chan struct{}) p.loop = make(map[string]bool) + p.lanLoopGuard = newLoopGuard() + p.ptrLoopGuard = newLoopGuard() if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { From 0bb8703f7824beb9ed9c0585fc28799cb6d22675 Mon Sep 17 00:00:00 2001 From: Alex Paguis Date: Tue, 5 Dec 2023 20:05:56 +0000 Subject: [PATCH 44/52] Update document for new client_id_preference param --- docs/config.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/config.md b/docs/config.md index 266c17a..81e67ee 100644 --- a/docs/config.md +++ b/docs/config.md @@ -216,11 +216,11 @@ DHCP leases file format. - Default: "" ### client_id_preference -Decide how client ID has is generated. - -If `host` -> client id will be a `hash(hostname)`. -If `mac` -> client id will be `hash(mac)`. +Decide how the client ID is generated +If `host` -> client id will only use the hostname i.e.`hash(hostname)`. +If `mac` -> client id will only use the MAC address `hash(mac)`. +Else -> client ID will use both Mac and Hostname i.e. `hash(mac + host) - Type: string - Required: no - Valid values: `mac`, `host` From 874ff01ab8c05b3074cf32b69009c4188d53bdb9 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 6 Dec 2023 14:52:20 +0700 Subject: [PATCH 45/52] cmd/cli: ensure log time field is formated with ms --- cmd/cli/main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index b1287ae..aa166ca 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -98,6 +98,7 @@ func initConsoleLogging() { // initLogging initializes global logging setup. func initLogging() { + zerolog.TimeFieldFormat = time.RFC3339 + ".000" initLoggingWithBackup(true) } @@ -136,7 +137,7 @@ func initLoggingWithBackup(doBackup bool) { } writers = append(writers, consoleWriter) multi := zerolog.MultiLevelWriter(writers...) - l := mainLog.Load().Output(multi).With().Timestamp().Logger() + l := mainLog.Load().Output(multi).With().Logger() mainLog.Store(&l) // TODO: find a better way. ctrld.ProxyLogger.Store(&l) From cebfd12d5c5f84ec8295bad6d82a9346448df89f Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 6 Dec 2023 15:09:13 +0700 Subject: [PATCH 46/52] internal/clientinfo: ensure RFC1918 address is chosen over others --- internal/clientinfo/dhcp.go | 32 +++++++++++++++++++++++--------- internal/clientinfo/dhcp_test.go | 12 ++++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/internal/clientinfo/dhcp.go b/internal/clientinfo/dhcp.go index a103263..ebbeb77 100644 --- a/internal/clientinfo/dhcp.go +++ b/internal/clientinfo/dhcp.go @@ -8,6 +8,7 @@ import ( "net" "net/netip" "os" + "sort" "strings" "sync" @@ -138,20 +139,33 @@ func (d *dhcp) lookupIPByHostname(name string, v6 bool) string { if d == nil { return "" } - var ip string + var ( + rfc1918Addrs []netip.Addr + others []netip.Addr + ) d.ip2name.Range(func(key, value any) bool { - if value == name { - if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { - ip = addr.String() - if addr.IsLoopback() { // Continue searching if this is loopback address. - return true - } - return false + if value != name { + return true + } + if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { + if addr.IsPrivate() { + rfc1918Addrs = append(rfc1918Addrs, addr) + } else { + others = append(others, addr) } } return true }) - return ip + result := [][]netip.Addr{rfc1918Addrs, others} + for _, addrs := range result { + if len(addrs) > 0 { + sort.Slice(addrs, func(i, j int) bool { + return addrs[i].Less(addrs[j]) + }) + return addrs[0].String() + } + } + return "" } // AddLeaseFile adds given lease file for reading/watching clients info. diff --git a/internal/clientinfo/dhcp_test.go b/internal/clientinfo/dhcp_test.go index af3a168..359f441 100644 --- a/internal/clientinfo/dhcp_test.go +++ b/internal/clientinfo/dhcp_test.go @@ -86,3 +86,15 @@ lease 192.168.1.2 { }) } } + +func Test_dhcp_lookupIPByHostname(t *testing.T) { + d := &dhcp{} + want := "192.168.1.123" + d.ip2name.Store(want, "foo") + d.ip2name.Store("127.0.0.1", "foo") + d.ip2name.Store("169.254.123.123", "foo") + + if got := d.lookupIPByHostname("foo", false); got != want { + t.Fatalf("unexpected result, want: %s, got: %s", want, got) + } +} From e92619620d8fd80da684e24498d714efa3d0814f Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 8 Dec 2023 23:04:15 +0700 Subject: [PATCH 47/52] cmd/cli: doing router setup based on "--iface" flag Solving downgrading issue from newer version to v1.3.1, and also easier to explain the logic: either doing "magic stuff" or do nothing. --- cmd/cli/cli.go | 18 ++---------------- cmd/cli/main.go | 42 ++++++++++++++++++++---------------------- 2 files changed, 22 insertions(+), 38 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 8b2999f..3f76c80 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -144,8 +144,6 @@ func initCLI() { _ = runCmd.Flags().MarkHidden("homedir") runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) _ = runCmd.Flags().MarkHidden("iface") - runCmd.Flags().BoolVarP(&setupRouter, "router", "", false, "Do setup router") - _ = runCmd.Flags().MarkHidden("router") runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) rootCmd.AddCommand(runCmd) @@ -255,7 +253,7 @@ func initCLI() { return } - if router.Name() != "" && setupRouter { + if router.Name() != "" && iface != "" { mainLog.Load().Debug().Msg("cleaning up router before installing") _ = p.router.Cleanup() } @@ -312,8 +310,6 @@ func initCLI() { startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) - startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, "Do router setup") - _ = startCmd.Flags().MarkHidden("router") routerCmd := &cobra.Command{ Use: "setup", @@ -598,16 +594,11 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if !cmd.Flags().Changed("iface") { os.Args = append(os.Args, "--iface="+ifaceStartStop) } - if !cmd.Flags().Changed("router") { - os.Args = append(os.Args, fmt.Sprintf("--router=%v", setupRouterStartStop)) - } iface = ifaceStartStop - setupRouter = setupRouterStartStop startCmd.Run(cmd, args) }, } startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`) - startCmdAlias.Flags().BoolVarP(&setupRouterStartStop, "router", "", true, "Do router setup") startCmdAlias.Flags().AddFlagSet(startCmd.Flags()) rootCmd.AddCommand(startCmdAlias) stopCmdAlias := &cobra.Command{ @@ -621,16 +612,11 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if !cmd.Flags().Changed("iface") { os.Args = append(os.Args, "--iface="+ifaceStartStop) } - if !cmd.Flags().Changed("router") { - os.Args = append(os.Args, fmt.Sprintf("--router=%v", setupRouterStartStop)) - } iface = ifaceStartStop - setupRouter = setupRouterStartStop stopCmd.Run(cmd, args) }, } stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`) - stopCmdAlias.Flags().BoolVarP(&setupRouterStartStop, "router", "", true, "Do router setup") stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags()) rootCmd.AddCommand(stopCmdAlias) @@ -991,7 +977,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { if cp := router.CertPool(); cp != nil { rootCertPool = cp } - if setupRouter { + if iface != "" { p.onStarted = append(p.onStarted, func() { mainLog.Load().Debug().Msg("router setup on start") if err := p.router.Setup(); err != nil { diff --git a/cmd/cli/main.go b/cmd/cli/main.go index aa166ca..3f1ef8b 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -14,28 +14,26 @@ import ( ) var ( - configPath string - configBase64 string - daemon bool - listenAddress string - primaryUpstream string - secondaryUpstream string - domains []string - logPath string - homedir string - cacheSize int - cfg ctrld.Config - verbose int - silent bool - cdUID string - cdOrg string - cdDev bool - iface string - ifaceStartStop string - nextdns string - cdUpstreamProto string - setupRouter bool - setupRouterStartStop bool + configPath string + configBase64 string + daemon bool + listenAddress string + primaryUpstream string + secondaryUpstream string + domains []string + logPath string + homedir string + cacheSize int + cfg ctrld.Config + verbose int + silent bool + cdUID string + cdOrg string + cdDev bool + iface string + ifaceStartStop string + nextdns string + cdUpstreamProto string mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter From 684019c2e3be57128562199533aab33136eae8a6 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 8 Dec 2023 23:23:33 +0700 Subject: [PATCH 48/52] all: force re-bootstrapping with timeout error --- cmd/cli/dns_proxy.go | 6 ++++++ config.go | 5 +++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 080bebc..0a68071 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/hex" + "errors" "fmt" "net" "net/netip" @@ -438,6 +439,11 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *dns.Msg { go p.um.checkUpstream(upstreams[n], upstreamConfig) } } + // For timeout error (i.e: context deadline exceed), force re-bootstrapping. + var e net.Error + if errors.As(err, &e) && e.Timeout() { + upstreamConfig.ReBootstrap() + } return nil } return answer diff --git a/config.go b/config.go index d3509be..aeb6e27 100644 --- a/config.go +++ b/config.go @@ -426,8 +426,9 @@ func (uc *UpstreamConfig) ReBootstrap() { return } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { - ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip") - uc.rebootstrap.Store(true) + if uc.rebootstrap.CompareAndSwap(false, true) { + ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip") + } return true, nil }) } From dfbcb1489d78fb83a9c550f92e33bb6ed3c2d711 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 11 Dec 2023 23:17:38 +0700 Subject: [PATCH 49/52] cmd/cli: improving loop guard test We see number of failed test in Github Action, mostly on MacOS or Windows due to the fact that goroutines are scheduled to be run consequently. This commit improves the test, ensuring at least 2 goroutines were started before increasing the counting. --- cmd/cli/loop_test.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cmd/cli/loop_test.go b/cmd/cli/loop_test.go index e8cfb2a..b2c8404 100644 --- a/cmd/cli/loop_test.go +++ b/cmd/cli/loop_test.go @@ -2,6 +2,7 @@ package cli import ( "sync" + "sync/atomic" "testing" ) @@ -9,16 +10,19 @@ func Test_loopGuard(t *testing.T) { lg := newLoopGuard() key := "foo" - var mu sync.Mutex - i := 0 + var i atomic.Int64 + var started atomic.Int64 n := 1000 do := func() { locked := lg.TryLock(key) defer lg.Unlock(key) + started.Add(1) + for started.Load() < 2 { + // Wait until at least 2 goroutines started, otherwise, on system with heavy load, + // or having only 1 CPU, all goroutines can be scheduled to run consequently. + } if locked { - mu.Lock() - i++ - mu.Unlock() + i.Add(1) } } @@ -32,7 +36,7 @@ func Test_loopGuard(t *testing.T) { } wg.Wait() - if i == n { + if i.Load() == int64(n) { t.Fatalf("i must not be increased %d times", n) } } From 41846b6d4c79ac6a7b594d2a2076026cf85906a8 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 12 Dec 2023 18:34:41 +0700 Subject: [PATCH 50/52] all: add config to enable/disable answering WAN clients --- cmd/cli/dns_proxy.go | 27 ++++++++++++++++++++++++--- cmd/cli/dns_proxy_test.go | 26 ++++++++++++++++++++++++++ config.go | 9 +++++---- docs/config.md | 9 ++++++++- 4 files changed, 63 insertions(+), 8 deletions(-) diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 0a68071..2b0f94d 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -77,12 +77,21 @@ func (p *prog) serveDNS(listenerNum string) error { if len(m.Question) == 0 { answer := new(dns.Msg) answer.SetRcode(m, dns.RcodeFormatError) + _ = w.WriteMsg(answer) + return + } + reqId := requestID() + ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) + if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) { + ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String()) + answer := new(dns.Msg) + answer.SetRcode(m, dns.RcodeRefused) + _ = w.WriteMsg(answer) return } go p.detectLoop(m) q := m.Question[0] domain := canonicalName(q.Name) - reqId := requestID() remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ci := p.getClientInfo(remoteIP, m) ci.ClientIDPref = p.cfg.Service.ClientIDPref @@ -90,7 +99,6 @@ func (p *prog) serveDNS(listenerNum string) error { remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci) fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String()) t := time.Now() - ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId) ctrld.Log(ctx, mainLog.Load().Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain) res := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain) var answer *dns.Msg @@ -113,7 +121,7 @@ func (p *prog) serveDNS(listenerNum string) error { ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } if err := w.WriteMsg(answer); err != nil { - ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveUDP: failed to send DNS response to client") + ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client") } }) @@ -865,3 +873,16 @@ func isLanHostnameQuery(m *dns.Msg) bool { strings.HasSuffix(name, ".domain") || strings.HasSuffix(name, ".lan") } + +// isWanClient reports whether the input is a WAN address. +func isWanClient(na net.Addr) bool { + var ip netip.Addr + if ap, err := netip.ParseAddrPort(na.String()); err == nil { + ip = ap.Addr() + } + return !ip.IsLoopback() && + !ip.IsPrivate() && + !ip.IsLinkLocalUnicast() && + !ip.IsLinkLocalMulticast() && + !tsaddr.CGNATRange().Contains(ip) +} diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 82c4f63..bd73d17 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -405,3 +405,29 @@ func Test_isPrivatePtrLookup(t *testing.T) { }) } } + +func Test_isWanClient(t *testing.T) { + tests := []struct { + name string + addr net.Addr + isWanClient bool + }{ + // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as + {"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false}, + {"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false}, + {"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false}, + {"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false}, + {"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false}, + {"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false}, + {"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isWanClient(tc.addr); tc.isWanClient != got { + t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got) + } + }) + } +} diff --git a/config.go b/config.go index aeb6e27..5baa10d 100644 --- a/config.go +++ b/config.go @@ -240,10 +240,11 @@ type UpstreamConfig struct { // ListenerConfig specifies the networks configuration that ctrld will run on. type ListenerConfig struct { - IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"` - Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"` - Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"` - Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"` + IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"` + Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"` + Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"` + AllowWanClients bool `mapstructure:"allow_wan_clients" toml:"allow_wan_clients,omitempty"` + Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"` } // IsDirectDnsListener reports whether ctrld can be a direct listener on port 53. diff --git a/docs/config.md b/docs/config.md index 81e67ee..e5b3945 100644 --- a/docs/config.md +++ b/docs/config.md @@ -405,7 +405,14 @@ Port number that the listener will listen on for incoming requests. If `port` is - Default: 0 or 53 or 5354 (depending on platform) ### restricted -If set to `true` makes the listener `REFUSE` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`. +If set to `true`, makes the listener `REFUSED` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`. + +- Type: bool +- Required: no +- Default: false + +### allow_wan_clients +The listener `REFUSED` DNS queries from WAN clients by default. If set to `true`, makes the listener replies to them. - Type: bool - Required: no From 122600bff2025a29beb41bd271501e25f2411116 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 16 Nov 2023 21:05:11 +0700 Subject: [PATCH 51/52] cmd/cli: remove redundant return statement --- cmd/cli/prog.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 55dfafc..878681e 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -325,7 +325,6 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) { case <-ctx.Done(): case <-reloadCh: } - return }() } From 0084e9ef26245e063b34a866797f37226fb02c9a Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 6 Dec 2023 15:39:00 +0700 Subject: [PATCH 52/52] internal/clientinfo: silent staticcheck S1008 The code is written for readability purpose. --- internal/clientinfo/mdns.go | 1 + internal/clientinfo/ptr_lookup.go | 1 + 2 files changed, 2 insertions(+) diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index f89e13f..3f0a311 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -69,6 +69,7 @@ func (m *mdns) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() + //lint:ignore S1008 This is used for readable. if addr.IsLoopback() { // Continue searching if this is loopback address. return true } diff --git a/internal/clientinfo/ptr_lookup.go b/internal/clientinfo/ptr_lookup.go index 1439752..8e6b3f7 100644 --- a/internal/clientinfo/ptr_lookup.go +++ b/internal/clientinfo/ptr_lookup.go @@ -104,6 +104,7 @@ func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string { if value == name { if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 { ip = addr.String() + //lint:ignore S1008 This is used for readable. if addr.IsLoopback() { // Continue searching if this is loopback address. return true }