diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 47002dc..af42c25 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -14,6 +14,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/kardianos/service" + "github.com/pelletier/go-toml/v2" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -87,7 +88,7 @@ func initCLI() { if err := v.Unmarshal(&cfg); err != nil { log.Fatalf("failed to unmarshal config: %v", err) } - processCDFlags(writeDefaultConfig) + processCDFlags() if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil { log.Fatalf("invalid config: %v", err) } @@ -173,10 +174,7 @@ func initCLI() { // the written config won't be writable by SYSTEM account, we have to update // the config here when "--cd" is supplied. if runtime.GOOS == "windows" && cdUID != "" { - if err := v.Unmarshal(&cfg); err != nil { - log.Fatalf("failed to unmarshal config: %v", err) - } - processCDFlags(writeDefaultConfig) + processCDFlags() } } s, err := service.New(&prog{}, sc) @@ -324,7 +322,24 @@ func writeConfigFile() { if cfu := v.ConfigFileUsed(); cfu != "" { defaultConfigFile = cfu } - if err := v.WriteConfigAs(defaultConfigFile); err != nil { + f, err := os.OpenFile(defaultConfigFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644)) + if err != nil { + log.Printf("failed to open config file: %v\n", err) + os.Exit(1) + } + defer f.Close() + if cdUID != "" { + if _, err := f.WriteString("# AUTO-GENERATED VIA CD FLAG - DO NOT MODIFY\n\n"); err != nil { + log.Printf("failed to write header to config file: %v\n", err) + os.Exit(1) + } + } + enc := toml.NewEncoder(f).SetIndentTables(true) + if err := enc.Encode(v.AllSettings()); err != nil { + log.Printf("failed to encode config file: %v\n", err) + os.Exit(1) + } + if err := f.Close(); err != nil { log.Printf("failed to write config file: %v\n", err) } } @@ -399,7 +414,7 @@ func processNoConfigFlags(noConfigStart bool) { processLogAndCacheFlags() } -func processCDFlags(writeConfig bool) { +func processCDFlags() { if cdUID == "" { return } @@ -408,22 +423,36 @@ func processCDFlags(writeConfig bool) { log.Fatalf("failed to fetch resolver config: %v", err) } - u0 := cfg.Upstream["0"] - u0.Name = resolverConfig.DOH - u0.Endpoint = resolverConfig.DOH - u0.Type = ctrld.ResolverTypeDOH - + cfg = ctrld.Config{} + cfg.Network = make(map[string]*ctrld.NetworkConfig) + cfg.Network["0"] = &ctrld.NetworkConfig{ + Name: "Netowrk 0", + Cidrs: []string{"0.0.0.0/0"}, + } + cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) + cfg.Upstream["0"] = &ctrld.UpstreamConfig{ + Endpoint: resolverConfig.DOH, + Type: ctrld.ResolverTypeDOH, + Timeout: 5000, + } rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) for _, domain := range resolverConfig.Exclude { rules = append(rules, ctrld.Rule{domain: []string{}}) } - cfg.Listener["0"].Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules} - - if writeConfig { - v.Set("listener", cfg.Listener) - v.Set("upstream", cfg.Upstream) - writeConfigFile() + cfg.Listener = make(map[string]*ctrld.ListenerConfig) + cfg.Listener["0"] = &ctrld.ListenerConfig{ + IP: "127.0.0.1", + Port: 53, + Policy: &ctrld.ListenerPolicyConfig{ + Name: "My Policy", + Rules: rules, + }, } + + v.Set("network", cfg.Network) + v.Set("upstream", cfg.Upstream) + v.Set("listener", cfg.Listener) + writeConfigFile() } func processListenFlag() { diff --git a/cmd/ctrld/net.go b/cmd/ctrld/net.go index ce98405..ef2d47d 100644 --- a/cmd/ctrld/net.go +++ b/cmd/ctrld/net.go @@ -3,18 +3,17 @@ package main import ( "net" "sync" - - "golang.org/x/net/nettest" ) +const controldIPv6Test = "ipv6.controld.io" + var ( stackOnce sync.Once ipv6Enabled bool ) func probeStack() { - // TODO(cuonglm): use nettest.SupportsIPv6 once https://github.com/golang/go/issues/57386 fixed. - if _, err := nettest.RoutedInterface("ip6", net.FlagUp); err == nil { + if _, err := net.Dial("tcp6", controldIPv6Test); err == nil { ipv6Enabled = true } } diff --git a/config.go b/config.go index a5514ea..8916cb4 100644 --- a/config.go +++ b/config.go @@ -68,7 +68,7 @@ func InitConfig(v *viper.Viper, name string) { // Config represents ctrld supported configuration. type Config struct { - Service ServiceConfig `mapstructure:"service"` + Service ServiceConfig `mapstructure:"service" toml:"service,omitempty"` Network map[string]*NetworkConfig `mapstructure:"network" toml:"network" validate:"min=1,dive"` Upstream map[string]*UpstreamConfig `mapstructure:"upstream" toml:"upstream" validate:"min=1,dive"` Listener map[string]*ListenerConfig `mapstructure:"listener" toml:"listener" validate:"min=1,dive"` @@ -76,49 +76,49 @@ type Config struct { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level"` - LogPath string `mapstructure:"log_path" toml:"log_path"` - CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable"` - CacheSize int `mapstructure:"cache_size" toml:"cache_size"` - CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override"` - CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale"` + LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` + LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` Daemon bool `mapstructure:"-" toml:"-"` AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. type NetworkConfig struct { - Name string `mapstructure:"name" toml:"name"` - Cidrs []string `mapstructure:"cidrs" toml:"cidrs" validate:"dive,cidr"` + Name string `mapstructure:"name" toml:"name,omitempty"` + Cidrs []string `mapstructure:"cidrs" toml:"cidrs,omitempty" validate:"dive,cidr"` IPNets []*net.IPNet `mapstructure:"-" toml:"-"` } // UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to. type UpstreamConfig struct { - Name string `mapstructure:"name" toml:"name"` - Type string `mapstructure:"type" toml:"type" validate:"oneof=doh doh3 dot doq os legacy"` - Endpoint string `mapstructure:"endpoint" toml:"endpoint" validate:"required_unless=Type os"` - BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip"` + Name string `mapstructure:"name" toml:"name,omitempty"` + Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"` + Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"` + BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"` Domain string `mapstructure:"-" toml:"-"` - Timeout int `mapstructure:"timeout" toml:"timeout" validate:"gte=0"` + Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"` transport *http.Transport `mapstructure:"-" toml:"-"` http3RoundTripper *http3.RoundTripper `mapstructure:"-" toml:"-"` } // ListenerConfig specifies the networks configuration that ctrld will run on. type ListenerConfig struct { - IP string `mapstructure:"ip" toml:"ip" validate:"ip"` - Port int `mapstructure:"port" toml:"port" validate:"gt=0"` - Restricted bool `mapstructure:"restricted" toml:"restricted"` - Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy"` + IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"ip"` + Port int `mapstructure:"port" toml:"port,omitempty" validate:"gt=0"` + Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"` + Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"` } // ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests. type ListenerPolicyConfig struct { - Name string `mapstructure:"name" toml:"name"` - Networks []Rule `mapstructure:"networks" toml:"networks" validate:"dive,len=1"` - Rules []Rule `mapstructure:"rules" toml:"rules" validate:"dive,len=1"` - FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes" validate:"dive,dnsrcode"` + Name string `mapstructure:"name" toml:"name,omitempty"` + Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"` + Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"` + FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"` FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"` }