all: add support for provision token

This commit is contained in:
Cuong Manh Le
2023-07-27 23:05:27 +00:00
committed by Cuong Manh Le
parent 82d887f52d
commit c271896551
4 changed files with 114 additions and 11 deletions

View File

@@ -202,6 +202,9 @@ func initCLI() {
}
oldLogPath := cfg.Service.LogPath
if uid := cdUIDFromProvToken(); uid != "" {
cdUID = uid
}
if cdUID != "" {
processCDFlags()
}
@@ -311,6 +314,7 @@ func initCLI() {
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
runCmd.Flags().StringVarP(&cdOrg, "cd-org", "", "", "Control D provision token")
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = runCmd.Flags().MarkHidden("dev")
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
@@ -337,6 +341,12 @@ func initCLI() {
}
setDependencies(sc)
sc.Arguments = append([]string{"run"}, osArgs...)
if uid := cdUIDFromProvToken(); uid != "" {
cdUID = uid
removeProvTokenFromArgs(sc)
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
sc.Arguments = append(sc.Arguments, "--cd="+cdUID)
}
p := &prog{
router: router.New(&cfg, cdUID != ""),
@@ -427,8 +437,7 @@ func initCLI() {
return
}
domain := cfg.Upstream["0"].VerifyDomain()
status = selfCheckStatus(status, domain)
status = selfCheckStatus(status)
switch status {
case service.StatusRunning:
mainLog.Load().Notice().Msg("Service started")
@@ -462,6 +471,7 @@ func initCLI() {
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
startCmd.Flags().StringVarP(&cdOrg, "cd-org", "", "", "Control D provision token")
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = startCmd.Flags().MarkHidden("dev")
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
@@ -1100,11 +1110,7 @@ func defaultIfaceName() string {
return dri
}
func selfCheckStatus(status service.Status, domain string) service.Status {
if domain == "" {
// Nothing to do, return the status as-is.
return status
}
func selfCheckStatus(status service.Status) service.Status {
dir, err := userHomeDir()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to check ctrld listener status: could not get home directory")
@@ -1146,6 +1152,7 @@ func selfCheckStatus(status service.Status, domain string) service.Status {
c := new(dns.Client)
var (
lcChanged map[string]*ctrld.ListenerConfig
ucChanged map[string]*ctrld.UpstreamConfig
mu sync.Mutex
)
@@ -1155,6 +1162,11 @@ func selfCheckStatus(status service.Status, domain string) service.Status {
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to update new config")
}
domain := cfg.FirstUpstream().VerifyDomain()
if domain == "" {
// Nothing to do, return the status as-is.
return status
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
mainLog.Load().Error().Err(err).Msg("could not watch config change")
@@ -1169,6 +1181,10 @@ func selfCheckStatus(status service.Status, domain string) service.Status {
mainLog.Load().Error().Msgf("failed to unmarshal listener config: %v", err)
return
}
if err := v.UnmarshalKey("upstream", &ucChanged); err != nil {
mainLog.Load().Error().Msgf("failed to unmarshal upstream config: %v", err)
return
}
})
v.WatchConfig()
var (
@@ -1180,8 +1196,15 @@ func selfCheckStatus(status service.Status, domain string) service.Status {
if lcChanged != nil {
cfg.Listener = lcChanged
}
if ucChanged != nil {
cfg.Upstream = ucChanged
}
mu.Unlock()
lc := cfg.FirstListener()
domain = cfg.FirstUpstream().VerifyDomain()
if domain == "" {
continue
}
m := new(dns.Msg)
m.SetQuestion(domain+".", dns.TypeA)
@@ -1599,3 +1622,44 @@ func osVersion() string {
}
return oi.String()
}
// cdUIDFromProvToken fetch UID from ControlD API using provision token.
func cdUIDFromProvToken() string {
// --cd flag supersedes --cd-org, ignore it if both are supplied.
if cdUID != "" {
return ""
}
// --cd-org is empty, nothing to do.
if cdOrg == "" {
return ""
}
// Process provision token if provided.
resolverConfig, err := controld.FetchResolverUID(cdOrg, rootCmd.Version, cdDev)
if err != nil {
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg)
}
return resolverConfig.UID
}
// removeProvTokenFromArgs removes the --cd-org from command line arguments.
func removeProvTokenFromArgs(sc *service.Config) {
a := sc.Arguments[:0]
skip := false
for _, x := range sc.Arguments {
if skip {
skip = false
continue
}
// For "--cd-org XXX", skip it and mark next arg skipped.
if x == "--cd-org" {
skip = true
continue
}
// For "--cd-org=XXX", just skip it.
if strings.HasPrefix(x, "--cd-org=") {
continue
}
a = append(a, x)
}
sc.Arguments = a
}

View File

@@ -28,6 +28,7 @@ var (
verbose int
silent bool
cdUID string
cdOrg string
cdDev bool
iface string
ifaceStartStop string

View File

@@ -144,6 +144,25 @@ func (c *Config) FirstListener() *ListenerConfig {
return c.Listener[strconv.Itoa(listeners[0])]
}
// FirstUpstream returns the first upstream of current config. Upstreams are sorted numerically.
//
// It panics if Config has no upstreams configured.
func (c *Config) FirstUpstream() *UpstreamConfig {
upstreams := make([]int, 0, len(c.Upstream))
for k := range c.Upstream {
n, err := strconv.Atoi(k)
if err != nil {
continue
}
upstreams = append(upstreams, n)
}
if len(upstreams) == 0 {
panic("missing listener config")
}
sort.Ints(upstreams)
return c.Upstream[strconv.Itoa(upstreams[0])]
}
// ServiceConfig specifies the general ctrld config.
type ServiceConfig struct {
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`

View File

@@ -6,8 +6,10 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"time"
@@ -33,6 +35,7 @@ type ResolverConfig struct {
CustomConfig string `json:"custom_config"`
} `json:"ctrld"`
Exclude []string `json:"exclude"`
UID string `json:"uid"`
}
type utilityResponse struct {
@@ -58,19 +61,35 @@ type utilityRequest struct {
ClientID string `json:"client_id,omitempty"`
}
type utilityOrgRequest struct {
ProvToken string `json:"prov_token"`
Hostname string `json:"hostname"`
}
// FetchResolverConfig fetch Control D config for given uid.
func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) {
uid, clientID := ParseRawUID(rawUID)
uReq := utilityRequest{UID: uid}
req := utilityRequest{UID: uid}
if clientID != "" {
uReq.ClientID = clientID
req.ClientID = clientID
}
body, _ := json.Marshal(uReq)
body, _ := json.Marshal(req)
return postUtilityAPI(version, cdDev, bytes.NewReader(body))
}
// FetchResolverUID fetch resolver uid from provision token.
func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) {
hostname, _ := os.Hostname()
body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname})
return postUtilityAPI(version, cdDev, bytes.NewReader(body))
}
func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig, error) {
apiUrl := resolverDataURLCom
if cdDev {
apiUrl = resolverDataURLDev
}
req, err := http.NewRequest("POST", apiUrl, bytes.NewReader(body))
req, err := http.NewRequest("POST", apiUrl, body)
if err != nil {
return nil, fmt.Errorf("http.NewRequest: %w", err)
}