mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
cmd/cli: validate remote config during "ctrld start"
On BSD, the service is made un-killable since v1.3.4 by using daemon command "-r" option. However, when reading remote config, the ctrld will fatally exit if the config is malformed. This causes daemon respawn new ctrld process immediately, causing the "ctrld start" command hang forever because of restart loop. Since "ctrld start" already fetch the resolver config for validating uid, it should validate the remote config, too. This allows better error message printed to users, let them know that the config is invalid. Further, if the remote config was invalid, we should disregard it and generating the default working one in cd mode.
This commit is contained in:
committed by
Cuong Manh Le
parent
cc6ae290f8
commit
64bcd2f00d
148
cmd/cli/cli.go
148
cmd/cli/cli.go
@@ -171,9 +171,38 @@ func initCLI() {
|
||||
setDependencies(sc)
|
||||
sc.Arguments = append([]string{"run"}, osArgs...)
|
||||
if cdUID != "" {
|
||||
if _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); err != nil {
|
||||
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID)
|
||||
}
|
||||
// validateCdRemoteConfig clobbers v, saving it here to restore later.
|
||||
oldV := v
|
||||
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
|
||||
if errors.As(err, &viper.ConfigParseError{}) {
|
||||
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
|
||||
tmpDir := os.TempDir()
|
||||
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
|
||||
errorLogged := false
|
||||
// Write remote config to a temporary file to get details error.
|
||||
if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil {
|
||||
if de := decoderErrorFromTomlFile(tmpConfFile); de != nil {
|
||||
row, col := de.Position()
|
||||
mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error())
|
||||
errorLogged = true
|
||||
}
|
||||
_ = os.Remove(tmpConfFile)
|
||||
}
|
||||
// If we could not log details error, emit what we have already got.
|
||||
if !errorLogged {
|
||||
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
|
||||
}
|
||||
mainLog.Load().Warn().Msg("disregarding invalid custom config")
|
||||
}
|
||||
v = oldV
|
||||
} else if uid := cdUIDFromProvToken(); uid != "" {
|
||||
cdUID = uid
|
||||
removeProvTokenFromArgs(sc)
|
||||
@@ -912,7 +941,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
readBase64Config(configBase64)
|
||||
if err := readBase64Config(configBase64); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config")
|
||||
}
|
||||
processNoConfigFlags(noConfigStart)
|
||||
p.mu.Lock()
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
@@ -1141,7 +1172,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
|
||||
}
|
||||
|
||||
// If error is viper.ConfigFileNotFoundError, write default config.
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
if errors.As(err, &viper.ConfigFileNotFoundError{}) {
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
@@ -1162,13 +1193,11 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := err.(viper.ConfigParseError); ok {
|
||||
if f, _ := os.Open(v.ConfigFileUsed()); f != nil {
|
||||
var i any
|
||||
if err, ok := toml.NewDecoder(f).Decode(&i).(*toml.DecodeError); ok {
|
||||
row, col := err.Position()
|
||||
mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err)
|
||||
}
|
||||
// If error is viper.ConfigParseError, emit details line and column number.
|
||||
if errors.As(err, &viper.ConfigParseError{}) {
|
||||
if de := decoderErrorFromTomlFile(v.ConfigFileUsed()); de != nil {
|
||||
row, col := de.Position()
|
||||
mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1177,13 +1206,27 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func readBase64Config(configBase64 string) {
|
||||
// decoderErrorFromTomlFile parses the invalid toml file, returning the details decoder error.
|
||||
func decoderErrorFromTomlFile(cf string) *toml.DecodeError {
|
||||
if f, _ := os.Open(cf); f != nil {
|
||||
defer f.Close()
|
||||
var i any
|
||||
var de *toml.DecodeError
|
||||
if err := toml.NewDecoder(f).Decode(&i); err != nil && errors.As(err, &de) {
|
||||
return de
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readBase64Config reads ctrld config from the base64 input string.
|
||||
func readBase64Config(configBase64 string) error {
|
||||
if configBase64 == "" {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
configStr, err := base64.StdEncoding.DecodeString(configBase64)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Msgf("invalid base64 config: %v", err)
|
||||
return fmt.Errorf("invalid base64 config: %w", err)
|
||||
}
|
||||
|
||||
// readBase64Config is called when:
|
||||
@@ -1194,9 +1237,7 @@ func readBase64Config(configBase64 string) {
|
||||
// So we need to re-create viper instance to discard old one.
|
||||
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
v.SetConfigType("toml")
|
||||
if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to read base64 config: %v", err)
|
||||
}
|
||||
return v.ReadConfig(bytes.NewReader(configStr))
|
||||
}
|
||||
|
||||
func processNoConfigFlags(noConfigStart bool) {
|
||||
@@ -1286,42 +1327,61 @@ func processCDFlags(cfg *ctrld.Config) error {
|
||||
// Fetch config, unmarshal to cfg.
|
||||
if resolverConfig.Ctrld.CustomConfig != "" {
|
||||
logger.Info().Msg("using defined custom config of Control-D resolver")
|
||||
readBase64Config(resolverConfig.Ctrld.CustomConfig)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil {
|
||||
setListenerDefaultValue(cfg)
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
cfg.Network = make(map[string]*ctrld.NetworkConfig)
|
||||
cfg.Network["0"] = &ctrld.NetworkConfig{
|
||||
Name: "Network 0",
|
||||
Cidrs: []string{"0.0.0.0/0"},
|
||||
}
|
||||
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
|
||||
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: cdUpstreamProto,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
for _, domain := range resolverConfig.Exclude {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
|
||||
lc := &ctrld.ListenerConfig{
|
||||
Policy: &ctrld.ListenerPolicyConfig{
|
||||
Name: "My Policy",
|
||||
Rules: rules,
|
||||
},
|
||||
}
|
||||
cfg.Listener["0"] = lc
|
||||
mainLog.Load().Err(err).Msg("disregarding invalid custom config")
|
||||
}
|
||||
|
||||
cfg.Network = make(map[string]*ctrld.NetworkConfig)
|
||||
cfg.Network["0"] = &ctrld.NetworkConfig{
|
||||
Name: "Network 0",
|
||||
Cidrs: []string{"0.0.0.0/0"},
|
||||
}
|
||||
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
|
||||
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: cdUpstreamProto,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
for _, domain := range resolverConfig.Exclude {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
|
||||
lc := &ctrld.ListenerConfig{
|
||||
Policy: &ctrld.ListenerPolicyConfig{
|
||||
Name: "My Policy",
|
||||
Rules: rules,
|
||||
},
|
||||
}
|
||||
cfg.Listener["0"] = lc
|
||||
|
||||
// Set default value.
|
||||
setListenerDefaultValue(cfg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setListenerDefaultValue sets the default value for cfg.Listener if none existed.
|
||||
func setListenerDefaultValue(cfg *ctrld.Config) {
|
||||
if len(cfg.Listener) == 0 {
|
||||
cfg.Listener = map[string]*ctrld.ListenerConfig{
|
||||
"0": {IP: "", Port: 0},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCdRemoteConfig validates the custom config from ControlD if defined.
|
||||
func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error {
|
||||
if rc.Ctrld.CustomConfig == "" {
|
||||
return nil
|
||||
}
|
||||
if err := readBase64Config(rc.Ctrld.CustomConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return v.Unmarshal(&cfg)
|
||||
}
|
||||
|
||||
func processListenFlag() {
|
||||
|
||||
Reference in New Issue
Block a user