diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index f62cb87..a63b3cc 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -17,6 +17,7 @@ import ( "github.com/spf13/viper" "github.com/Control-D-Inc/ctrld" + "github.com/Control-D-Inc/ctrld/internal/controld" ) var ( @@ -61,15 +62,15 @@ func initCLI() { log.Fatal("Cannot run in daemon mode. Please install a Windows service.") } - noConfigStart := isNoConfigStart(cmd) - + noConfigStart := isNoConfigStart(cmd) && cdUID != "" + writeDefaultConfig := !noConfigStart && configBase64 == "" configs := []struct { name string written bool }{ // For compatibility, we check for config.toml first, but only read it if exists. {"config", false}, - {"ctrld", !noConfigStart && configBase64 == ""}, + {"ctrld", writeDefaultConfig}, } for _, config := range configs { ctrld.SetConfigName(v, config.name) @@ -81,6 +82,7 @@ func initCLI() { readBase64Config() processNoConfigFlags(noConfigStart) + processCDFlags() if err := v.Unmarshal(&cfg); err != nil { log.Fatalf("failed to unmarshal config: %v", err) } @@ -138,6 +140,7 @@ func initCLI() { runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy") runCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file") runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") + runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") rootCmd.AddCommand(runCmd) @@ -183,6 +186,7 @@ func initCLI() { startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy") startCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file") startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items") + startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") stopCmd := &cobra.Command{ Use: "stop", @@ -308,21 +312,24 @@ func writeConfigFile() { } } -func readConfigFile(configWritten bool) bool { +func readConfigFile(writeDefaultConfig bool) bool { + // If err == nil, there's a config supplied via `--config`, no default config written. err := v.ReadInConfig() if err == nil { return true } - if !configWritten { + if !writeDefaultConfig { return false } + // If error is viper.ConfigFileNotFoundError, write default config. if _, ok := err.(viper.ConfigFileNotFoundError); ok { writeConfigFile() defaultConfigWritten = true return false } + // Otherwise, report fatal error and exit. log.Fatalf("failed to decode config file: %v", err) return false } @@ -347,21 +354,7 @@ func processNoConfigFlags(noConfigStart bool) { if listenAddress == "" || primaryUpstream == "" { log.Fatal(`"listen" and "primary_upstream" flags must be set in no config mode`) } - host, portStr, err := net.SplitHostPort(listenAddress) - if err != nil { - log.Fatalf("invalid listener address: %v", err) - } - port, err := strconv.Atoi(portStr) - if err != nil { - log.Fatalf("invalid port number: %v", err) - } - lc := &ctrld.ListenerConfig{ - IP: host, - Port: port, - } - v.Set("listener", map[string]*ctrld.ListenerConfig{ - "0": lc, - }) + processListenFlag() upstream := map[string]*ctrld.UpstreamConfig{ "0": { @@ -380,10 +373,67 @@ func processNoConfigFlags(noConfigStart bool) { for _, domain := range domains { rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}}) } + lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"] lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules} } v.Set("upstream", upstream) + processLogAndCacheFlags() +} + +func processCDFlags() { + if cdUID == "" { + return + } + resolverConfig, err := controld.FetchResolverConfig(cdUID) + if err != nil { + log.Fatalf("failed to fetch resolver config: %v", err) + } + + upstream := map[string]*ctrld.UpstreamConfig{ + "0": { + BootstrapIP: resolverConfig.IP(supportsIPv6()), + Name: resolverConfig.DOH, + Endpoint: resolverConfig.DOH, + Type: ctrld.ResolverTypeDOH, + }, + } + v.Set("upstream", upstream) + + processListenFlag() + + rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) + for _, domain := range resolverConfig.Exclude { + rules = append(rules, ctrld.Rule{domain: []string{}}) + } + lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"] + lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules} + + processLogAndCacheFlags() +} + +func processListenFlag() { + if listenAddress == "" { + return + } + host, portStr, err := net.SplitHostPort(listenAddress) + if err != nil { + log.Fatalf("invalid listener address: %v", err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + log.Fatalf("invalid port number: %v", err) + } + lc := &ctrld.ListenerConfig{ + IP: host, + Port: port, + } + v.Set("listener", map[string]*ctrld.ListenerConfig{ + "0": lc, + }) +} + +func processLogAndCacheFlags() { sc := ctrld.ServiceConfig{} if logPath != "" { sc.LogLevel = "debug" diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index b203668..679d292 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -143,6 +143,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i } } upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) + if len(upstreamConfigs) == 0 { + upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + upstreams = []string{"upstream.os"} + } resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) dnsResolver, err := ctrld.NewResolver(upstreamConfig) @@ -204,9 +208,6 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U upstreamNum := strings.TrimPrefix(upstream, "upstream.") upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum]) } - if len(upstreamConfigs) == 0 { - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - } return upstreamConfigs } diff --git a/cmd/ctrld/main.go b/cmd/ctrld/main.go index 4918020..d4362ba 100644 --- a/cmd/ctrld/main.go +++ b/cmd/ctrld/main.go @@ -31,6 +31,8 @@ var ( rootLogger = zerolog.New(io.Discard) mainLog = rootLogger proxyLog = rootLogger + + cdUID string ) func main() { diff --git a/docs/basic_mode.md b/docs/basic_mode.md index dda3801..5e2bcda 100644 --- a/docs/basic_mode.md +++ b/docs/basic_mode.md @@ -25,6 +25,8 @@ Usage: Flags: --base64_config string base64 encoded config + --cache_size int Enable cache with size items + --cd string Control D resolver uid -c, --config string Path to config file -d, --daemon Run as daemon --domains strings list of domain to apply in a split DNS policy diff --git a/docs/controld_config.md b/docs/controld_config.md new file mode 100644 index 0000000..20c9c5b --- /dev/null +++ b/docs/controld_config.md @@ -0,0 +1,15 @@ +# Control D config + +`ctrld` can build a Control D config and run with the specific resolver data. + +For example: + +```shell +ctrld run --cd p2 +``` + +Above command will fetch the `p2` resolver data from Control D API and use that data for running `ctrld`: + + - The resolver `doh` endpoint will be used as the primary upstream. + - The resolver `exclude` list will be used to create a rule policy which will steer them to the default OS resolver. +``` diff --git a/internal/controld/config.go b/internal/controld/config.go new file mode 100644 index 0000000..69b41a8 --- /dev/null +++ b/internal/controld/config.go @@ -0,0 +1,90 @@ +package controld + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" +) + +const resolverDataURL = "https://api.controld.com/utility" + +// ResolverConfig represents Control D resolver data. +type ResolverConfig struct { + V4 []string `json:"v4"` + V6 []string `json:"v6"` + DOH string `json:"doh"` + Exclude []string `json:"exclude"` +} + +func (r *ResolverConfig) IP(v6 bool) string { + ip4 := r.v4() + ip6 := r.v6() + if v6 && ip6 != "" { + return ip6 + } + return ip4 +} + +func (r *ResolverConfig) v4() string { + for _, ip := range r.V4 { + return ip + } + return "" +} + +func (r *ResolverConfig) v6() string { + for _, ip := range r.V6 { + return ip + } + return "" +} + +type utilityResponse struct { + Success bool `json:"success"` + Body struct { + Resolver ResolverConfig `json:"resolver"` + } `json:"body"` +} + +type utilityErrorResponse struct { + Error struct { + Message string `json:"message"` + } `json:"error"` +} + +type utilityRequest struct { + UID string `json:"uid"` +} + +// FetchResolverConfig fetch Control D config for given uid. +func FetchResolverConfig(uid string) (*ResolverConfig, error) { + body, _ := json.Marshal(utilityRequest{UID: uid}) + req, err := http.NewRequest("POST", resolverDataURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("http.NewRequest: %w", err) + } + req.Header.Add("Content-Type", "application/json") + client := http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("client.Do: %w", err) + } + defer resp.Body.Close() + d := json.NewDecoder(resp.Body) + if resp.StatusCode != http.StatusOK { + errResp := &utilityErrorResponse{} + if err := d.Decode(errResp); err != nil { + return nil, err + } + return nil, errors.New(errResp.Error.Message) + } + + ur := &utilityResponse{} + if err := d.Decode(ur); err != nil { + return nil, err + } + return &ur.Body.Resolver, nil +} diff --git a/internal/controld/config_test.go b/internal/controld/config_test.go new file mode 100644 index 0000000..3c09ed7 --- /dev/null +++ b/internal/controld/config_test.go @@ -0,0 +1,33 @@ +//go:build controld + +package controld + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const utilityURL = "https://api.controld.com/utility" + +func TestFetchResolverConfig(t *testing.T) { + tests := []struct { + name string + uid string + wantErr bool + }{ + {"valid", "p2", false}, + {"invalid uid", "abcd1234", true}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := FetchResolverConfig(tc.uid) + assert.False(t, (err != nil) != tc.wantErr) + if !tc.wantErr { + assert.NotEmpty(t, got.DOH) + } + }) + } +}