Merge pull request #5 from Control-D-Inc/cuonglm/upstream-failover-rcode

all: implement policy failover rcodes
This commit is contained in:
Yegor S
2022-12-14 12:53:26 -05:00
committed by GitHub
9 changed files with 165 additions and 23 deletions

View File

@@ -7,10 +7,13 @@ import (
"os/exec"
"runtime"
"github.com/go-playground/validator/v10"
"github.com/kardianos/service"
"github.com/pelletier/go-toml"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/Control-D-Inc/ctrld"
)
var (
@@ -52,6 +55,9 @@ func initCLI() {
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
}
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
log.Fatalf("invalid config: %v", err)
}
initLogging()
if daemon {
exe, err := os.Executable()

View File

@@ -22,7 +22,10 @@ func (p *prog) serveUDP(listenerNum string) error {
mainLog.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
return allocErr
}
var failoverRcodes []int
if listenerConfig.Policy != nil {
failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers
}
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
domain := canonicalName(m.Question[0].Name)
reqId := requestID()
@@ -37,7 +40,7 @@ func (p *prog) serveUDP(listenerNum string) error {
answer.SetRcode(m, dns.RcodeRefused)
} else {
answer = p.proxy(ctx, upstreams, m)
answer = p.proxy(ctx, upstreams, failoverRcodes, m)
rtt := time.Since(t)
ctrld.Log(ctx, proxyLog.Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
}
@@ -119,7 +122,7 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
return upstreams, matched
}
func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns.Msg {
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg {
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
@@ -128,12 +131,14 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns
ctrld.Log(ctx, proxyLog.Error().Err(err), "failed to create resolver")
return nil
}
resolveCtx, cancel := context.WithCancel(ctx)
defer cancel()
if upstreamConfig.Timeout > 0 {
timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*time.Duration(upstreamConfig.Timeout))
timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(upstreamConfig.Timeout))
defer cancel()
ctx = timeoutCtx
resolveCtx = timeoutCtx
}
answer, err := dnsResolver.Resolve(ctx, msg)
answer, err := dnsResolver.Resolve(resolveCtx, msg)
if err != nil {
ctrld.Log(ctx, proxyLog.Error().Err(err), "failed to resolve query")
return nil
@@ -141,9 +146,15 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns
return answer
}
for n, upstreamConfig := range upstreamConfigs {
if answer := resolve(n, upstreamConfig, msg); answer != nil {
return answer
answer := resolve(n, upstreamConfig, msg)
if answer == nil {
continue
}
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
ctrld.Log(ctx, proxyLog.Debug(), "failover rcode matched, process to next upstream")
continue
}
return answer
}
ctrld.Log(ctx, proxyLog.Error(), "all upstreams failed")
answer := new(dns.Msg)
@@ -151,6 +162,18 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns
return answer
}
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
for _, upstream := range upstreams {
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
}
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
}
return upstreamConfigs
}
// canonicalName returns canonical name from FQDN with "." trimmed.
func canonicalName(fqdn string) string {
q := strings.TrimSpace(fqdn)
@@ -189,18 +212,6 @@ func fmtRemoteToLocal(listenerNum, remote, local string) string {
return fmt.Sprintf("%s -> listener.%s: %s:", remote, listenerNum, local)
}
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
for _, upstream := range upstreams {
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
}
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
}
return upstreamConfigs
}
func requestID() string {
b := make([]byte, 3) // 6 chars
if _, err := rand.Read(b); err != nil {
@@ -209,6 +220,15 @@ func requestID() string {
return hex.EncodeToString(b)
}
func containRcode(rcodes []int, rcode int) bool {
for i := range rcodes {
if rcodes[i] == rcode {
return true
}
}
return false
}
var osUpstreamConfig = &ctrld.UpstreamConfig{
Name: "OS resolver",
Type: "os",

View File

@@ -74,6 +74,7 @@ func (p *prog) run() {
}
for listenerNum := range p.cfg.Listener {
p.cfg.Listener[listenerNum].Init()
go func(listenerNum string) {
defer wg.Done()
listenerConfig := p.cfg.Listener[listenerNum]

View File

@@ -5,6 +5,7 @@ import (
"net/url"
"strings"
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
"github.com/go-playground/validator/v10"
"github.com/spf13/viper"
)
@@ -92,9 +93,11 @@ type ListenerConfig struct {
// ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests.
type ListenerPolicyConfig struct {
Name string `mapstructure:"name" toml:"name"`
Networks []Rule `mapstructure:"networks" toml:"networks" validate:"dive,len=1"`
Rules []Rule `mapstructure:"rules" toml:"rules" validate:"dive,len=1"`
Name string `mapstructure:"name" toml:"name"`
Networks []Rule `mapstructure:"networks" toml:"networks" validate:"dive,len=1"`
Rules []Rule `mapstructure:"rules" toml:"rules" validate:"dive,len=1"`
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes" validate:"dive,dnsrcode"`
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
}
// Rule is a map from source to list of upstreams.
@@ -122,11 +125,26 @@ func (uc *UpstreamConfig) Init() {
}
}
// Init initialized necessary values for an ListenerConfig.
func (lc *ListenerConfig) Init() {
if lc.Policy != nil {
lc.Policy.FailoverRcodeNumbers = make([]int, len(lc.Policy.FailoverRcodes))
for i, rcode := range lc.Policy.FailoverRcodes {
lc.Policy.FailoverRcodeNumbers[i] = dnsrcode.FromString(rcode)
}
}
}
// ValidateConfig validates the given config.
func ValidateConfig(validate *validator.Validate, cfg *Config) error {
_ = validate.RegisterValidation("dnsrcode", validateDnsRcode)
return validate.Struct(cfg)
}
func validateDnsRcode(fl validator.FieldLevel) bool {
return dnsrcode.FromString(fl.Field().String()) != -1
}
func defaultPortFor(typ string) string {
switch typ {
case resolverTypeDOH, resolverTypeDOH3:

View File

@@ -69,6 +69,7 @@ func TestConfigValidation(t *testing.T) {
{"invalid listener port", invalidListenerPort(t), true},
{"os upstream", configWithOsUpstream(t), false},
{"invalid rules", configWithInvalidRules(t), true},
{"invalid dns rcodes", configWithInvalidRcodes(t), true},
}
for _, tc := range tests {
@@ -155,6 +156,16 @@ func configWithInvalidRules(t *testing.T) *ctrld.Config {
return cfg
}
func configWithInvalidRcodes(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Listener["0"].Policy = &ctrld.ListenerPolicyConfig{
Name: "Policy with invalid Rcodes",
Networks: []ctrld.Rule{{"*.com": []string{"upstream.0"}}},
FailoverRcodes: []string{"foo"},
}
return cfg
}
func TestUpstreamConfig_Init(t *testing.T) {
tests := []struct {
name string

View File

@@ -300,4 +300,21 @@ Above policy will:
- type: array of rule
### failover_rcodes
For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`. For example:
```toml
[listener.0.policy]
name = "My Policy"
failover_rcodes = ["NXDOMAIN", "SERVFAIL"]
networks = [
{"network.0" = ["upstream.0", "upstream.1"]},
]
```
If `upstream.0` returns a NXDOMAIN response, the request will be forwarded to `upstream.1` instead of returning immediately to the client.
See all available DNS Rcodes value [here](rcode_link).
[toml_link]: https://toml.io/en
[rcode_link]: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6

View File

@@ -0,0 +1,39 @@
package dnsrcode
import "strings"
// https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6
var dnsRcode = map[string]int{
"NOERROR": 0, // NoError - No Error
"FORMERR": 1, // FormErr - Format Error
"SERVFAIL": 2, // ServFail - Server Failure
"NXDOMAIN": 3, // NXDomain - Non-Existent Domain
"NOTIMP": 4, // NotImp - Not Implemented
"REFUSED": 5, // Refused - Query Refused
"YXDOMAIN": 6, // YXDomain - Name Exists when it should not
"YXRRSET": 7, // YXRRSet - RR Set Exists when it should not
"NXRRSET": 8, // NXRRSet - RR Set that should exist does not
"NOTAUTH": 9, // NotAuth - Server Not Authoritative for zone
"NOTZONE": 10, // NotZone - Name not contained in zone
"BADSIG": 16, // BADSIG - TSIG Signature Failure
"BADVERS": 16, // BADVERS - Bad OPT Version
"BADKEY": 17, // BADKEY - Key not recognized
"BADTIME": 18, // BADTIME - Signature out of time window
"BADMODE": 19, // BADMODE - Bad TKEY Mode
"BADNAME": 20, // BADNAME - Duplicate key name
"BADALG": 21, // BADALG - Algorithm not supported
"BADTRUNC": 22, // BADTRUNC - Bad Truncation
"BADCOOKIE": 23, // BADCOOKIE - Bad/missing Server Cookie
}
// FromString returns the DNS Rcode number from given DNS Rcode string.
// The string value is treated as case-insensitive. If the input string
// is an invalid DNS Rcode, -1 is returned.
func FromString(rcode string) int {
rcode = strings.ToUpper(rcode)
val, ok := dnsRcode[rcode]
if !ok {
return -1
}
return val
}

View File

@@ -0,0 +1,29 @@
package dnsrcode
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFromString(t *testing.T) {
tests := []struct {
name string
rcode string
expectedRcode int
}{
{"valid", "NoError", 0},
{"upper", "NOERROR", 0},
{"lower", "noerror", 0},
{"mix", "nOeRrOr", 0},
{"invalid", "foo", -1},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expectedRcode, FromString(tc.rcode))
})
}
}

View File

@@ -59,6 +59,7 @@ port = 1337
[listener.0.policy]
name = "My Policy"
failover_rcodes = ["NXDOMAIN", "SERVFAIL"]
networks = [
{"network.0" = ["upstream.1", "upstream.0"]},
{"network.1" = ["upstream.0"]},