From b267572b3885ee7b763dfb28b1f1d86c15ced266 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 28 Apr 2023 01:12:59 +0700 Subject: [PATCH] all: implement split upstreams This commit introduces split upstreams feature, allowing to configure what ip stack that ctrld will use to connect to upstream. --- cmd/ctrld/dns_proxy.go | 3 +- config.go | 206 ++++++++++++++++++++++++++++++++-------- config_internal_test.go | 9 ++ config_quic.go | 56 +++++++++-- config_quic_free.go | 5 +- docs/config.md | 20 +++- doh.go | 16 ++-- doq.go | 14 ++- dot.go | 10 +- resolver.go | 15 ++- 10 files changed, 286 insertions(+), 68 deletions(-) diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index 845c2e2..f473634 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -242,7 +242,8 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i } } answer, err := resolve1(n, upstreamConfig, msg) - if err != nil { + // Only do re-bootstrapping if bootstrap ip is not explicitly set by user. + if err != nil && upstreamConfig.BootstrapIP == "" { ctrld.Log(ctx, mainLog.Debug().Err(err), "could not resolve query on first attempt, retrying...") // If any error occurred, re-bootstrap transport/ip, retry the request. upstreamConfig.ReBootstrap() diff --git a/config.go b/config.go index 4eb4e56..f3e28fc 100644 --- a/config.go +++ b/config.go @@ -4,13 +4,13 @@ import ( "context" "crypto/tls" "crypto/x509" + "math/rand" "net" "net/http" "net/url" "os" "strings" "sync" - "sync/atomic" "time" "github.com/go-playground/validator/v10" @@ -22,6 +22,15 @@ import ( ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) +const ( + IpStackBoth = "both" + IpStackV4 = "v4" + IpStackV6 = "v6" + IpStackSplit = "split" +) + +var controldParentDomains = []string{"controld.com", "controld.net", "controld.dev"} + // SetConfigName set the config name that ctrld will look for. // DEPRECATED: use SetConfigNameWithPath instead. func SetConfigName(v *viper.Viper, name string) { @@ -118,19 +127,25 @@ type UpstreamConfig struct { Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"` BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"` Domain string `mapstructure:"-" toml:"-"` + IPStack string `mapstructure:"ip_stack" toml:"ip_stack,omitempty" validate:"ipstack"` Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"` // The caller should not access this field directly. // Use UpstreamSendClientInfo instead. - SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"` - transport *http.Transport `mapstructure:"-" toml:"-"` - http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"` - certPool *x509.CertPool `mapstructure:"-" toml:"-"` - u *url.URL `mapstructure:"-" toml:"-"` + SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"` - g singleflight.Group - mu sync.Mutex - bootstrapIPs []string - nextBootstrapIP atomic.Uint32 + g singleflight.Group + mu sync.Mutex + bootstrapIPs []string + bootstrapIPs4 []string + bootstrapIPs6 []string + transport *http.Transport + transport4 *http.Transport + transport6 *http.Transport + http3RoundTripper http.RoundTripper + http3RoundTripper4 http.RoundTripper + http3RoundTripper6 http.RoundTripper + certPool *x509.CertPool + u *url.URL } // ListenerConfig specifies the networks configuration that ctrld will run on. @@ -164,18 +179,23 @@ func (uc *UpstreamConfig) Init() { uc.u = u } } - if uc.Domain != "" { - return + if uc.Domain == "" { + if !strings.Contains(uc.Endpoint, ":") { + uc.Domain = uc.Endpoint + uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type)) + } + host, _, _ := net.SplitHostPort(uc.Endpoint) + uc.Domain = host + if net.ParseIP(uc.Domain) != nil { + uc.BootstrapIP = uc.Domain + } } - - if !strings.Contains(uc.Endpoint, ":") { - uc.Domain = uc.Endpoint - uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type)) - } - host, _, _ := net.SplitHostPort(uc.Endpoint) - uc.Domain = host - if net.ParseIP(uc.Domain) != nil { - uc.BootstrapIP = uc.Domain + if uc.IPStack == "" { + if uc.isControlD() { + uc.IPStack = IpStackSplit + } else { + uc.IPStack = IpStackBoth + } } } @@ -195,13 +215,8 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool { } switch uc.Type { case ResolverTypeDOH, ResolverTypeDOH3: - if u, err := url.Parse(uc.Endpoint); err == nil { - domain := u.Hostname() - for _, parent := range []string{"controld.com", "controld.net"} { - if dns.IsSubDomain(parent, domain) { - return true - } - } + if uc.isControlD() { + return true } } return false @@ -226,6 +241,13 @@ func (uc *UpstreamConfig) SetupBootstrapIP() { // The first usable IP will be used as bootstrap IP of the upstream. func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) { uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS) + for _, ip := range uc.bootstrapIPs { + if ctrldnet.IsIPv6(ip) { + uc.bootstrapIPs6 = append(uc.bootstrapIPs6, ip) + } else { + uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip) + } + } ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs) } @@ -238,7 +260,6 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { ProxyLog.Debug().Msg("re-bootstrapping upstream ip") - uc.BootstrapIP = "" uc.setupTransportWithoutPingUpstream() return true, nil }) @@ -269,19 +290,17 @@ func (uc *UpstreamConfig) setupDOHTransport() { uc.pingUpstream() } -func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() { - uc.mu.Lock() - defer uc.mu.Unlock() - uc.transport = http.DefaultTransport.(*http.Transport).Clone() - uc.transport.IdleConnTimeout = 5 * time.Second - uc.transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} +func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.IdleConnTimeout = 5 * time.Second + transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} dialerTimeoutMs := 2000 if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs { dialerTimeoutMs = uc.Timeout } dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond - uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { _, port, _ := net.SplitHostPort(addr) if uc.BootstrapIP != "" { dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout} @@ -292,17 +311,42 @@ func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() { pd := &ctrldnet.ParallelDialer{} pd.Timeout = dialerTimeout pd.KeepAlive = dialerTimeout - addrs := make([]string, len(uc.bootstrapIPs)) - for i := range uc.bootstrapIPs { - addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port) + dialAddrs := make([]string, len(addrs)) + for i := range addrs { + dialAddrs[i] = net.JoinHostPort(addrs[i], port) } - conn, err := pd.DialContext(ctx, network, addrs) + conn, err := pd.DialContext(ctx, network, dialAddrs) if err != nil { return nil, err } Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", conn.RemoteAddr()) return conn, nil } + return transport +} + +func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() { + uc.mu.Lock() + defer uc.mu.Unlock() + switch uc.IPStack { + case IpStackBoth, "": + uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + case IpStackV4: + uc.transport = uc.newDOHTransport(uc.bootstrapIPs4) + case IpStackV6: + uc.transport = uc.newDOHTransport(uc.bootstrapIPs6) + case IpStackSplit: + uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if ctrldnet.IPv6Available(ctx) { + uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6) + } else { + uc.transport6 = uc.transport4 + } + + uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + } } func (uc *UpstreamConfig) pingUpstream() { @@ -320,6 +364,74 @@ func (uc *UpstreamConfig) pingUpstream() { _, _ = dnsResolver.Resolve(ctx, msg) } +func (uc *UpstreamConfig) isControlD() bool { + domain := uc.Domain + if domain == "" { + if u, err := url.Parse(uc.Endpoint); err == nil { + domain = u.Hostname() + } + } + for _, parent := range controldParentDomains { + if dns.IsSubDomain(parent, domain) { + return true + } + } + return false +} + +func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { + switch uc.IPStack { + case IpStackBoth, IpStackV4, IpStackV6: + return uc.transport + case IpStackSplit: + switch dnsType { + case dns.TypeA: + return uc.transport4 + default: + return uc.transport6 + } + } + return uc.transport +} + +func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string { + switch uc.IPStack { + case IpStackBoth: + return pick(uc.bootstrapIPs) + case IpStackV4: + return pick(uc.bootstrapIPs4) + case IpStackV6: + return pick(uc.bootstrapIPs6) + case IpStackSplit: + switch dnsType { + case dns.TypeA: + return pick(uc.bootstrapIPs4) + default: + return pick(uc.bootstrapIPs6) + } + } + return pick(uc.bootstrapIPs) +} + +func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) { + switch uc.IPStack { + case IpStackBoth: + return "tcp-tls", "udp" + case IpStackV4: + return "tcp4-tls", "udp4" + case IpStackV6: + return "tcp6-tls", "udp6" + case IpStackSplit: + switch dnsType { + case dns.TypeA: + return "tcp4-tls", "udp4" + default: + return "tcp6-tls", "udp6" + } + } + return "tcp-tls", "udp" +} + // Init initialized necessary values for an ListenerConfig. func (lc *ListenerConfig) Init() { if lc.Policy != nil { @@ -333,6 +445,7 @@ func (lc *ListenerConfig) Init() { // ValidateConfig validates the given config. func ValidateConfig(validate *validator.Validate, cfg *Config) error { _ = validate.RegisterValidation("dnsrcode", validateDnsRcode) + _ = validate.RegisterValidation("ipstack", validateIpStack) _ = validate.RegisterValidation("iporempty", validateIpOrEmpty) return validate.Struct(cfg) } @@ -341,6 +454,15 @@ func validateDnsRcode(fl validator.FieldLevel) bool { return dnsrcode.FromString(fl.Field().String()) != -1 } +func validateIpStack(fl validator.FieldLevel) bool { + switch fl.Field().String() { + case IpStackBoth, IpStackV4, IpStackV6, IpStackSplit, "": + return true + default: + return false + } +} + func validateIpOrEmpty(fl validator.FieldLevel) bool { val := fl.Field().String() if val == "" { @@ -384,3 +506,7 @@ func ResolverTypeFromEndpoint(endpoint string) string { } return ResolverTypeDOT } + +func pick(s []string) string { + return s[rand.Intn(len(s))] +} diff --git a/config_internal_test.go b/config_internal_test.go index 4c67826..bf310f9 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -48,6 +48,7 @@ func TestUpstreamConfig_Init(t *testing.T) { BootstrapIP: "", Domain: "example.com", Timeout: 0, + IPStack: IpStackBoth, u: u1, }, }, @@ -68,6 +69,7 @@ func TestUpstreamConfig_Init(t *testing.T) { BootstrapIP: "", Domain: "example.com", Timeout: 0, + IPStack: IpStackBoth, u: u2, }, }, @@ -88,6 +90,7 @@ func TestUpstreamConfig_Init(t *testing.T) { BootstrapIP: "", Domain: "freedns.controld.com", Timeout: 0, + IPStack: IpStackSplit, }, }, { @@ -99,6 +102,7 @@ func TestUpstreamConfig_Init(t *testing.T) { BootstrapIP: "", Domain: "", Timeout: 0, + IPStack: IpStackSplit, }, &UpstreamConfig{ Name: "dot", @@ -107,6 +111,7 @@ func TestUpstreamConfig_Init(t *testing.T) { BootstrapIP: "", Domain: "freedns.controld.com", Timeout: 0, + IPStack: IpStackSplit, }, }, { @@ -126,6 +131,7 @@ func TestUpstreamConfig_Init(t *testing.T) { BootstrapIP: "1.2.3.4", Domain: "1.2.3.4", Timeout: 0, + IPStack: IpStackBoth, }, }, { @@ -145,6 +151,7 @@ func TestUpstreamConfig_Init(t *testing.T) { BootstrapIP: "1.2.3.4", Domain: "1.2.3.4", Timeout: 0, + IPStack: IpStackBoth, }, }, { @@ -157,6 +164,7 @@ func TestUpstreamConfig_Init(t *testing.T) { Domain: "", Timeout: 0, SendClientInfo: ptrBool(false), + IPStack: IpStackBoth, }, &UpstreamConfig{ Name: "doh", @@ -166,6 +174,7 @@ func TestUpstreamConfig_Init(t *testing.T) { Domain: "example.com", Timeout: 0, SendClientInfo: ptrBool(false), + IPStack: IpStackBoth, u: u2, }, }, diff --git a/config_quic.go b/config_quic.go index 9c6d668..8c0fb97 100644 --- a/config_quic.go +++ b/config_quic.go @@ -7,10 +7,15 @@ import ( "crypto/tls" "errors" "net" + "net/http" "sync" + "time" + "github.com/miekg/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" + + ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) func (uc *UpstreamConfig) setupDOH3Transport() { @@ -18,9 +23,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() { uc.pingUpstream() } -func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() { - uc.mu.Lock() - defer uc.mu.Unlock() +func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { rt := &http3.RoundTripper{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { @@ -40,20 +43,57 @@ func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() { } return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg) } - addrs := make([]string, len(uc.bootstrapIPs)) - for i := range uc.bootstrapIPs { - addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port) + dialAddrs := make([]string, len(addrs)) + for i := range addrs { + dialAddrs[i] = net.JoinHostPort(addrs[i], port) } pd := &quicParallelDialer{} - conn, err := pd.Dial(ctx, domain, addrs, tlsCfg, cfg) + conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg) if err != nil { return nil, err } ProxyLog.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr()) return conn, err } + return rt +} - uc.http3RoundTripper = rt +func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() { + uc.mu.Lock() + defer uc.mu.Unlock() + switch uc.IPStack { + case IpStackBoth, "": + uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + case IpStackV4: + uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4) + case IpStackV6: + uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) + case IpStackSplit: + uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if ctrldnet.IPv6Available(ctx) { + uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) + } else { + uc.http3RoundTripper6 = uc.http3RoundTripper4 + } + uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + } +} + +func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { + switch uc.IPStack { + case IpStackBoth, IpStackV4, IpStackV6: + return uc.http3RoundTripper + case IpStackSplit: + switch dnsType { + case dns.TypeA: + return uc.http3RoundTripper4 + default: + return uc.http3RoundTripper6 + } + } + return uc.http3RoundTripper } // Putting the code for quic parallel dialer here: diff --git a/config_quic_free.go b/config_quic_free.go index 3817e51..a4b1bdd 100644 --- a/config_quic_free.go +++ b/config_quic_free.go @@ -2,6 +2,9 @@ package ctrld +import "net/http" + func (uc *UpstreamConfig) setupDOH3Transport() {} -func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {} +func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {} +func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { return nil } diff --git a/docs/config.md b/docs/config.md index e8cec53..fc78e98 100644 --- a/docs/config.md +++ b/docs/config.md @@ -227,9 +227,27 @@ Value `0` means no timeout. The protocol that `ctrld` will use to send DNS requests to upstream. - Type: string - - required: yes + - Required: yes - Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os` +### ip_stack +Specifying what kind of ip stack that `ctrld` will use to connect to upstream. + + - Type: string + - Required: no + - Valid values: + - `both`: using either ipv4 or ipv6. + - `v4`: only dial upstream via IPv4, never dial IPv6. + - `v6`: only dial upstream via IPv6, never dial IPv4. + - `split`: + - If `A` record is requested -> dial via ipv4. + - If `AAAA` or any other record is requested -> dial ipv6 (if available, otherwise ipv4) + +If `ip_stack` is empty, or undefined: + + - Default value is `both` for non-Control D resolvers. + - Default value is `split` for Control D resolvers. + ## Network The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs. diff --git a/doh.go b/doh.go index 000cc49..e831feb 100644 --- a/doh.go +++ b/doh.go @@ -36,12 +36,12 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver { } type dohResolver struct { + uc *UpstreamConfig endpoint *url.URL isDoH3 bool transport *http.Transport http3RoundTripper http.RoundTripper sendClientInfo bool - uc *UpstreamConfig } func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { @@ -61,18 +61,22 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return nil, fmt.Errorf("could not create request: %w", err) } addHeader(ctx, req, r.sendClientInfo) - - c := http.Client{Transport: r.transport} + dnsTyp := uint16(0) + if len(msg.Question) > 0 { + dnsTyp = msg.Question[0].Qtype + } + c := http.Client{Transport: r.uc.dohTransport(dnsTyp)} if r.isDoH3 { - if r.http3RoundTripper == nil { + transport := r.uc.doh3Transport(dnsTyp) + if transport == nil { return nil, errors.New("DoH3 is not supported") } - c.Transport = r.http3RoundTripper + c.Transport = transport } resp, err := c.Do(req) if err != nil { if r.isDoH3 { - if closer, ok := r.http3RoundTripper.(io.Closer); ok { + if closer, ok := c.Transport.(io.Closer); ok { closer.Close() } } diff --git a/doq.go b/doq.go index 20919e3..f0a7bc1 100644 --- a/doq.go +++ b/doq.go @@ -20,11 +20,17 @@ type doqResolver struct { func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { endpoint := r.uc.Endpoint tlsConfig := &tls.Config{NextProtos: []string{"doq"}} - if r.uc.BootstrapIP != "" { - tlsConfig.ServerName = r.uc.Domain - _, port, _ := net.SplitHostPort(endpoint) - endpoint = net.JoinHostPort(r.uc.BootstrapIP, port) + ip := r.uc.BootstrapIP + if ip == "" { + dnsTyp := uint16(0) + if len(msg.Question) > 0 { + dnsTyp = msg.Question[0].Qtype + } + ip = r.uc.bootstrapIPForDNSType(dnsTyp) } + tlsConfig.ServerName = r.uc.Domain + _, port, _ := net.SplitHostPort(endpoint) + endpoint = net.JoinHostPort(ip, port) return resolve(ctx, msg, endpoint, tlsConfig) } diff --git a/dot.go b/dot.go index 11befb7..2e11a8f 100644 --- a/dot.go +++ b/dot.go @@ -14,13 +14,19 @@ type dotResolver struct { func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { // The dialer is used to prevent bootstrapping cycle. - // If r.endpoing is set to dns.controld.dev, we need to resolve + // If r.endpoint is set to dns.controld.dev, we need to resolve // dns.controld.dev first. By using a dialer with custom resolver, // we ensure that we can always resolve the bootstrap domain // regardless of the machine DNS status. dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53")) + dnsTyp := uint16(0) + if len(msg.Question) > 0 { + dnsTyp = msg.Question[0].Qtype + } + + tcpNet, _ := r.uc.netForDNSType(dnsTyp) dnsClient := &dns.Client{ - Net: "tcp-tls", + Net: tcpNet, Dialer: dialer, TLSConfig: &tls.Config{RootCAs: r.uc.certPool}, } diff --git a/resolver.go b/resolver.go index f2a8424..89b98b6 100644 --- a/resolver.go +++ b/resolver.go @@ -34,7 +34,7 @@ var errUnknownResolver = errors.New("unknown resolver") // NewResolver creates a Resolver based on the given upstream config. func NewResolver(uc *UpstreamConfig) (Resolver, error) { - typ, endpoint := uc.Type, uc.Endpoint + typ := uc.Type switch typ { case ResolverTypeDOH, ResolverTypeDOH3: return newDohResolver(uc), nil @@ -45,7 +45,7 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) { case ResolverTypeOS: return or, nil case ResolverTypeLegacy: - return &legacyResolver{endpoint: endpoint}, nil + return &legacyResolver{uc: uc}, nil } return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ) } @@ -110,17 +110,22 @@ func newDialer(dnsAddress string) *net.Dialer { } type legacyResolver struct { - endpoint string + uc *UpstreamConfig } func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { // See comment in (*dotResolver).resolve method. dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53")) + dnsTyp := uint16(0) + if len(msg.Question) > 0 { + dnsTyp = msg.Question[0].Qtype + } + _, udpNet := r.uc.netForDNSType(dnsTyp) dnsClient := &dns.Client{ - Net: "udp", + Net: udpNet, Dialer: dialer, } - answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint) + answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.uc.Endpoint) return answer, err }