diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index bf7b9a6..8d733b3 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -158,6 +158,7 @@ func initCLI() { Run: func(cmd *cobra.Command, args []string) { checkStrFlagEmpty(cmd, cdUidFlagName) checkStrFlagEmpty(cmd, cdOrgFlagName) + validateCdAndNextDNSFlags() sc := &service.Config{} *sc = *svcConfig osArgs := os.Args[2:] @@ -231,6 +232,15 @@ func initCLI() { initLogging() + if nextdns != "" { + removeNextDNSFromArgs(sc) + generateNextDNSConfig() + updateListenerConfig() + if err := writeConfigFile(); err != nil { + mainLog.Load().Error().Err(err).Msg("failed to write config with NextDNS resolver") + } + } + // Explicitly passing config, so on system where home directory could not be obtained, // or sub-process env is different with the parent, we still behave correctly and use // the expected config file. @@ -281,7 +291,7 @@ func initCLI() { } }, } - // Keep these flags in sync with runCmd above, except for "-d". + // Keep these flags in sync with runCmd above, except for "-d"/"--nextdns". startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") @@ -295,6 +305,7 @@ func initCLI() { startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") routerCmd := &cobra.Command{ Use: "setup", @@ -1216,6 +1227,11 @@ func selfCheckStatus(s service.Service) service.Status { return service.StatusUnknown } + // Not a ctrld upstream, return status as-is. + if cfg.FirstUpstream().VerifyDomain() == "" { + return status + } + mainLog.Load().Debug().Msg("ctrld listener is ready") mainLog.Load().Debug().Msg("performing self-check") bo := backoff.NewBackoff("self-check", logf, 10*time.Second) @@ -1489,6 +1505,7 @@ func mobileListenerPort() int { func updateListenerConfig() (updated bool) { lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" + nextdnsMode := nextdns != "" for n, listener := range cfg.Listener { lcc[n] = &listenerConfigCheck{} if listener.IP == "" { @@ -1500,7 +1517,8 @@ func updateListenerConfig() (updated bool) { lcc[n].Port = true } // In cd mode, we always try to pick an ip:port pair to work. - if cdMode { + // Same if nextdns resolver is used. + if cdMode || nextdnsMode { lcc[n].IP = true lcc[n].Port = true } @@ -1801,3 +1819,32 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name) } } + +func validateCdAndNextDNSFlags() { + if (cdUID != "" || cdOrg != "") && nextdns != "" { + mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName) + } +} + +// removeNextDNSFromArgs removes the --nextdns from command line arguments. +func removeNextDNSFromArgs(sc *service.Config) { + a := sc.Arguments[:0] + skip := false + for _, x := range sc.Arguments { + if skip { + skip = false + continue + } + // For "--nextdns XXX", skip it and mark next arg skipped. + if x == "--"+nextdnsFlagName { + skip = true + continue + } + // For "--nextdns=XXX", just skip it. + if strings.HasPrefix(x, "--"+nextdnsFlagName+"=") { + continue + } + a = append(a, x) + } + sc.Arguments = a +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index f4439a5..f79b15f 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -32,6 +32,7 @@ var ( cdDev bool iface string ifaceStartStop string + nextdns string mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter @@ -39,8 +40,9 @@ var ( ) const ( - cdUidFlagName = "cd" - cdOrgFlagName = "cd-org" + cdUidFlagName = "cd" + cdOrgFlagName = "cd-org" + nextdnsFlagName = "nextdns" ) func init() { diff --git a/cmd/cli/nextdns.go b/cmd/cli/nextdns.go new file mode 100644 index 0000000..8aebfc8 --- /dev/null +++ b/cmd/cli/nextdns.go @@ -0,0 +1,31 @@ +package cli + +import ( + "fmt" + + "github.com/Control-D-Inc/ctrld" +) + +const nextdnsURL = "https://dns.nextdns.io" + +func generateNextDNSConfig() { + if nextdns == "" { + return + } + mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver") + cfg = ctrld.Config{ + Listener: map[string]*ctrld.ListenerConfig{ + "0": { + IP: "0.0.0.0", + Port: 53, + }, + }, + Upstream: map[string]*ctrld.UpstreamConfig{ + "0": { + Type: ctrld.ResolverTypeDOH3, + Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, nextdns), + Timeout: 5000, + }, + }, + } +} diff --git a/config.go b/config.go index c9c1acc..489a7fd 100644 --- a/config.go +++ b/config.go @@ -323,7 +323,7 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool { } switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: - if uc.isControlD() { + if uc.isControlD() || uc.isNextDNS() { return true } } @@ -520,6 +520,16 @@ func (uc *UpstreamConfig) isControlD() bool { return false } +func (uc *UpstreamConfig) isNextDNS() bool { + domain := uc.Domain + if domain == "" { + if u, err := url.Parse(uc.Endpoint); err == nil { + domain = u.Hostname() + } + } + return domain == "dns.nextdns.io" +} + func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { uc.transportOnce.Do(func() { uc.SetupTransport() diff --git a/doh.go b/doh.go index e0aa363..96f8051 100644 --- a/doh.go +++ b/doh.go @@ -76,7 +76,6 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver { endpoint: uc.u, isDoH3: uc.Type == ResolverTypeDOH3, http3RoundTripper: uc.http3RoundTripper, - sendClientInfo: uc.UpstreamSendClientInfo(), uc: uc, } return r @@ -87,9 +86,9 @@ type dohResolver struct { endpoint *url.URL isDoH3 bool http3RoundTripper http.RoundTripper - sendClientInfo bool } +// Resolve performs DNS query with given DNS message using DOH protocol. func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { data, err := msg.Pack() if err != nil { @@ -106,7 +105,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if err != nil { return nil, fmt.Errorf("could not create request: %w", err) } - addHeader(ctx, req, r.sendClientInfo) + addHeader(ctx, req, r.uc) dnsTyp := uint16(0) if len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype @@ -146,26 +145,19 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return answer, nil } -func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { +func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) { req.Header.Set("Content-Type", headerApplicationDNS) req.Header.Set("Accept", headerApplicationDNS) - req.Header.Set(dohOsHeader, dohOsHeaderValue()) printed := false - if sendClientInfo { + if uc.UpstreamSendClientInfo() { if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != "" - if ci.Mac != "" { - req.Header.Set(dohMacHeader, ci.Mac) - } - if ci.IP != "" { - req.Header.Set(dohIPHeader, ci.IP) - } - if ci.Hostname != "" { - req.Header.Set(dohHostHeader, ci.Hostname) - } - if ci.Self { - req.Header.Set(dohOsHeader, dohOsHeaderValue()) + switch { + case uc.isControlD(): + addControlDHeaders(req, ci) + case uc.isNextDNS(): + addNextDNSHeaders(req, ci) } } } @@ -173,3 +165,35 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") } } + +// addControlDHeaders set DoH/Doh3 HTTP request headers for ControlD upstream. +func addControlDHeaders(req *http.Request, ci *ClientInfo) { + req.Header.Set(dohOsHeader, dohOsHeaderValue()) + if ci.Mac != "" { + req.Header.Set(dohMacHeader, ci.Mac) + } + if ci.IP != "" { + req.Header.Set(dohIPHeader, ci.IP) + } + if ci.Hostname != "" { + req.Header.Set(dohHostHeader, ci.Hostname) + } + if ci.Self { + req.Header.Set(dohOsHeader, dohOsHeaderValue()) + } +} + +// addNextDNSHeaders set DoH/Doh3 HTTP request headers for nextdns upstream. +// https://github.com/nextdns/nextdns/blob/v1.41.0/resolver/doh.go#L100 +func addNextDNSHeaders(req *http.Request, ci *ClientInfo) { + if ci.Mac != "" { + // https: //github.com/nextdns/nextdns/blob/v1.41.0/run.go#L543 + req.Header.Set("X-Device-Model", "mac:"+ci.Mac[:8]) + } + if ci.IP != "" { + req.Header.Set("X-Device-Ip", ci.IP) + } + if ci.Hostname != "" { + req.Header.Set("X-Device-Name", ci.Hostname) + } +}