From 904b23eeac7408d402c54e643c28d61ace1043f4 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Tue, 17 Oct 2023 00:56:52 +0700 Subject: [PATCH] cmd/cli: add --proto flag to set upstream type in cd mode --- cmd/cli/cli.go | 21 +++++++++++++++++++-- cmd/cli/main.go | 1 + 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 8d733b3..d0d64c9 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -144,6 +144,7 @@ func initCLI() { _ = runCmd.Flags().MarkHidden("homedir") runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) _ = runCmd.Flags().MarkHidden("iface") + runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) rootCmd.AddCommand(runCmd) @@ -177,6 +178,9 @@ func initCLI() { // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. sc.Arguments = append(sc.Arguments, "--cd="+cdUID) } + if cdUID != "" { + validateCdUpstreamProtocol() + } p := &prog{ router: router.New(&cfg, cdUID != ""), @@ -306,6 +310,7 @@ func initCLI() { _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id") + startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`) routerCmd := &cobra.Command{ Use: "setup", @@ -788,6 +793,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { cdUID = uid } if cdUID != "" { + validateCdUpstreamProtocol() err := processCDFlags() if err != nil { appCallback.Exit(err.Error()) @@ -1099,7 +1105,7 @@ func processCDFlags() error { cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) cfg.Upstream["0"] = &ctrld.UpstreamConfig{ Endpoint: resolverConfig.DOH, - Type: ctrld.ResolverTypeDOH, + Type: cdUpstreamProto, Timeout: 5000, } rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) @@ -1816,7 +1822,18 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) { return } if fl.Value.String() == "" { - mainLog.Load().Fatal().Msgf(`flag "--%s"" value must be non-empty`, fl.Name) + mainLog.Load().Fatal().Msgf(`flag "--%s" value must be non-empty`, fl.Name) + } +} + +func validateCdUpstreamProtocol() { + if cdUID == "" { + return + } + switch cdUpstreamProto { + case ctrld.ResolverTypeDOH, ctrld.ResolverTypeDOH3: + default: + mainLog.Load().Fatal().Msg(`flag "--protocol" must be "doh" or "doh3"`) } } diff --git a/cmd/cli/main.go b/cmd/cli/main.go index f79b15f..bf65044 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -33,6 +33,7 @@ var ( iface string ifaceStartStop string nextdns string + cdUpstreamProto string mainLog atomic.Pointer[zerolog.Logger] consoleWriter zerolog.ConsoleWriter