From c27189655166d3a59856505fb2c26f06b7028f97 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 27 Jul 2023 23:05:27 +0000 Subject: [PATCH] all: add support for provision token --- cmd/ctrld/cli.go | 78 +++++++++++++++++++++++++++++++++---- cmd/ctrld/main.go | 1 + config.go | 19 +++++++++ internal/controld/config.go | 27 +++++++++++-- 4 files changed, 114 insertions(+), 11 deletions(-) diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 55fcaea..2369323 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -202,6 +202,9 @@ func initCLI() { } oldLogPath := cfg.Service.LogPath + if uid := cdUIDFromProvToken(); uid != "" { + cdUID = uid + } if cdUID != "" { processCDFlags() } @@ -311,6 +314,7 @@ func initCLI() { runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") + runCmd.Flags().StringVarP(&cdOrg, "cd-org", "", "", "Control D provision token") runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = runCmd.Flags().MarkHidden("dev") runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") @@ -337,6 +341,12 @@ func initCLI() { } setDependencies(sc) sc.Arguments = append([]string{"run"}, osArgs...) + if uid := cdUIDFromProvToken(); uid != "" { + cdUID = uid + removeProvTokenFromArgs(sc) + // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. + sc.Arguments = append(sc.Arguments, "--cd="+cdUID) + } p := &prog{ router: router.New(&cfg, cdUID != ""), @@ -427,8 +437,7 @@ func initCLI() { return } - domain := cfg.Upstream["0"].VerifyDomain() - status = selfCheckStatus(status, domain) + status = selfCheckStatus(status) switch status { case service.StatusRunning: mainLog.Load().Notice().Msg("Service started") @@ -462,6 +471,7 @@ func initCLI() { startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") + startCmd.Flags().StringVarP(&cdOrg, "cd-org", "", "", "Control D provision token") startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) @@ -1100,11 +1110,7 @@ func defaultIfaceName() string { return dri } -func selfCheckStatus(status service.Status, domain string) service.Status { - if domain == "" { - // Nothing to do, return the status as-is. - return status - } +func selfCheckStatus(status service.Status) service.Status { dir, err := userHomeDir() if err != nil { mainLog.Load().Error().Err(err).Msg("failed to check ctrld listener status: could not get home directory") @@ -1146,6 +1152,7 @@ func selfCheckStatus(status service.Status, domain string) service.Status { c := new(dns.Client) var ( lcChanged map[string]*ctrld.ListenerConfig + ucChanged map[string]*ctrld.UpstreamConfig mu sync.Mutex ) @@ -1155,6 +1162,11 @@ func selfCheckStatus(status service.Status, domain string) service.Status { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Err(err).Msg("failed to update new config") } + domain := cfg.FirstUpstream().VerifyDomain() + if domain == "" { + // Nothing to do, return the status as-is. + return status + } watcher, err := fsnotify.NewWatcher() if err != nil { mainLog.Load().Error().Err(err).Msg("could not watch config change") @@ -1169,6 +1181,10 @@ func selfCheckStatus(status service.Status, domain string) service.Status { mainLog.Load().Error().Msgf("failed to unmarshal listener config: %v", err) return } + if err := v.UnmarshalKey("upstream", &ucChanged); err != nil { + mainLog.Load().Error().Msgf("failed to unmarshal upstream config: %v", err) + return + } }) v.WatchConfig() var ( @@ -1180,8 +1196,15 @@ func selfCheckStatus(status service.Status, domain string) service.Status { if lcChanged != nil { cfg.Listener = lcChanged } + if ucChanged != nil { + cfg.Upstream = ucChanged + } mu.Unlock() lc := cfg.FirstListener() + domain = cfg.FirstUpstream().VerifyDomain() + if domain == "" { + continue + } m := new(dns.Msg) m.SetQuestion(domain+".", dns.TypeA) @@ -1599,3 +1622,44 @@ func osVersion() string { } return oi.String() } + +// cdUIDFromProvToken fetch UID from ControlD API using provision token. +func cdUIDFromProvToken() string { + // --cd flag supersedes --cd-org, ignore it if both are supplied. + if cdUID != "" { + return "" + } + // --cd-org is empty, nothing to do. + if cdOrg == "" { + return "" + } + // Process provision token if provided. + resolverConfig, err := controld.FetchResolverUID(cdOrg, rootCmd.Version, cdDev) + if err != nil { + mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg) + } + return resolverConfig.UID +} + +// removeProvTokenFromArgs removes the --cd-org from command line arguments. +func removeProvTokenFromArgs(sc *service.Config) { + a := sc.Arguments[:0] + skip := false + for _, x := range sc.Arguments { + if skip { + skip = false + continue + } + // For "--cd-org XXX", skip it and mark next arg skipped. + if x == "--cd-org" { + skip = true + continue + } + // For "--cd-org=XXX", just skip it. + if strings.HasPrefix(x, "--cd-org=") { + continue + } + a = append(a, x) + } + sc.Arguments = a +} diff --git a/cmd/ctrld/main.go b/cmd/ctrld/main.go index 75e7d2b..80160ec 100644 --- a/cmd/ctrld/main.go +++ b/cmd/ctrld/main.go @@ -28,6 +28,7 @@ var ( verbose int silent bool cdUID string + cdOrg string cdDev bool iface string ifaceStartStop string diff --git a/config.go b/config.go index 0019e00..3e2efe4 100644 --- a/config.go +++ b/config.go @@ -144,6 +144,25 @@ func (c *Config) FirstListener() *ListenerConfig { return c.Listener[strconv.Itoa(listeners[0])] } +// FirstUpstream returns the first upstream of current config. Upstreams are sorted numerically. +// +// It panics if Config has no upstreams configured. +func (c *Config) FirstUpstream() *UpstreamConfig { + upstreams := make([]int, 0, len(c.Upstream)) + for k := range c.Upstream { + n, err := strconv.Atoi(k) + if err != nil { + continue + } + upstreams = append(upstreams, n) + } + if len(upstreams) == 0 { + panic("missing listener config") + } + sort.Ints(upstreams) + return c.Upstream[strconv.Itoa(upstreams[0])] +} + // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` diff --git a/internal/controld/config.go b/internal/controld/config.go index 320fd4c..4e4bc2e 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -6,8 +6,10 @@ import ( "crypto/tls" "encoding/json" "fmt" + "io" "net" "net/http" + "os" "strings" "time" @@ -33,6 +35,7 @@ type ResolverConfig struct { CustomConfig string `json:"custom_config"` } `json:"ctrld"` Exclude []string `json:"exclude"` + UID string `json:"uid"` } type utilityResponse struct { @@ -58,19 +61,35 @@ type utilityRequest struct { ClientID string `json:"client_id,omitempty"` } +type utilityOrgRequest struct { + ProvToken string `json:"prov_token"` + Hostname string `json:"hostname"` +} + // FetchResolverConfig fetch Control D config for given uid. func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) { uid, clientID := ParseRawUID(rawUID) - uReq := utilityRequest{UID: uid} + req := utilityRequest{UID: uid} if clientID != "" { - uReq.ClientID = clientID + req.ClientID = clientID } - body, _ := json.Marshal(uReq) + body, _ := json.Marshal(req) + return postUtilityAPI(version, cdDev, bytes.NewReader(body)) +} + +// FetchResolverUID fetch resolver uid from provision token. +func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) { + hostname, _ := os.Hostname() + body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname}) + return postUtilityAPI(version, cdDev, bytes.NewReader(body)) +} + +func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig, error) { apiUrl := resolverDataURLCom if cdDev { apiUrl = resolverDataURLDev } - req, err := http.NewRequest("POST", apiUrl, bytes.NewReader(body)) + req, err := http.NewRequest("POST", apiUrl, body) if err != nil { return nil, fmt.Errorf("http.NewRequest: %w", err) }