all: add custom hostname support for provisoning

This commit is contained in:
Cuong Manh Le
2024-10-23 16:00:09 +07:00
committed by Cuong Manh Le
parent 65de7edcde
commit 9d666be5d4
5 changed files with 91 additions and 16 deletions

View File

@@ -147,6 +147,7 @@ func initCLI() {
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = runCmd.Flags().MarkHidden("dev")
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
@@ -319,7 +320,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
} else if uid := cdUIDFromProvToken(); uid != "" {
cdUID = uid
mainLog.Load().Debug().Msg("using uid from provision token")
removeProvTokenFromArgs(sc)
removeOrgFlagsFromArgs(sc)
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
sc.Arguments = append(sc.Arguments, "--cd="+cdUID)
}
@@ -440,6 +441,7 @@ NOTE: running "ctrld start" without any arguments will start already installed c
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = startCmd.Flags().MarkHidden("dev")
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
@@ -2363,17 +2365,30 @@ func cdUIDFromProvToken() string {
if cdOrg == "" {
return ""
}
// Validate custom hostname if provided.
if customHostname != "" && !validHostname(customHostname) {
mainLog.Load().Fatal().Msgf("invalid custom hostname: %q", customHostname)
}
req := &controld.UtilityOrgRequest{ProvToken: cdOrg, Hostname: customHostname}
// Process provision token if provided.
resolverConfig, err := controld.FetchResolverUID(cdOrg, rootCmd.Version, cdDev)
resolverConfig, err := controld.FetchResolverUID(req, rootCmd.Version, cdDev)
if err != nil {
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid with provision token: %s", cdOrg)
}
return resolverConfig.UID
}
// removeProvTokenFromArgs removes the --cd-org from command line arguments.
func removeProvTokenFromArgs(sc *service.Config) {
// removeOrgFlagsFromArgs removes organization flags from command line arguments.
// The flags are:
//
// - "--cd-org"
// - "--custom-hostname"
//
// This is necessary because "ctrld run" only need a valid UID, which could be fetched
// using "--cd-org". So if "ctrld start" have already been called with "--cd-org", we
// already have a valid UID to pass to "ctrld run", so we don't have to force "ctrld run"
// to re-do the already done job.
func removeOrgFlagsFromArgs(sc *service.Config) {
a := sc.Arguments[:0]
skip := false
for _, x := range sc.Arguments {
@@ -2381,13 +2396,14 @@ func removeProvTokenFromArgs(sc *service.Config) {
skip = false
continue
}
// For "--cd-org XXX", skip it and mark next arg skipped.
if x == "--"+cdOrgFlagName {
// For "--cd-org XXX"/"--custom-hostname XXX", skip them and mark next arg skipped.
if x == "--"+cdOrgFlagName || x == "--"+customHostnameFlagName {
skip = true
continue
}
// For "--cd-org=XXX", just skip it.
if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") {
// For "--cd-org=XXX"/"--custom-hostname=XXX", just skip them.
if strings.HasPrefix(x, "--"+cdOrgFlagName+"=") ||
strings.HasPrefix(x, "--"+customHostnameFlagName+"=") {
continue
}
a = append(a, x)

14
cmd/cli/hostname.go Normal file
View File

@@ -0,0 +1,14 @@
package cli
import "regexp"
// validHostname reports whether hostname is a valid hostname.
// A valid hostname contains 3 -> 64 characters and conform to RFC1123.
func validHostname(hostname string) bool {
hostnameLen := len(hostname)
if hostnameLen < 3 || hostnameLen > 64 {
return false
}
validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
return validHostnameRfc1123.MatchString(hostname)
}

35
cmd/cli/hostname_test.go Normal file
View File

@@ -0,0 +1,35 @@
package cli
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_validHostname(t *testing.T) {
tests := []struct {
name string
hostname string
valid bool
}{
{"localhost", "localhost", true},
{"localdomain", "localhost.localdomain", true},
{"localhost6", "localhost6.localdomain6", true},
{"ip6", "ip6-localhost", true},
{"non-domain", "controld", true},
{"domain", "controld.com", true},
{"empty", "", false},
{"min length", "fo", false},
{"max length", strings.Repeat("a", 65), false},
{"special char", "foo!", false},
{"non-ascii", "fooΩ", false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.hostname, func(t *testing.T) {
t.Parallel()
assert.True(t, validHostname(tc.hostname) == tc.valid)
})
}
}

View File

@@ -29,6 +29,7 @@ var (
silent bool
cdUID string
cdOrg string
customHostname string
cdDev bool
iface string
ifaceStartStop string
@@ -45,9 +46,10 @@ var (
)
const (
cdUidFlagName = "cd"
cdOrgFlagName = "cd-org"
nextdnsFlagName = "nextdns"
cdUidFlagName = "cd"
cdOrgFlagName = "cd-org"
customHostnameFlagName = "custom-hostname"
nextdnsFlagName = "nextdns"
)
func init() {

View File

@@ -5,6 +5,7 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@@ -64,7 +65,8 @@ type utilityRequest struct {
ClientID string `json:"client_id,omitempty"`
}
type utilityOrgRequest struct {
// UtilityOrgRequest contains request data for calling Org API.
type UtilityOrgRequest struct {
ProvToken string `json:"prov_token"`
Hostname string `json:"hostname"`
}
@@ -81,9 +83,15 @@ func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, e
}
// FetchResolverUID fetch resolver uid from provision token.
func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) {
hostname, _ := os.Hostname()
body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname})
func FetchResolverUID(req *UtilityOrgRequest, version string, cdDev bool) (*ResolverConfig, error) {
if req == nil {
return nil, errors.New("invalid request")
}
hostname := req.Hostname
if hostname == "" {
hostname, _ = os.Hostname()
}
body, _ := json.Marshal(UtilityOrgRequest{ProvToken: req.ProvToken, Hostname: hostname})
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
}