diff --git a/config.go b/config.go index cb38096..56bb68d 100644 --- a/config.go +++ b/config.go @@ -285,6 +285,7 @@ type Rule map[string][]string // Init initialized necessary values for an UpstreamConfig. func (uc *UpstreamConfig) Init() { + uc.initDoHScheme() uc.uid = upstreamUID() if u, err := url.Parse(uc.Endpoint); err == nil { uc.Domain = u.Host @@ -631,6 +632,18 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { return "tcp-tls", "udp" } +// initDoHScheme initializes the endpoint scheme for DoH/DoH3 upstream if not present. +func (uc *UpstreamConfig) initDoHScheme() { + switch uc.Type { + case ResolverTypeDOH, ResolverTypeDOH3: + default: + return + } + if !strings.HasPrefix(uc.Endpoint, "https://") { + uc.Endpoint = "https://" + uc.Endpoint + } +} + // Init initialized necessary values for an ListenerConfig. func (lc *ListenerConfig) Init() { if lc.Policy != nil { @@ -683,6 +696,7 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) { return } + uc.initDoHScheme() // DoH/DoH3 requires endpoint is an HTTP url. if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 { u, err := url.Parse(uc.Endpoint) @@ -690,10 +704,6 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) { sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "") return } - if u.Scheme != "http" && u.Scheme != "https" { - sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "") - return - } } } diff --git a/config_test.go b/config_test.go index d66556f..55a19f3 100644 --- a/config_test.go +++ b/config_test.go @@ -102,6 +102,7 @@ func TestConfigValidation(t *testing.T) { {"invalid lease file format", configWithInvalidLeaseFileFormat(t), true}, {"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true}, {"invalid client id pref", configWithInvalidClientIDPref(t), true}, + {"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false}, } for _, tc := range tests { @@ -167,6 +168,12 @@ func invalidUpstreamType(t *testing.T) *ctrld.Config { return cfg } +func dohUpstreamEndpointWithoutScheme(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Upstream["0"].Endpoint = "freedns.controld.com/p1" + return cfg +} + func invalidUpstreamTimeout(t *testing.T) *ctrld.Config { cfg := defaultConfig(t) cfg.Upstream["0"].Timeout = -1 @@ -258,7 +265,7 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config { func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config { cfg := defaultConfig(t) - cfg.Upstream["0"].Endpoint = "1.1.1.1" + cfg.Upstream["0"].Endpoint = "/1.1.1.1" cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH return cfg }