cmd/ctrld: make "--cd" always owerwrites the config

While at it, also make the toml encoded config format nicer.
This commit is contained in:
Cuong Manh Le
2023-01-10 10:17:49 +07:00
committed by Cuong Manh Le
parent 9f90811567
commit 279e938b2a
3 changed files with 72 additions and 44 deletions

View File

@@ -14,6 +14,7 @@ import (
"github.com/go-playground/validator/v10"
"github.com/kardianos/service"
"github.com/pelletier/go-toml/v2"
"github.com/spf13/cobra"
"github.com/spf13/viper"
@@ -87,7 +88,7 @@ func initCLI() {
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
}
processCDFlags(writeDefaultConfig)
processCDFlags()
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
log.Fatalf("invalid config: %v", err)
}
@@ -173,10 +174,7 @@ func initCLI() {
// the written config won't be writable by SYSTEM account, we have to update
// the config here when "--cd" is supplied.
if runtime.GOOS == "windows" && cdUID != "" {
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
}
processCDFlags(writeDefaultConfig)
processCDFlags()
}
}
s, err := service.New(&prog{}, sc)
@@ -324,7 +322,24 @@ func writeConfigFile() {
if cfu := v.ConfigFileUsed(); cfu != "" {
defaultConfigFile = cfu
}
if err := v.WriteConfigAs(defaultConfigFile); err != nil {
f, err := os.OpenFile(defaultConfigFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644))
if err != nil {
log.Printf("failed to open config file: %v\n", err)
os.Exit(1)
}
defer f.Close()
if cdUID != "" {
if _, err := f.WriteString("# AUTO-GENERATED VIA CD FLAG - DO NOT MODIFY\n\n"); err != nil {
log.Printf("failed to write header to config file: %v\n", err)
os.Exit(1)
}
}
enc := toml.NewEncoder(f).SetIndentTables(true)
if err := enc.Encode(v.AllSettings()); err != nil {
log.Printf("failed to encode config file: %v\n", err)
os.Exit(1)
}
if err := f.Close(); err != nil {
log.Printf("failed to write config file: %v\n", err)
}
}
@@ -399,7 +414,7 @@ func processNoConfigFlags(noConfigStart bool) {
processLogAndCacheFlags()
}
func processCDFlags(writeConfig bool) {
func processCDFlags() {
if cdUID == "" {
return
}
@@ -408,22 +423,36 @@ func processCDFlags(writeConfig bool) {
log.Fatalf("failed to fetch resolver config: %v", err)
}
u0 := cfg.Upstream["0"]
u0.Name = resolverConfig.DOH
u0.Endpoint = resolverConfig.DOH
u0.Type = ctrld.ResolverTypeDOH
cfg = ctrld.Config{}
cfg.Network = make(map[string]*ctrld.NetworkConfig)
cfg.Network["0"] = &ctrld.NetworkConfig{
Name: "Netowrk 0",
Cidrs: []string{"0.0.0.0/0"},
}
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
Endpoint: resolverConfig.DOH,
Type: ctrld.ResolverTypeDOH,
Timeout: 5000,
}
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
cfg.Listener["0"].Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
if writeConfig {
v.Set("listener", cfg.Listener)
v.Set("upstream", cfg.Upstream)
writeConfigFile()
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
cfg.Listener["0"] = &ctrld.ListenerConfig{
IP: "127.0.0.1",
Port: 53,
Policy: &ctrld.ListenerPolicyConfig{
Name: "My Policy",
Rules: rules,
},
}
v.Set("network", cfg.Network)
v.Set("upstream", cfg.Upstream)
v.Set("listener", cfg.Listener)
writeConfigFile()
}
func processListenFlag() {

View File

@@ -3,18 +3,17 @@ package main
import (
"net"
"sync"
"golang.org/x/net/nettest"
)
const controldIPv6Test = "ipv6.controld.io"
var (
stackOnce sync.Once
ipv6Enabled bool
)
func probeStack() {
// TODO(cuonglm): use nettest.SupportsIPv6 once https://github.com/golang/go/issues/57386 fixed.
if _, err := nettest.RoutedInterface("ip6", net.FlagUp); err == nil {
if _, err := net.Dial("tcp6", controldIPv6Test); err == nil {
ipv6Enabled = true
}
}

View File

@@ -68,7 +68,7 @@ func InitConfig(v *viper.Viper, name string) {
// Config represents ctrld supported configuration.
type Config struct {
Service ServiceConfig `mapstructure:"service"`
Service ServiceConfig `mapstructure:"service" toml:"service,omitempty"`
Network map[string]*NetworkConfig `mapstructure:"network" toml:"network" validate:"min=1,dive"`
Upstream map[string]*UpstreamConfig `mapstructure:"upstream" toml:"upstream" validate:"min=1,dive"`
Listener map[string]*ListenerConfig `mapstructure:"listener" toml:"listener" validate:"min=1,dive"`
@@ -76,49 +76,49 @@ type Config struct {
// ServiceConfig specifies the general ctrld config.
type ServiceConfig struct {
LogLevel string `mapstructure:"log_level" toml:"log_level"`
LogPath string `mapstructure:"log_path" toml:"log_path"`
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable"`
CacheSize int `mapstructure:"cache_size" toml:"cache_size"`
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override"`
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale"`
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
Daemon bool `mapstructure:"-" toml:"-"`
AllocateIP bool `mapstructure:"-" toml:"-"`
}
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
type NetworkConfig struct {
Name string `mapstructure:"name" toml:"name"`
Cidrs []string `mapstructure:"cidrs" toml:"cidrs" validate:"dive,cidr"`
Name string `mapstructure:"name" toml:"name,omitempty"`
Cidrs []string `mapstructure:"cidrs" toml:"cidrs,omitempty" validate:"dive,cidr"`
IPNets []*net.IPNet `mapstructure:"-" toml:"-"`
}
// UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to.
type UpstreamConfig struct {
Name string `mapstructure:"name" toml:"name"`
Type string `mapstructure:"type" toml:"type" validate:"oneof=doh doh3 dot doq os legacy"`
Endpoint string `mapstructure:"endpoint" toml:"endpoint" validate:"required_unless=Type os"`
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip"`
Name string `mapstructure:"name" toml:"name,omitempty"`
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
Domain string `mapstructure:"-" toml:"-"`
Timeout int `mapstructure:"timeout" toml:"timeout" validate:"gte=0"`
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
transport *http.Transport `mapstructure:"-" toml:"-"`
http3RoundTripper *http3.RoundTripper `mapstructure:"-" toml:"-"`
}
// ListenerConfig specifies the networks configuration that ctrld will run on.
type ListenerConfig struct {
IP string `mapstructure:"ip" toml:"ip" validate:"ip"`
Port int `mapstructure:"port" toml:"port" validate:"gt=0"`
Restricted bool `mapstructure:"restricted" toml:"restricted"`
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy"`
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"ip"`
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gt=0"`
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
}
// ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests.
type ListenerPolicyConfig struct {
Name string `mapstructure:"name" toml:"name"`
Networks []Rule `mapstructure:"networks" toml:"networks" validate:"dive,len=1"`
Rules []Rule `mapstructure:"rules" toml:"rules" validate:"dive,len=1"`
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes" validate:"dive,dnsrcode"`
Name string `mapstructure:"name" toml:"name,omitempty"`
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
}