diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index d07c145..3f606b7 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1489,22 +1489,24 @@ func processNoConfigFlags(noConfigStart bool) { return endpoint, typ } pEndpoint, pType := endpointAndTyp(primaryUpstream) - upstream := map[string]*ctrld.UpstreamConfig{ - "0": { - Name: pEndpoint, - Endpoint: pEndpoint, - Type: pType, - Timeout: 5000, - }, + puc := &ctrld.UpstreamConfig{ + Name: pEndpoint, + Endpoint: pEndpoint, + Type: pType, + Timeout: 5000, } + puc.Init() + upstream := map[string]*ctrld.UpstreamConfig{"0": puc} if secondaryUpstream != "" { sEndpoint, sType := endpointAndTyp(secondaryUpstream) - upstream["1"] = &ctrld.UpstreamConfig{ + suc := &ctrld.UpstreamConfig{ Name: sEndpoint, Endpoint: sEndpoint, Type: sType, Timeout: 5000, } + suc.Init() + upstream["1"] = suc rules := make([]ctrld.Rule, 0, len(domains)) for _, domain := range domains { rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}}) diff --git a/config.go b/config.go index 6c66f62..3f9b2f8 100644 --- a/config.go +++ b/config.go @@ -65,6 +65,7 @@ const ( endpointPrefixHTTPS = "https://" endpointPrefixQUIC = "quic://" endpointPrefixH3 = "h3://" + endpointPrefixSdns = "sdns://" ) var ( @@ -233,7 +234,7 @@ type NetworkConfig struct { // UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to. type UpstreamConfig struct { Name string `mapstructure:"name" toml:"name,omitempty"` - Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy sdns"` + Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy sdns ''"` Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty"` BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"` Domain string `mapstructure:"-" toml:"-"` @@ -687,6 +688,9 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { // initDoHScheme initializes the endpoint scheme for DoH/DoH3 upstream if not present. func (uc *UpstreamConfig) initDoHScheme() { + if strings.HasPrefix(uc.Endpoint, endpointPrefixH3) && uc.Type == "" { + uc.Type = ResolverTypeDOH3 + } switch uc.Type { case ResolverTypeDOH: case ResolverTypeDOH3: @@ -703,6 +707,9 @@ func (uc *UpstreamConfig) initDoHScheme() { // initDnsStamps initializes upstream config based on encoded DNS Stamps Endpoint. func (uc *UpstreamConfig) initDnsStamps() error { + if strings.HasPrefix(uc.Endpoint, endpointPrefixSdns) && uc.Type == "" { + uc.Type = ResolverTypeSDNS + } if uc.Type != ResolverTypeSDNS { return nil } @@ -794,6 +801,12 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) { return } + // Empty type is ok only for endpoints starts with "h3://" and "sdns://". + if uc.Type == "" && !strings.HasPrefix(uc.Endpoint, endpointPrefixH3) && !strings.HasPrefix(uc.Endpoint, endpointPrefixSdns) { + sl.ReportError(uc.Endpoint, "type", "type", "oneof", "doh doh3 dot doq os legacy sdns") + return + } + // initDoHScheme/initDnsStamps may change upstreams information, // so restoring changed values after validation to keep original one. defer func(ep, typ string) { @@ -835,6 +848,7 @@ func defaultPortFor(typ string) string { // - If endpoint starts with "https://" -> ResolverTypeDOH // - If endpoint starts with "quic://" -> ResolverTypeDOQ // - If endpoint starts with "h3://" -> ResolverTypeDOH3 +// - If endpoint starts with "sdns://" -> ResolverTypeSDNS // - For anything else -> ResolverTypeDOT func ResolverTypeFromEndpoint(endpoint string) string { switch { @@ -844,6 +858,8 @@ func ResolverTypeFromEndpoint(endpoint string) string { return ResolverTypeDOQ case strings.HasPrefix(endpoint, endpointPrefixH3): return ResolverTypeDOH3 + case strings.HasPrefix(endpoint, endpointPrefixSdns): + return ResolverTypeSDNS } host := endpoint if strings.Contains(endpoint, ":") { diff --git a/config_internal_test.go b/config_internal_test.go index 41edd32..7b09da3 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -200,6 +200,26 @@ func TestUpstreamConfig_Init(t *testing.T) { u: u1, }, }, + { + "h3 without type", + &UpstreamConfig{ + Name: "doh3", + Endpoint: "h3://example.com", + BootstrapIP: "", + Domain: "", + Timeout: 0, + }, + &UpstreamConfig{ + Name: "doh3", + Type: "doh3", + Endpoint: "https://example.com", + BootstrapIP: "", + Domain: "example.com", + Timeout: 0, + IPStack: IpStackBoth, + u: u1, + }, + }, { "sdns -> doh", &UpstreamConfig{ @@ -285,6 +305,26 @@ func TestUpstreamConfig_Init(t *testing.T) { IPStack: IpStackBoth, }, }, + { + "sdns without type", + &UpstreamConfig{ + Name: "sdns", + Endpoint: "sdns://AAcAAAAAAAAACjc2Ljc2LjIuMTE", + BootstrapIP: "", + Domain: "", + Timeout: 0, + IPStack: IpStackBoth, + }, + &UpstreamConfig{ + Name: "sdns", + Type: "legacy", + Endpoint: "76.76.2.11:53", + BootstrapIP: "76.76.2.11", + Domain: "76.76.2.11", + Timeout: 0, + IPStack: IpStackBoth, + }, + }, } for _, tc := range tests { diff --git a/config_test.go b/config_test.go index c1ffeb4..a20b33c 100644 --- a/config_test.go +++ b/config_test.go @@ -107,6 +107,9 @@ func TestConfigValidation(t *testing.T) { {"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true}, {"invalid client id pref", configWithInvalidClientIDPref(t), true}, {"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false}, + {"doh endpoint without type", dohUpstreamEndpointWithoutType(t), true}, + {"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false}, + {"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false}, {"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true}, } @@ -194,6 +197,27 @@ func dohUpstreamEndpointWithoutScheme(t *testing.T) *ctrld.Config { return cfg } +func dohUpstreamEndpointWithoutType(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Upstream["0"].Endpoint = "https://freedns.controld.com/p1" + cfg.Upstream["0"].Type = "" + return cfg +} + +func doh3UpstreamEndpointWithoutType(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Upstream["0"].Endpoint = "h3://freedns.controld.com/p1" + cfg.Upstream["0"].Type = "" + return cfg +} + +func sdnsUpstreamEndpointWithoutType(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Upstream["0"].Endpoint = "sdns://AgMAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29tAy9wMQ" + cfg.Upstream["0"].Type = "" + return cfg +} + func invalidUpstreamTimeout(t *testing.T) *ctrld.Config { cfg := defaultConfig(t) cfg.Upstream["0"].Timeout = -1