From 7ce62ccaecc646651f6c77ce4eb31d557bf64b12 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 9 Aug 2023 10:01:00 +0700 Subject: [PATCH 1/2] Validate DoH/DoH3 endpoint properly When resolver type is doh/doh3, the endpoint must be a valid http url. Updates #149 --- cmd/cli/cli.go | 2 ++ config.go | 29 ++++++++++++++++++++++++++++- config_test.go | 8 ++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 894efc0..ca48258 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1383,6 +1383,8 @@ func fieldErrorMsg(fe validator.FieldError) string { return fmt.Sprintf("invalid IP format: %s", fe.Value()) case "file": return fmt.Sprintf("filed does not exist: %s", fe.Value()) + case "http_url": + return fmt.Sprintf("invalid http/https url: %s", fe.Value()) } return "" } diff --git a/config.go b/config.go index 70b165b..eef5af0 100644 --- a/config.go +++ b/config.go @@ -193,7 +193,7 @@ type NetworkConfig struct { type UpstreamConfig struct { Name string `mapstructure:"name" toml:"name,omitempty"` Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"` - Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"` + Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty"` BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"` Domain string `mapstructure:"-" toml:"-"` IPStack string `mapstructure:"ip_stack" toml:"ip_stack,omitempty" validate:"ipstack"` @@ -589,6 +589,7 @@ func ValidateConfig(validate *validator.Validate, cfg *Config) error { _ = validate.RegisterValidation("dnsrcode", validateDnsRcode) _ = validate.RegisterValidation("ipstack", validateIpStack) _ = validate.RegisterValidation("iporempty", validateIpOrEmpty) + validate.RegisterStructValidation(upstreamConfigStructLevelValidation, UpstreamConfig{}) return validate.Struct(cfg) } @@ -613,6 +614,32 @@ func validateIpOrEmpty(fl validator.FieldLevel) bool { return net.ParseIP(val) != nil } +func upstreamConfigStructLevelValidation(sl validator.StructLevel) { + uc := sl.Current().Addr().Interface().(*UpstreamConfig) + if uc.Type == ResolverTypeOS { + return + } + + // Endpoint is required for non os resolver. + if uc.Endpoint == "" { + sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "required_unless", "") + return + } + + // DoH/DoH3 requires endpoint is an HTTP url. + if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 { + u, err := url.Parse(uc.Endpoint) + if err != nil || u.Host == "" { + sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "") + return + } + if u.Scheme != "http" && u.Scheme != "https" { + sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "") + return + } + } +} + func defaultPortFor(typ string) string { switch typ { case ResolverTypeDOH, ResolverTypeDOH3: diff --git a/config_test.go b/config_test.go index 83a3386..ca57372 100644 --- a/config_test.go +++ b/config_test.go @@ -95,6 +95,7 @@ func TestConfigValidation(t *testing.T) { {"non-existed lease file", configWithNonExistedLeaseFile(t), true}, {"lease file format required if lease file exist", configWithExistedLeaseFile(t), true}, {"invalid lease file format", configWithInvalidLeaseFileFormat(t), true}, + {"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true}, } for _, tc := range tests { @@ -225,3 +226,10 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config { cfg.Service.DHCPLeaseFileFormat = "invalid" return cfg } + +func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Upstream["0"].Endpoint = "1.1.1.1" + cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH + return cfg +} From f39512b4c04b160ef124cd554a474856acbdbe69 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 9 Aug 2023 10:11:23 +0700 Subject: [PATCH 2/2] cmd/ctrld: only write to config file if listener config changed Updates #149 --- cmd/cli/cli.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index ca48258..8f54d1f 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -208,16 +208,18 @@ func initCLI() { processCDFlags() } - updateListenerConfig() + updated := updateListenerConfig() if cdUID != "" { processLogAndCacheFlags() } - if err := writeConfigFile(); err != nil { - mainLog.Load().Fatal().Err(err).Msg("failed to write config file") - } else { - mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) + if updated { + if err := writeConfigFile(); err != nil { + mainLog.Load().Fatal().Err(err).Msg("failed to write config file") + } else { + mainLog.Load().Info().Msg("writing config file to: " + defaultConfigFile) + } } if newLogPath := cfg.Service.LogPath; newLogPath != "" && oldLogPath != newLogPath { @@ -1412,8 +1414,8 @@ type listenerConfigCheck struct { // updateListenerConfig updates the config for listeners if not defined, // or defined but invalid to be used, e.g: using loopback address other -// than 127.0.0.1 with sytemd-resolved. -func updateListenerConfig() { +// than 127.0.0.1 with systemd-resolved. +func updateListenerConfig() (updated bool) { lcc := make(map[string]*listenerConfigCheck) cdMode := cdUID != "" for n, listener := range cfg.Listener { @@ -1431,6 +1433,7 @@ func updateListenerConfig() { lcc[n].IP = true lcc[n].Port = true } + updated = updated || lcc[n].IP || lcc[n].Port } var closers []io.Closer @@ -1603,6 +1606,7 @@ func updateListenerConfig() { } } } + return } func dirWritable(dir string) (bool, error) {