diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 938860a..4123e28 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -593,18 +593,24 @@ func processNoConfigFlags(noConfigStart bool) { } processListenFlag() + endpointAndTyp := func(endpoint string) (string, string) { + typ := ctrld.ResolverTypeFromEndpoint(endpoint) + return strings.TrimPrefix(endpoint, "quic://"), typ + } + pEndpoint, pType := endpointAndTyp(primaryUpstream) upstream := map[string]*ctrld.UpstreamConfig{ "0": { - Name: primaryUpstream, - Endpoint: primaryUpstream, - Type: ctrld.ResolverTypeDOH, + Name: pEndpoint, + Endpoint: pEndpoint, + Type: pType, }, } if secondaryUpstream != "" { + sEndpoint, sType := endpointAndTyp(secondaryUpstream) upstream["1"] = &ctrld.UpstreamConfig{ - Name: secondaryUpstream, - Endpoint: secondaryUpstream, - Type: ctrld.ResolverTypeLegacy, + Name: sEndpoint, + Endpoint: sEndpoint, + Type: sType, } rules := make([]ctrld.Rule, 0, len(domains)) for _, domain := range domains { diff --git a/config.go b/config.go index 2cb0e38..fb97901 100644 --- a/config.go +++ b/config.go @@ -380,3 +380,27 @@ func availableNameservers() []string { } return nss[:n] } + +// ResolverTypeFromEndpoint tries guessing the resolver type with a given endpoint +// using following rules: +// +// - If endpoint is an IP address -> ResolverTypeLegacy +// - If endpoint starts with "https://" -> ResolverTypeDOH +// - If endpoint starts with "quic://" -> ResolverTypeDOQ +// - For anything else -> ResolverTypeDOT +func ResolverTypeFromEndpoint(endpoint string) string { + switch { + case strings.HasPrefix(endpoint, "https://"): + return ResolverTypeDOH + case strings.HasPrefix(endpoint, "quic://"): + return ResolverTypeDOQ + } + host := endpoint + if strings.Contains(endpoint, ":") { + host, _, _ = net.SplitHostPort(host) + } + if ip := net.ParseIP(host); ip != nil { + return ResolverTypeLegacy + } + return ResolverTypeDOT +} diff --git a/resolver_test.go b/resolver_test.go index a5a93c4..531570b 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -27,3 +27,27 @@ func Test_osResolver_Resolve(t *testing.T) { case <-ctx.Done(): } } + +func Test_upstreamTypeFromEndpoint(t *testing.T) { + tests := []struct { + name string + endpoint string + resolverType string + }{ + {"doh", "https://freedns.controld.com/p2", ResolverTypeDOH}, + {"doq", "quic://p2.freedns.controld.com", ResolverTypeDOQ}, + {"dot", "p2.freedns.controld.com", ResolverTypeDOT}, + {"legacy", "8.8.8.8:53", ResolverTypeLegacy}, + {"legacy ipv6", "[2404:6800:4005:809::200e]:53", ResolverTypeLegacy}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if rt := ResolverTypeFromEndpoint(tc.endpoint); rt != tc.resolverType { + t.Errorf("mismatch, want: %s, got: %s", tc.resolverType, rt) + } + }) + } +}