diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 1a93c9d..1e9c541 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -147,6 +147,7 @@ func initCLI() { runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") + runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = runCmd.Flags().MarkHidden("dev") runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") @@ -319,7 +320,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c } else if uid := cdUIDFromProvToken(); uid != "" { cdUID = uid mainLog.Load().Debug().Msg("using uid from provision token") - removeProvTokenFromArgs(sc) + removeOrgFlagsFromArgs(sc) // Pass --cd flag to "ctrld run" command, so the provision token takes no effect. sc.Arguments = append(sc.Arguments, "--cd="+cdUID) } @@ -440,6 +441,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid") startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token") + startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API") startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) @@ -2363,17 +2365,30 @@ func cdUIDFromProvToken() string { if cdOrg == "" { return "" } - + // Validate custom hostname if provided. + if customHostname != "" && !validHostname(customHostname) { + mainLog.Load().Fatal().Msgf("invalid custom hostname: %q", customHostname) + } + req := &controld.UtilityOrgRequest{ProvToken: cdOrg, Hostname: customHostname} // Process provision token if provided. - resolverConfig, err := controld.FetchResolverUID(cdOrg, rootCmd.Version, cdDev) + resolverConfig, err := controld.FetchResolverUID(req, rootCmd.Version, cdDev) if err != nil { mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg) } return resolverConfig.UID } -// removeProvTokenFromArgs removes the --cd-org from command line arguments. -func removeProvTokenFromArgs(sc *service.Config) { +// removeOrgFlagsFromArgs removes organization flags from command line arguments. +// The flags are: +// +// - "--cd-org" +// - "--custom-hostname" +// +// This is necessary because "ctrld run" only need a valid UID, which could be fetched +// using "--cd-org". So if "ctrld start" have already been called with "--cd-org", we +// already have a valid UID to pass to "ctrld run", so we don't have to force "ctrld run" +// to re-do the already done job. +func removeOrgFlagsFromArgs(sc *service.Config) { a := sc.Arguments[:0] skip := false for _, x := range sc.Arguments { @@ -2381,13 +2396,14 @@ func removeProvTokenFromArgs(sc *service.Config) { skip = false continue } - // For "--cd-org XXX", skip it and mark next arg skipped. - if x == "--"+cdOrgFlagName { + // For "--cd-org XXX"/"--custom-hostname XXX", skip them and mark next arg skipped. + if x == "--"+cdOrgFlagName || x == "--"+customHostnameFlagName { skip = true continue } - // For "--cd-org=XXX", just skip it. - if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") { + // For "--cd-org=XXX"/"--custom-hostname=XXX", just skip them. + if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") || + strings.HasPrefix(x, "--"+customHostnameFlagName+"=") { continue } a = append(a, x) diff --git a/cmd/cli/hostname.go b/cmd/cli/hostname.go new file mode 100644 index 0000000..d28435d --- /dev/null +++ b/cmd/cli/hostname.go @@ -0,0 +1,14 @@ +package cli + +import "regexp" + +// validHostname reports whether hostname is a valid hostname. +// A valid hostname contains 3 -> 64 characters and conform to RFC1123. +func validHostname(hostname string) bool { + hostnameLen := len(hostname) + if hostnameLen < 3 || hostnameLen > 64 { + return false + } + validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) + return validHostnameRfc1123.MatchString(hostname) +} diff --git a/cmd/cli/hostname_test.go b/cmd/cli/hostname_test.go new file mode 100644 index 0000000..f44b231 --- /dev/null +++ b/cmd/cli/hostname_test.go @@ -0,0 +1,35 @@ +package cli + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_validHostname(t *testing.T) { + tests := []struct { + name string + hostname string + valid bool + }{ + {"localhost", "localhost", true}, + {"localdomain", "localhost.localdomain", true}, + {"localhost6", "localhost6.localdomain6", true}, + {"ip6", "ip6-localhost", true}, + {"non-domain", "controld", true}, + {"domain", "controld.com", true}, + {"empty", "", false}, + {"min length", "fo", false}, + {"max length", strings.Repeat("a", 65), false}, + {"special char", "foo!", false}, + {"non-ascii", "fooΩ", false}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.hostname, func(t *testing.T) { + t.Parallel() + assert.True(t, validHostname(tc.hostname) == tc.valid) + }) + } +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 146c58d..b8f6d8d 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -29,6 +29,7 @@ var ( silent bool cdUID string cdOrg string + customHostname string cdDev bool iface string ifaceStartStop string @@ -45,9 +46,10 @@ var ( ) const ( - cdUidFlagName = "cd" - cdOrgFlagName = "cd-org" - nextdnsFlagName = "nextdns" + cdUidFlagName = "cd" + cdOrgFlagName = "cd-org" + customHostnameFlagName = "custom-hostname" + nextdnsFlagName = "nextdns" ) func init() { diff --git a/internal/controld/config.go b/internal/controld/config.go index 01e114b..1bc2512 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "net" @@ -64,7 +65,8 @@ type utilityRequest struct { ClientID string `json:"client_id,omitempty"` } -type utilityOrgRequest struct { +// UtilityOrgRequest contains request data for calling Org API. +type UtilityOrgRequest struct { ProvToken string `json:"prov_token"` Hostname string `json:"hostname"` } @@ -81,9 +83,15 @@ func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, e } // FetchResolverUID fetch resolver uid from provision token. -func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) { - hostname, _ := os.Hostname() - body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname}) +func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) { + if req == nil { + return nil, errors.New("invalid request") + } + hostname := req.Hostname + if hostname == "" { + hostname, _ = os.Hostname() + } + body, _ := json.Marshal(UtilityOrgRequest{ProvToken: req.ProvToken, Hostname: hostname}) return postUtilityAPI(version, cdDev, false, bytes.NewReader(body)) }