From 9f908115679315e35de23daa2072ba81b2f3994c Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 6 Jan 2023 23:56:26 +0700 Subject: [PATCH] cmd/ctrld: update config when "--cd" present --- cmd/ctrld/cli.go | 66 ++++++++++++++++++++----------------- docs/config.md | 11 +++++++ internal/controld/config.go | 3 ++ 3 files changed, 50 insertions(+), 30 deletions(-) diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 293d30e..47002dc 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -3,6 +3,7 @@ package main import ( "bytes" "encoding/base64" + "fmt" "log" "net" "os" @@ -13,7 +14,6 @@ import ( "github.com/go-playground/validator/v10" "github.com/kardianos/service" - "github.com/pelletier/go-toml" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -64,8 +64,8 @@ func initCLI() { log.Fatal("Cannot run in daemon mode. Please install a Windows service.") } - noConfigStart := isNoConfigStart(cmd) && cdUID != "" - writeDefaultConfig := !noConfigStart && configBase64 == "" && cdUID == "" + noConfigStart := isNoConfigStart(cmd) + writeDefaultConfig := !noConfigStart && configBase64 == "" configs := []struct { name string written bool @@ -84,10 +84,10 @@ func initCLI() { readBase64Config() processNoConfigFlags(noConfigStart) - processCDFlags() if err := v.Unmarshal(&cfg); err != nil { log.Fatalf("failed to unmarshal config: %v", err) } + processCDFlags(writeDefaultConfig) if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil { log.Fatalf("invalid config: %v", err) } @@ -151,25 +151,35 @@ func initCLI() { Short: "Start the ctrld service", Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { - cfg := &service.Config{} - *cfg = *svcConfig + sc := &service.Config{} + *sc = *svcConfig osArgs := os.Args[2:] if os.Args[1] == "service" { osArgs = os.Args[3:] } - cfg.Arguments = append([]string{"run"}, osArgs...) + sc.Arguments = append([]string{"run"}, osArgs...) if dir, err := os.UserHomeDir(); err == nil { // WorkingDirectory is not supported on Windows. - cfg.WorkingDirectory = dir + sc.WorkingDirectory = dir // No config path, generating config in HOME directory. - noConfigStart := isNoConfigStart(cmd) && cdUID != "" - writeDefaultConfig := !noConfigStart && configBase64 == "" && cdUID == "" + noConfigStart := isNoConfigStart(cmd) + writeDefaultConfig := !noConfigStart && configBase64 == "" if configPath == "" && writeDefaultConfig { defaultConfigFile = filepath.Join(dir, defaultConfigFile) readConfigFile(true) } + + // On Windows, the service will be run as SYSTEM, so if ctrld start as Admin, + // the written config won't be writable by SYSTEM account, we have to update + // the config here when "--cd" is supplied. + if runtime.GOOS == "windows" && cdUID != "" { + if err := v.Unmarshal(&cfg); err != nil { + log.Fatalf("failed to unmarshal config: %v", err) + } + processCDFlags(writeDefaultConfig) + } } - s, err := service.New(&prog{}, cfg) + s, err := service.New(&prog{}, sc) if err != nil { stderrMsg(err.Error()) return @@ -311,12 +321,10 @@ func initCLI() { } func writeConfigFile() { - c := v.AllSettings() - bs, err := toml.Marshal(c) - if err != nil { - log.Fatalf("unable to marshal config to toml: %v", err) + if cfu := v.ConfigFileUsed(); cfu != "" { + defaultConfigFile = cfu } - if err := os.WriteFile(defaultConfigFile, bs, 0600); err != nil { + if err := v.WriteConfigAs(defaultConfigFile); err != nil { log.Printf("failed to write config file: %v\n", err) } } @@ -325,6 +333,7 @@ func readConfigFile(writeDefaultConfig bool) bool { // If err == nil, there's a config supplied via `--config`, no default config written. err := v.ReadInConfig() if err == nil { + fmt.Println("loading config file from: ", v.ConfigFileUsed()) return true } @@ -390,7 +399,7 @@ func processNoConfigFlags(noConfigStart bool) { processLogAndCacheFlags() } -func processCDFlags() { +func processCDFlags(writeConfig bool) { if cdUID == "" { return } @@ -399,25 +408,22 @@ func processCDFlags() { log.Fatalf("failed to fetch resolver config: %v", err) } - upstream := map[string]*ctrld.UpstreamConfig{ - "0": { - Name: resolverConfig.DOH, - Endpoint: resolverConfig.DOH, - Type: ctrld.ResolverTypeDOH, - }, - } - v.Set("upstream", upstream) - - processListenFlag() + u0 := cfg.Upstream["0"] + u0.Name = resolverConfig.DOH + u0.Endpoint = resolverConfig.DOH + u0.Type = ctrld.ResolverTypeDOH rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) for _, domain := range resolverConfig.Exclude { rules = append(rules, ctrld.Rule{domain: []string{}}) } - lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"] - lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules} + cfg.Listener["0"].Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules} - processLogAndCacheFlags() + if writeConfig { + v.Set("listener", cfg.Listener) + v.Set("upstream", cfg.Upstream) + writeConfigFile() + } } func processListenFlag() { diff --git a/docs/config.md b/docs/config.md index e8b30e8..4f12736 100644 --- a/docs/config.md +++ b/docs/config.md @@ -314,6 +314,17 @@ Above policy will: - Forward requests on `listener.0` for `test.com` to `upstream.2`. If timeout is reached, retry on `upstream.1`. - All other requests on `listener.0` that do not match above conditions will be forwarded to `upstream.0`. +An empty upstream would not route the request to any defined upstreams, and use the OS default resolver. + +```toml +[listener.0.policy] +name = "OS Resolver" + +rules = [ + {"*.local" = []}, +] +``` + #### name `name` is the name for the policy. diff --git a/internal/controld/config.go b/internal/controld/config.go index c4ed22a..ca6303b 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -41,6 +41,9 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) { if err != nil { return nil, fmt.Errorf("http.NewRequest: %w", err) } + q := req.URL.Query() + q.Set("platform", "ctrld") + req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") client := http.Client{Timeout: 5 * time.Second} resp, err := client.Do(req)