diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index f431500..0dcadec 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -7,10 +7,13 @@ import ( "os/exec" "runtime" + "github.com/go-playground/validator/v10" "github.com/kardianos/service" "github.com/pelletier/go-toml" "github.com/spf13/cobra" "github.com/spf13/viper" + + "github.com/Control-D-Inc/ctrld" ) var ( @@ -52,6 +55,9 @@ func initCLI() { if err := v.Unmarshal(&cfg); err != nil { log.Fatalf("failed to unmarshal config: %v", err) } + if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil { + log.Fatalf("invalid config: %v", err) + } initLogging() if daemon { exe, err := os.Executable() diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index 06378bb..9d14c66 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -22,7 +22,10 @@ func (p *prog) serveUDP(listenerNum string) error { mainLog.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") return allocErr } - + var failoverRcodes []int + if listenerConfig.Policy != nil { + failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers + } handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { domain := canonicalName(m.Question[0].Name) reqId := requestID() @@ -37,7 +40,7 @@ func (p *prog) serveUDP(listenerNum string) error { answer.SetRcode(m, dns.RcodeRefused) } else { - answer = p.proxy(ctx, upstreams, m) + answer = p.proxy(ctx, upstreams, failoverRcodes, m) rtt := time.Since(t) ctrld.Log(ctx, proxyLog.Debug(), "received response of %d bytes in %s", answer.Len(), rtt) } @@ -119,7 +122,7 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c return upstreams, matched } -func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns.Msg { +func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg { upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg { ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name) @@ -128,12 +131,14 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns ctrld.Log(ctx, proxyLog.Error().Err(err), "failed to create resolver") return nil } + resolveCtx, cancel := context.WithCancel(ctx) + defer cancel() if upstreamConfig.Timeout > 0 { - timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*time.Duration(upstreamConfig.Timeout)) + timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(upstreamConfig.Timeout)) defer cancel() - ctx = timeoutCtx + resolveCtx = timeoutCtx } - answer, err := dnsResolver.Resolve(ctx, msg) + answer, err := dnsResolver.Resolve(resolveCtx, msg) if err != nil { ctrld.Log(ctx, proxyLog.Error().Err(err), "failed to resolve query") return nil @@ -141,9 +146,15 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns return answer } for n, upstreamConfig := range upstreamConfigs { - if answer := resolve(n, upstreamConfig, msg); answer != nil { - return answer + answer := resolve(n, upstreamConfig, msg) + if answer == nil { + continue } + if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) { + ctrld.Log(ctx, proxyLog.Debug(), "failover rcode matched, process to next upstream") + continue + } + return answer } ctrld.Log(ctx, proxyLog.Error(), "all upstreams failed") answer := new(dns.Msg) @@ -151,6 +162,18 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns return answer } +func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { + upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) + for _, upstream := range upstreams { + upstreamNum := strings.TrimPrefix(upstream, "upstream.") + upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum]) + } + if len(upstreamConfigs) == 0 { + upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} + } + return upstreamConfigs +} + // canonicalName returns canonical name from FQDN with "." trimmed. func canonicalName(fqdn string) string { q := strings.TrimSpace(fqdn) @@ -189,18 +212,6 @@ func fmtRemoteToLocal(listenerNum, remote, local string) string { return fmt.Sprintf("%s -> listener.%s: %s:", remote, listenerNum, local) } -func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { - upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) - for _, upstream := range upstreams { - upstreamNum := strings.TrimPrefix(upstream, "upstream.") - upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum]) - } - if len(upstreamConfigs) == 0 { - upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - } - return upstreamConfigs -} - func requestID() string { b := make([]byte, 3) // 6 chars if _, err := rand.Read(b); err != nil { @@ -209,6 +220,15 @@ func requestID() string { return hex.EncodeToString(b) } +func containRcode(rcodes []int, rcode int) bool { + for i := range rcodes { + if rcodes[i] == rcode { + return true + } + } + return false +} + var osUpstreamConfig = &ctrld.UpstreamConfig{ Name: "OS resolver", Type: "os", diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index a016ea5..e67e02c 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -74,6 +74,7 @@ func (p *prog) run() { } for listenerNum := range p.cfg.Listener { + p.cfg.Listener[listenerNum].Init() go func(listenerNum string) { defer wg.Done() listenerConfig := p.cfg.Listener[listenerNum] diff --git a/config.go b/config.go index 2ec92e8..f660006 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,7 @@ import ( "net/url" "strings" + "github.com/Control-D-Inc/ctrld/internal/dnsrcode" "github.com/go-playground/validator/v10" "github.com/spf13/viper" ) @@ -92,9 +93,11 @@ type ListenerConfig struct { // ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests. type ListenerPolicyConfig struct { - Name string `mapstructure:"name" toml:"name"` - Networks []Rule `mapstructure:"networks" toml:"networks" validate:"dive,len=1"` - Rules []Rule `mapstructure:"rules" toml:"rules" validate:"dive,len=1"` + Name string `mapstructure:"name" toml:"name"` + Networks []Rule `mapstructure:"networks" toml:"networks" validate:"dive,len=1"` + Rules []Rule `mapstructure:"rules" toml:"rules" validate:"dive,len=1"` + FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes" validate:"dive,dnsrcode"` + FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"` } // Rule is a map from source to list of upstreams. @@ -122,11 +125,26 @@ func (uc *UpstreamConfig) Init() { } } +// Init initialized necessary values for an ListenerConfig. +func (lc *ListenerConfig) Init() { + if lc.Policy != nil { + lc.Policy.FailoverRcodeNumbers = make([]int, len(lc.Policy.FailoverRcodes)) + for i, rcode := range lc.Policy.FailoverRcodes { + lc.Policy.FailoverRcodeNumbers[i] = dnsrcode.FromString(rcode) + } + } +} + // ValidateConfig validates the given config. func ValidateConfig(validate *validator.Validate, cfg *Config) error { + _ = validate.RegisterValidation("dnsrcode", validateDnsRcode) return validate.Struct(cfg) } +func validateDnsRcode(fl validator.FieldLevel) bool { + return dnsrcode.FromString(fl.Field().String()) != -1 +} + func defaultPortFor(typ string) string { switch typ { case resolverTypeDOH, resolverTypeDOH3: diff --git a/config_test.go b/config_test.go index ef038c3..f315c92 100644 --- a/config_test.go +++ b/config_test.go @@ -69,6 +69,7 @@ func TestConfigValidation(t *testing.T) { {"invalid listener port", invalidListenerPort(t), true}, {"os upstream", configWithOsUpstream(t), false}, {"invalid rules", configWithInvalidRules(t), true}, + {"invalid dns rcodes", configWithInvalidRcodes(t), true}, } for _, tc := range tests { @@ -155,6 +156,16 @@ func configWithInvalidRules(t *testing.T) *ctrld.Config { return cfg } +func configWithInvalidRcodes(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + cfg.Listener["0"].Policy = &ctrld.ListenerPolicyConfig{ + Name: "Policy with invalid Rcodes", + Networks: []ctrld.Rule{{"*.com": []string{"upstream.0"}}}, + FailoverRcodes: []string{"foo"}, + } + return cfg +} + func TestUpstreamConfig_Init(t *testing.T) { tests := []struct { name string diff --git a/docs/config.md b/docs/config.md index 3b24144..44c125e 100644 --- a/docs/config.md +++ b/docs/config.md @@ -300,4 +300,21 @@ Above policy will: - type: array of rule +### failover_rcodes +For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`. For example: + +```toml +[listener.0.policy] +name = "My Policy" +failover_rcodes = ["NXDOMAIN", "SERVFAIL"] +networks = [ + {"network.0" = ["upstream.0", "upstream.1"]}, +] +``` + +If `upstream.0` returns a NXDOMAIN response, the request will be forwarded to `upstream.1` instead of returning immediately to the client. + +See all available DNS Rcodes value [here](rcode_link). + [toml_link]: https://toml.io/en +[rcode_link]: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6 diff --git a/internal/dnsrcode/rcode.go b/internal/dnsrcode/rcode.go new file mode 100644 index 0000000..140911a --- /dev/null +++ b/internal/dnsrcode/rcode.go @@ -0,0 +1,39 @@ +package dnsrcode + +import "strings" + +// https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6 +var dnsRcode = map[string]int{ + "NOERROR": 0, // NoError - No Error + "FORMERR": 1, // FormErr - Format Error + "SERVFAIL": 2, // ServFail - Server Failure + "NXDOMAIN": 3, // NXDomain - Non-Existent Domain + "NOTIMP": 4, // NotImp - Not Implemented + "REFUSED": 5, // Refused - Query Refused + "YXDOMAIN": 6, // YXDomain - Name Exists when it should not + "YXRRSET": 7, // YXRRSet - RR Set Exists when it should not + "NXRRSET": 8, // NXRRSet - RR Set that should exist does not + "NOTAUTH": 9, // NotAuth - Server Not Authoritative for zone + "NOTZONE": 10, // NotZone - Name not contained in zone + "BADSIG": 16, // BADSIG - TSIG Signature Failure + "BADVERS": 16, // BADVERS - Bad OPT Version + "BADKEY": 17, // BADKEY - Key not recognized + "BADTIME": 18, // BADTIME - Signature out of time window + "BADMODE": 19, // BADMODE - Bad TKEY Mode + "BADNAME": 20, // BADNAME - Duplicate key name + "BADALG": 21, // BADALG - Algorithm not supported + "BADTRUNC": 22, // BADTRUNC - Bad Truncation + "BADCOOKIE": 23, // BADCOOKIE - Bad/missing Server Cookie +} + +// FromString returns the DNS Rcode number from given DNS Rcode string. +// The string value is treated as case-insensitive. If the input string +// is an invalid DNS Rcode, -1 is returned. +func FromString(rcode string) int { + rcode = strings.ToUpper(rcode) + val, ok := dnsRcode[rcode] + if !ok { + return -1 + } + return val +} diff --git a/internal/dnsrcode/rcode_test.go b/internal/dnsrcode/rcode_test.go new file mode 100644 index 0000000..f36e850 --- /dev/null +++ b/internal/dnsrcode/rcode_test.go @@ -0,0 +1,29 @@ +package dnsrcode + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFromString(t *testing.T) { + tests := []struct { + name string + rcode string + expectedRcode int + }{ + {"valid", "NoError", 0}, + {"upper", "NOERROR", 0}, + {"lower", "noerror", 0}, + {"mix", "nOeRrOr", 0}, + {"invalid", "foo", -1}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expectedRcode, FromString(tc.rcode)) + }) + } +} diff --git a/testhelper/config.go b/testhelper/config.go index fe548fc..06f1bbb 100644 --- a/testhelper/config.go +++ b/testhelper/config.go @@ -59,6 +59,7 @@ port = 1337 [listener.0.policy] name = "My Policy" +failover_rcodes = ["NXDOMAIN", "SERVFAIL"] networks = [ {"network.0" = ["upstream.1", "upstream.0"]}, {"network.1" = ["upstream.0"]},