diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 0e2b642..aaa720f 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -234,6 +234,8 @@ func initCLI() { runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") + runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = runCmd.Flags().MarkHidden("dev") runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") _ = runCmd.Flags().MarkHidden("homedir") runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) @@ -352,6 +354,8 @@ func initCLI() { startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") + startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = startCmd.Flags().MarkHidden("dev") startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`) _ = startCmd.Flags().MarkHidden("router") @@ -706,7 +710,7 @@ func processCDFlags() { } logger := mainLog.With().Str("mode", "cd").Logger() logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) - resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version) + resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { s, err := service.New(&prog{}, svcConfig) if err != nil { diff --git a/cmd/ctrld/cli_router.go b/cmd/ctrld/cli_router.go index 3d54c1c..7688e9a 100644 --- a/cmd/ctrld/cli_router.go +++ b/cmd/ctrld/cli_router.go @@ -77,6 +77,8 @@ func initRouterCLI() { routerCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file") routerCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") routerCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") + routerCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain") + _ = routerCmd.Flags().MarkHidden("dev") routerCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) tmpl := routerCmd.UsageTemplate() diff --git a/cmd/ctrld/main.go b/cmd/ctrld/main.go index 9903a0b..bc3edf0 100644 --- a/cmd/ctrld/main.go +++ b/cmd/ctrld/main.go @@ -27,6 +27,7 @@ var ( verbose int silent bool cdUID string + cdDev bool iface string ifaceStartStop string setupRouter bool diff --git a/internal/controld/config.go b/internal/controld/config.go index 22c18b9..eef98f9 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -17,9 +17,11 @@ import ( ) const ( - apiDomain = "api.controld.com" - resolverDataURL = "https://api.controld.com/utility" - InvalidConfigCode = 40401 + apiDomainCom = "api.controld.com" + apiDomainDev = "api.controld.dev" + resolverDataURLCom = "https://api.controld.com/utility" + resolverDataURLDev = "https://api.controld.dev/utility" + InvalidConfigCode = 40401 ) // ResolverConfig represents Control D resolver data. @@ -54,9 +56,13 @@ type utilityRequest struct { } // FetchResolverConfig fetch Control D config for given uid. -func FetchResolverConfig(uid, version string) (*ResolverConfig, error) { +func FetchResolverConfig(uid, version string, cdDev bool) (*ResolverConfig, error) { body, _ := json.Marshal(utilityRequest{UID: uid}) - req, err := http.NewRequest("POST", resolverDataURL, bytes.NewReader(body)) + apiUrl := resolverDataURLCom + if cdDev { + apiUrl = resolverDataURLDev + } + req, err := http.NewRequest("POST", apiUrl, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("http.NewRequest: %w", err) } @@ -67,6 +73,10 @@ func FetchResolverConfig(uid, version string) (*ResolverConfig, error) { req.Header.Add("Content-Type", "application/json") transport := http.DefaultTransport.(*http.Transport).Clone() transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + apiDomain := apiDomainCom + if cdDev { + apiDomain = apiDomainDev + } ips := ctrld.LookupIP(apiDomain) if len(ips) == 0 { ctrld.ProxyLog.Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr) diff --git a/internal/controld/config_test.go b/internal/controld/config_test.go index d46cf95..2c00247 100644 --- a/internal/controld/config_test.go +++ b/internal/controld/config_test.go @@ -13,16 +13,18 @@ func TestFetchResolverConfig(t *testing.T) { tests := []struct { name string uid string + dev bool wantErr bool }{ - {"valid", "p2", false}, - {"invalid uid", "abcd1234", true}, + {"valid com", "p2", false, false}, + {"valid dev", "p2", true, false}, + {"invalid uid", "abcd1234", false, true}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - got, err := FetchResolverConfig(tc.uid, "dev-test") + got, err := FetchResolverConfig(tc.uid, "dev-test", tc.dev) require.False(t, (err != nil) != tc.wantErr, err) if !tc.wantErr { assert.NotEmpty(t, got.DOH)