diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 31fdb77..28dfd82 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -171,9 +171,38 @@ func initCLI() { setDependencies(sc) sc.Arguments = append([]string{"run"}, osArgs...) if cdUID != "" { - if _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); err != nil { + rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev) + if err != nil { mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID) } + // validateCdRemoteConfig clobbers v, saving it here to restore later. + oldV := v + if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil { + if errors.As(err, &viper.ConfigParseError{}) { + if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 { + tmpDir := os.TempDir() + tmpConfFile := filepath.Join(tmpDir, "ctrld.toml") + errorLogged := false + // Write remote config to a temporary file to get details error. + if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil { + if de := decoderErrorFromTomlFile(tmpConfFile); de != nil { + row, col := de.Position() + mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error()) + errorLogged = true + } + _ = os.Remove(tmpConfFile) + } + // If we could not log details error, emit what we have already got. + if !errorLogged { + mainLog.Load().Error().Msgf("failed to parse custom config: %v", err) + } + } + } else { + mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err) + } + mainLog.Load().Warn().Msg("disregarding invalid custom config") + } + v = oldV } else if uid := cdUIDFromProvToken(); uid != "" { cdUID = uid removeProvTokenFromArgs(sc) @@ -912,7 +941,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { writeDefaultConfig := !noConfigStart && configBase64 == "" tryReadingConfig(writeDefaultConfig) - readBase64Config(configBase64) + if err := readBase64Config(configBase64); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config") + } processNoConfigFlags(noConfigStart) p.mu.Lock() if err := v.Unmarshal(&cfg); err != nil { @@ -1141,7 +1172,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { } // If error is viper.ConfigFileNotFoundError, write default config. - if _, ok := err.(viper.ConfigFileNotFoundError); ok { + if errors.As(err, &viper.ConfigFileNotFoundError{}) { if err := v.Unmarshal(&cfg); err != nil { mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err) } @@ -1162,13 +1193,11 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { return false } - if _, ok := err.(viper.ConfigParseError); ok { - if f, _ := os.Open(v.ConfigFileUsed()); f != nil { - var i any - if err, ok := toml.NewDecoder(f).Decode(&i).(*toml.DecodeError); ok { - row, col := err.Position() - mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err) - } + // If error is viper.ConfigParseError, emit details line and column number. + if errors.As(err, &viper.ConfigParseError{}) { + if de := decoderErrorFromTomlFile(v.ConfigFileUsed()); de != nil { + row, col := de.Position() + mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err) } } @@ -1177,13 +1206,27 @@ func readConfigFile(writeDefaultConfig, notice bool) bool { return false } -func readBase64Config(configBase64 string) { +// decoderErrorFromTomlFile parses the invalid toml file, returning the details decoder error. +func decoderErrorFromTomlFile(cf string) *toml.DecodeError { + if f, _ := os.Open(cf); f != nil { + defer f.Close() + var i any + var de *toml.DecodeError + if err := toml.NewDecoder(f).Decode(&i); err != nil && errors.As(err, &de) { + return de + } + } + return nil +} + +// readBase64Config reads ctrld config from the base64 input string. +func readBase64Config(configBase64 string) error { if configBase64 == "" { - return + return nil } configStr, err := base64.StdEncoding.DecodeString(configBase64) if err != nil { - mainLog.Load().Fatal().Msgf("invalid base64 config: %v", err) + return fmt.Errorf("invalid base64 config: %w", err) } // readBase64Config is called when: @@ -1194,9 +1237,7 @@ func readBase64Config(configBase64 string) { // So we need to re-create viper instance to discard old one. v = viper.NewWithOptions(viper.KeyDelimiter("::")) v.SetConfigType("toml") - if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil { - mainLog.Load().Fatal().Msgf("failed to read base64 config: %v", err) - } + return v.ReadConfig(bytes.NewReader(configStr)) } func processNoConfigFlags(noConfigStart bool) { @@ -1286,42 +1327,61 @@ func processCDFlags(cfg *ctrld.Config) error { // Fetch config, unmarshal to cfg. if resolverConfig.Ctrld.CustomConfig != "" { logger.Info().Msg("using defined custom config of Control-D resolver") - readBase64Config(resolverConfig.Ctrld.CustomConfig) - if err := v.Unmarshal(&cfg); err != nil { - mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err) + if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil { + setListenerDefaultValue(cfg) + return nil } - } else { - cfg.Network = make(map[string]*ctrld.NetworkConfig) - cfg.Network["0"] = &ctrld.NetworkConfig{ - Name: "Network 0", - Cidrs: []string{"0.0.0.0/0"}, - } - cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) - cfg.Upstream["0"] = &ctrld.UpstreamConfig{ - Endpoint: resolverConfig.DOH, - Type: cdUpstreamProto, - Timeout: 5000, - } - rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) - for _, domain := range resolverConfig.Exclude { - rules = append(rules, ctrld.Rule{domain: []string{}}) - } - cfg.Listener = make(map[string]*ctrld.ListenerConfig) - lc := &ctrld.ListenerConfig{ - Policy: &ctrld.ListenerPolicyConfig{ - Name: "My Policy", - Rules: rules, - }, - } - cfg.Listener["0"] = lc + mainLog.Load().Err(err).Msg("disregarding invalid custom config") } + + cfg.Network = make(map[string]*ctrld.NetworkConfig) + cfg.Network["0"] = &ctrld.NetworkConfig{ + Name: "Network 0", + Cidrs: []string{"0.0.0.0/0"}, + } + cfg.Upstream = make(map[string]*ctrld.UpstreamConfig) + cfg.Upstream["0"] = &ctrld.UpstreamConfig{ + Endpoint: resolverConfig.DOH, + Type: cdUpstreamProto, + Timeout: 5000, + } + rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude)) + for _, domain := range resolverConfig.Exclude { + rules = append(rules, ctrld.Rule{domain: []string{}}) + } + cfg.Listener = make(map[string]*ctrld.ListenerConfig) + lc := &ctrld.ListenerConfig{ + Policy: &ctrld.ListenerPolicyConfig{ + Name: "My Policy", + Rules: rules, + }, + } + cfg.Listener["0"] = lc + // Set default value. + setListenerDefaultValue(cfg) + + return nil +} + +// setListenerDefaultValue sets the default value for cfg.Listener if none existed. +func setListenerDefaultValue(cfg *ctrld.Config) { if len(cfg.Listener) == 0 { cfg.Listener = map[string]*ctrld.ListenerConfig{ "0": {IP: "", Port: 0}, } } - return nil +} + +// validateCdRemoteConfig validates the custom config from ControlD if defined. +func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error { + if rc.Ctrld.CustomConfig == "" { + return nil + } + if err := readBase64Config(rc.Ctrld.CustomConfig); err != nil { + return err + } + return v.Unmarshal(&cfg) } func processListenFlag() {