all: add flag to use dev domain for testing

This commit is contained in:
Cuong Manh Le
2023-06-03 10:21:35 +07:00
committed by Cuong Manh Le
parent 25eae187db
commit c941f9c621
5 changed files with 28 additions and 9 deletions
+5 -1
View File
@@ -234,6 +234,8 @@ func initCLI() {
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = runCmd.Flags().MarkHidden("dev")
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
_ = runCmd.Flags().MarkHidden("homedir")
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
@@ -352,6 +354,8 @@ func initCLI() {
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = startCmd.Flags().MarkHidden("dev")
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
_ = startCmd.Flags().MarkHidden("router")
@@ -706,7 +710,7 @@ func processCDFlags() {
}
logger := mainLog.With().Str("mode", "cd").Logger()
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version)
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
s, err := service.New(&prog{}, svcConfig)
if err != nil {
+2
View File
@@ -77,6 +77,8 @@ func initRouterCLI() {
routerCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
routerCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
routerCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
routerCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = routerCmd.Flags().MarkHidden("dev")
routerCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
tmpl := routerCmd.UsageTemplate()
+1
View File
@@ -27,6 +27,7 @@ var (
verbose int
silent bool
cdUID string
cdDev bool
iface string
ifaceStartStop string
setupRouter bool
+15 -5
View File
@@ -17,9 +17,11 @@ import (
)
const (
apiDomain = "api.controld.com"
resolverDataURL = "https://api.controld.com/utility"
InvalidConfigCode = 40401
apiDomainCom = "api.controld.com"
apiDomainDev = "api.controld.dev"
resolverDataURLCom = "https://api.controld.com/utility"
resolverDataURLDev = "https://api.controld.dev/utility"
InvalidConfigCode = 40401
)
// ResolverConfig represents Control D resolver data.
@@ -54,9 +56,13 @@ type utilityRequest struct {
}
// FetchResolverConfig fetch Control D config for given uid.
func FetchResolverConfig(uid, version string) (*ResolverConfig, error) {
func FetchResolverConfig(uid, version string, cdDev bool) (*ResolverConfig, error) {
body, _ := json.Marshal(utilityRequest{UID: uid})
req, err := http.NewRequest("POST", resolverDataURL, bytes.NewReader(body))
apiUrl := resolverDataURLCom
if cdDev {
apiUrl = resolverDataURLDev
}
req, err := http.NewRequest("POST", apiUrl, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("http.NewRequest: %w", err)
}
@@ -67,6 +73,10 @@ func FetchResolverConfig(uid, version string) (*ResolverConfig, error) {
req.Header.Add("Content-Type", "application/json")
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
apiDomain := apiDomainCom
if cdDev {
apiDomain = apiDomainDev
}
ips := ctrld.LookupIP(apiDomain)
if len(ips) == 0 {
ctrld.ProxyLog.Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr)
+5 -3
View File
@@ -13,16 +13,18 @@ func TestFetchResolverConfig(t *testing.T) {
tests := []struct {
name string
uid string
dev bool
wantErr bool
}{
{"valid", "p2", false},
{"invalid uid", "abcd1234", true},
{"valid com", "p2", false, false},
{"valid dev", "p2", true, false},
{"invalid uid", "abcd1234", false, true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := FetchResolverConfig(tc.uid, "dev-test")
got, err := FetchResolverConfig(tc.uid, "dev-test", tc.dev)
require.False(t, (err != nil) != tc.wantErr, err)
if !tc.wantErr {
assert.NotEmpty(t, got.DOH)