cmd/ctrld: update config when "--cd" present

This commit is contained in:
Cuong Manh Le
2023-01-06 23:56:26 +07:00
committed by Cuong Manh Le
parent 6edd42629e
commit 9f90811567
3 changed files with 50 additions and 30 deletions

View File

@@ -3,6 +3,7 @@ package main
import (
"bytes"
"encoding/base64"
"fmt"
"log"
"net"
"os"
@@ -13,7 +14,6 @@ import (
"github.com/go-playground/validator/v10"
"github.com/kardianos/service"
"github.com/pelletier/go-toml"
"github.com/spf13/cobra"
"github.com/spf13/viper"
@@ -64,8 +64,8 @@ func initCLI() {
log.Fatal("Cannot run in daemon mode. Please install a Windows service.")
}
noConfigStart := isNoConfigStart(cmd) && cdUID != ""
writeDefaultConfig := !noConfigStart && configBase64 == "" && cdUID == ""
noConfigStart := isNoConfigStart(cmd)
writeDefaultConfig := !noConfigStart && configBase64 == ""
configs := []struct {
name string
written bool
@@ -84,10 +84,10 @@ func initCLI() {
readBase64Config()
processNoConfigFlags(noConfigStart)
processCDFlags()
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
}
processCDFlags(writeDefaultConfig)
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
log.Fatalf("invalid config: %v", err)
}
@@ -151,25 +151,35 @@ func initCLI() {
Short: "Start the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
cfg := &service.Config{}
*cfg = *svcConfig
sc := &service.Config{}
*sc = *svcConfig
osArgs := os.Args[2:]
if os.Args[1] == "service" {
osArgs = os.Args[3:]
}
cfg.Arguments = append([]string{"run"}, osArgs...)
sc.Arguments = append([]string{"run"}, osArgs...)
if dir, err := os.UserHomeDir(); err == nil {
// WorkingDirectory is not supported on Windows.
cfg.WorkingDirectory = dir
sc.WorkingDirectory = dir
// No config path, generating config in HOME directory.
noConfigStart := isNoConfigStart(cmd) && cdUID != ""
writeDefaultConfig := !noConfigStart && configBase64 == "" && cdUID == ""
noConfigStart := isNoConfigStart(cmd)
writeDefaultConfig := !noConfigStart && configBase64 == ""
if configPath == "" && writeDefaultConfig {
defaultConfigFile = filepath.Join(dir, defaultConfigFile)
readConfigFile(true)
}
// On Windows, the service will be run as SYSTEM, so if ctrld start as Admin,
// the written config won't be writable by SYSTEM account, we have to update
// the config here when "--cd" is supplied.
if runtime.GOOS == "windows" && cdUID != "" {
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
}
processCDFlags(writeDefaultConfig)
}
}
s, err := service.New(&prog{}, cfg)
s, err := service.New(&prog{}, sc)
if err != nil {
stderrMsg(err.Error())
return
@@ -311,12 +321,10 @@ func initCLI() {
}
func writeConfigFile() {
c := v.AllSettings()
bs, err := toml.Marshal(c)
if err != nil {
log.Fatalf("unable to marshal config to toml: %v", err)
if cfu := v.ConfigFileUsed(); cfu != "" {
defaultConfigFile = cfu
}
if err := os.WriteFile(defaultConfigFile, bs, 0600); err != nil {
if err := v.WriteConfigAs(defaultConfigFile); err != nil {
log.Printf("failed to write config file: %v\n", err)
}
}
@@ -325,6 +333,7 @@ func readConfigFile(writeDefaultConfig bool) bool {
// If err == nil, there's a config supplied via `--config`, no default config written.
err := v.ReadInConfig()
if err == nil {
fmt.Println("loading config file from: ", v.ConfigFileUsed())
return true
}
@@ -390,7 +399,7 @@ func processNoConfigFlags(noConfigStart bool) {
processLogAndCacheFlags()
}
func processCDFlags() {
func processCDFlags(writeConfig bool) {
if cdUID == "" {
return
}
@@ -399,25 +408,22 @@ func processCDFlags() {
log.Fatalf("failed to fetch resolver config: %v", err)
}
upstream := map[string]*ctrld.UpstreamConfig{
"0": {
Name: resolverConfig.DOH,
Endpoint: resolverConfig.DOH,
Type: ctrld.ResolverTypeDOH,
},
}
v.Set("upstream", upstream)
processListenFlag()
u0 := cfg.Upstream["0"]
u0.Name = resolverConfig.DOH
u0.Endpoint = resolverConfig.DOH
u0.Type = ctrld.ResolverTypeDOH
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
cfg.Listener["0"].Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
processLogAndCacheFlags()
if writeConfig {
v.Set("listener", cfg.Listener)
v.Set("upstream", cfg.Upstream)
writeConfigFile()
}
}
func processListenFlag() {

View File

@@ -314,6 +314,17 @@ Above policy will:
- Forward requests on `listener.0` for `test.com` to `upstream.2`. If timeout is reached, retry on `upstream.1`.
- All other requests on `listener.0` that do not match above conditions will be forwarded to `upstream.0`.
An empty upstream would not route the request to any defined upstreams, and use the OS default resolver.
```toml
[listener.0.policy]
name = "OS Resolver"
rules = [
{"*.local" = []},
]
```
#### name
`name` is the name for the policy.

View File

@@ -41,6 +41,9 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) {
if err != nil {
return nil, fmt.Errorf("http.NewRequest: %w", err)
}
q := req.URL.Query()
q.Set("platform", "ctrld")
req.URL.RawQuery = q.Encode()
req.Header.Add("Content-Type", "application/json")
client := http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)