cmd/cli: validate remote config during "ctrld start"

On BSD, the service is made un-killable since v1.3.4 by using daemon
command "-r" option. However, when reading remote config, the ctrld will
fatally exit if the config is malformed. This causes daemon respawn new
ctrld process immediately, causing the "ctrld start" command hang
forever because of restart loop.

Since "ctrld start" already fetch the resolver config for validating
uid, it should validate the remote config, too. This allows better error
message printed to users, let them know that the config is invalid.

Further, if the remote config was invalid, we should disregard it and
generating the default working one in cd mode.
This commit is contained in:
Cuong Manh Le
2024-03-05 16:27:11 +07:00
committed by Cuong Manh Le
parent cc6ae290f8
commit 64bcd2f00d

View File

@@ -171,9 +171,38 @@ func initCLI() {
setDependencies(sc)
sc.Arguments = append([]string{"run"}, osArgs...)
if cdUID != "" {
if _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); err != nil {
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
if err != nil {
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID)
}
// validateCdRemoteConfig clobbers v, saving it here to restore later.
oldV := v
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
if errors.As(err, &viper.ConfigParseError{}) {
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
tmpDir := os.TempDir()
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
errorLogged := false
// Write remote config to a temporary file to get details error.
if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil {
if de := decoderErrorFromTomlFile(tmpConfFile); de != nil {
row, col := de.Position()
mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error())
errorLogged = true
}
_ = os.Remove(tmpConfFile)
}
// If we could not log details error, emit what we have already got.
if !errorLogged {
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
}
}
} else {
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
}
mainLog.Load().Warn().Msg("disregarding invalid custom config")
}
v = oldV
} else if uid := cdUIDFromProvToken(); uid != "" {
cdUID = uid
removeProvTokenFromArgs(sc)
@@ -912,7 +941,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
writeDefaultConfig := !noConfigStart && configBase64 == ""
tryReadingConfig(writeDefaultConfig)
readBase64Config(configBase64)
if err := readBase64Config(configBase64); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config")
}
processNoConfigFlags(noConfigStart)
p.mu.Lock()
if err := v.Unmarshal(&cfg); err != nil {
@@ -1141,7 +1172,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
}
// If error is viper.ConfigFileNotFoundError, write default config.
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
if errors.As(err, &viper.ConfigFileNotFoundError{}) {
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
}
@@ -1162,13 +1193,11 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
return false
}
if _, ok := err.(viper.ConfigParseError); ok {
if f, _ := os.Open(v.ConfigFileUsed()); f != nil {
var i any
if err, ok := toml.NewDecoder(f).Decode(&i).(*toml.DecodeError); ok {
row, col := err.Position()
mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err)
}
// If error is viper.ConfigParseError, emit details line and column number.
if errors.As(err, &viper.ConfigParseError{}) {
if de := decoderErrorFromTomlFile(v.ConfigFileUsed()); de != nil {
row, col := de.Position()
mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err)
}
}
@@ -1177,13 +1206,27 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
return false
}
func readBase64Config(configBase64 string) {
// decoderErrorFromTomlFile parses the invalid toml file, returning the details decoder error.
func decoderErrorFromTomlFile(cf string) *toml.DecodeError {
if f, _ := os.Open(cf); f != nil {
defer f.Close()
var i any
var de *toml.DecodeError
if err := toml.NewDecoder(f).Decode(&i); err != nil && errors.As(err, &de) {
return de
}
}
return nil
}
// readBase64Config reads ctrld config from the base64 input string.
func readBase64Config(configBase64 string) error {
if configBase64 == "" {
return
return nil
}
configStr, err := base64.StdEncoding.DecodeString(configBase64)
if err != nil {
mainLog.Load().Fatal().Msgf("invalid base64 config: %v", err)
return fmt.Errorf("invalid base64 config: %w", err)
}
// readBase64Config is called when:
@@ -1194,9 +1237,7 @@ func readBase64Config(configBase64 string) {
// So we need to re-create viper instance to discard old one.
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
v.SetConfigType("toml")
if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil {
mainLog.Load().Fatal().Msgf("failed to read base64 config: %v", err)
}
return v.ReadConfig(bytes.NewReader(configStr))
}
func processNoConfigFlags(noConfigStart bool) {
@@ -1286,42 +1327,61 @@ func processCDFlags(cfg *ctrld.Config) error {
// Fetch config, unmarshal to cfg.
if resolverConfig.Ctrld.CustomConfig != "" {
logger.Info().Msg("using defined custom config of Control-D resolver")
readBase64Config(resolverConfig.Ctrld.CustomConfig)
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil {
setListenerDefaultValue(cfg)
return nil
}
} else {
cfg.Network = make(map[string]*ctrld.NetworkConfig)
cfg.Network["0"] = &ctrld.NetworkConfig{
Name: "Network 0",
Cidrs: []string{"0.0.0.0/0"},
}
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
Endpoint: resolverConfig.DOH,
Type: cdUpstreamProto,
Timeout: 5000,
}
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
lc := &ctrld.ListenerConfig{
Policy: &ctrld.ListenerPolicyConfig{
Name: "My Policy",
Rules: rules,
},
}
cfg.Listener["0"] = lc
mainLog.Load().Err(err).Msg("disregarding invalid custom config")
}
cfg.Network = make(map[string]*ctrld.NetworkConfig)
cfg.Network["0"] = &ctrld.NetworkConfig{
Name: "Network 0",
Cidrs: []string{"0.0.0.0/0"},
}
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
Endpoint: resolverConfig.DOH,
Type: cdUpstreamProto,
Timeout: 5000,
}
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
lc := &ctrld.ListenerConfig{
Policy: &ctrld.ListenerPolicyConfig{
Name: "My Policy",
Rules: rules,
},
}
cfg.Listener["0"] = lc
// Set default value.
setListenerDefaultValue(cfg)
return nil
}
// setListenerDefaultValue sets the default value for cfg.Listener if none existed.
func setListenerDefaultValue(cfg *ctrld.Config) {
if len(cfg.Listener) == 0 {
cfg.Listener = map[string]*ctrld.ListenerConfig{
"0": {IP: "", Port: 0},
}
}
return nil
}
// validateCdRemoteConfig validates the custom config from ControlD if defined.
func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error {
if rc.Ctrld.CustomConfig == "" {
return nil
}
if err := readBase64Config(rc.Ctrld.CustomConfig); err != nil {
return err
}
return v.Unmarshal(&cfg)
}
func processListenFlag() {