Allow DoH/DoH3 endpoint without scheme

This commit is contained in:
Cuong Manh Le
2024-03-05 20:48:13 +07:00
committed by Cuong Manh Le
parent 203a2ec8b8
commit a9672dfff5
2 changed files with 22 additions and 5 deletions

View File

@@ -285,6 +285,7 @@ type Rule map[string][]string
// Init initialized necessary values for an UpstreamConfig.
func (uc *UpstreamConfig) Init() {
uc.initDoHScheme()
uc.uid = upstreamUID()
if u, err := url.Parse(uc.Endpoint); err == nil {
uc.Domain = u.Host
@@ -631,6 +632,18 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
return "tcp-tls", "udp"
}
// initDoHScheme initializes the endpoint scheme for DoH/DoH3 upstream if not present.
func (uc *UpstreamConfig) initDoHScheme() {
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
default:
return
}
if !strings.HasPrefix(uc.Endpoint, "https://") {
uc.Endpoint = "https://" + uc.Endpoint
}
}
// Init initialized necessary values for an ListenerConfig.
func (lc *ListenerConfig) Init() {
if lc.Policy != nil {
@@ -683,6 +696,7 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
return
}
uc.initDoHScheme()
// DoH/DoH3 requires endpoint is an HTTP url.
if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 {
u, err := url.Parse(uc.Endpoint)
@@ -690,10 +704,6 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
return
}
if u.Scheme != "http" && u.Scheme != "https" {
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
return
}
}
}

View File

@@ -102,6 +102,7 @@ func TestConfigValidation(t *testing.T) {
{"invalid lease file format", configWithInvalidLeaseFileFormat(t), true},
{"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true},
{"invalid client id pref", configWithInvalidClientIDPref(t), true},
{"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false},
}
for _, tc := range tests {
@@ -167,6 +168,12 @@ func invalidUpstreamType(t *testing.T) *ctrld.Config {
return cfg
}
func dohUpstreamEndpointWithoutScheme(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Upstream["0"].Endpoint = "freedns.controld.com/p1"
return cfg
}
func invalidUpstreamTimeout(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Upstream["0"].Timeout = -1
@@ -258,7 +265,7 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Upstream["0"].Endpoint = "1.1.1.1"
cfg.Upstream["0"].Endpoint = "/1.1.1.1"
cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH
return cfg
}