mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
We see in practice on fresh new VM test, there's a DNS server that return the answer with record not for the query domain. To workaround this, filter out the answers not for the query domain.
192 lines
4.8 KiB
Go
192 lines
4.8 KiB
Go
package ctrld
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
const (
|
|
ResolverTypeDOH = "doh"
|
|
ResolverTypeDOH3 = "doh3"
|
|
ResolverTypeDOT = "dot"
|
|
ResolverTypeDOQ = "doq"
|
|
ResolverTypeOS = "os"
|
|
ResolverTypeLegacy = "legacy"
|
|
)
|
|
|
|
var bootstrapDNS = "76.76.2.0"
|
|
var or = &osResolver{nameservers: nameservers()}
|
|
|
|
// Resolver is the interface that wraps the basic DNS operations.
|
|
//
|
|
// Resolve resolves the DNS query, return the result and the corresponding error.
|
|
type Resolver interface {
|
|
Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error)
|
|
}
|
|
|
|
var errUnknownResolver = errors.New("unknown resolver")
|
|
|
|
// NewResolver creates a Resolver based on the given upstream config.
|
|
func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
|
typ, endpoint := uc.Type, uc.Endpoint
|
|
switch typ {
|
|
case ResolverTypeDOH, ResolverTypeDOH3:
|
|
return newDohResolver(uc), nil
|
|
case ResolverTypeDOT:
|
|
return &dotResolver{uc: uc}, nil
|
|
case ResolverTypeDOQ:
|
|
return &doqResolver{uc: uc}, nil
|
|
case ResolverTypeOS:
|
|
return or, nil
|
|
case ResolverTypeLegacy:
|
|
return &legacyResolver{endpoint: endpoint}, nil
|
|
}
|
|
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
|
|
}
|
|
|
|
type osResolver struct {
|
|
nameservers []string
|
|
}
|
|
|
|
type osResolverResult struct {
|
|
answer *dns.Msg
|
|
err error
|
|
}
|
|
|
|
// Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from
|
|
// available nameservers with a roundrobin algorithm.
|
|
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
|
numServers := len(o.nameservers)
|
|
if numServers == 0 {
|
|
return nil, errors.New("no nameservers available")
|
|
}
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
dnsClient := &dns.Client{Net: "udp"}
|
|
ch := make(chan *osResolverResult, numServers)
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(o.nameservers))
|
|
go func() {
|
|
wg.Wait()
|
|
close(ch)
|
|
}()
|
|
for _, server := range o.nameservers {
|
|
go func(server string) {
|
|
defer wg.Done()
|
|
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
|
ch <- &osResolverResult{answer: answer, err: err}
|
|
}(server)
|
|
}
|
|
|
|
errs := make([]error, 0, numServers)
|
|
for res := range ch {
|
|
if res.err == nil {
|
|
cancel()
|
|
return res.answer, res.err
|
|
}
|
|
errs = append(errs, res.err)
|
|
}
|
|
|
|
return nil, errors.Join(errs...)
|
|
}
|
|
|
|
func newDialer(dnsAddress string) *net.Dialer {
|
|
return &net.Dialer{
|
|
Resolver: &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := net.Dialer{}
|
|
return d.DialContext(ctx, network, dnsAddress)
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
type legacyResolver struct {
|
|
endpoint string
|
|
}
|
|
|
|
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
|
// See comment in (*dotResolver).resolve method.
|
|
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
|
dnsClient := &dns.Client{
|
|
Net: "udp",
|
|
Dialer: dialer,
|
|
}
|
|
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint)
|
|
return answer, err
|
|
}
|
|
|
|
// LookupIP looks up host using OS resolver.
|
|
// It returns a slice of that host's IPv4 and IPv6 addresses.
|
|
func LookupIP(domain string) []string {
|
|
return lookupIP(domain, -1, true)
|
|
}
|
|
|
|
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
|
resolver := &osResolver{nameservers: nameservers()}
|
|
if withBootstrapDNS {
|
|
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
|
}
|
|
ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
|
|
timeoutMs := 2000
|
|
if timeout > 0 && timeout < timeoutMs {
|
|
timeoutMs = timeoutMs
|
|
}
|
|
questionDomain := dns.Fqdn(domain)
|
|
ipFromRecord := func(record dns.RR) string {
|
|
switch ar := record.(type) {
|
|
case *dns.A:
|
|
if ar.Hdr.Name != questionDomain {
|
|
return ""
|
|
}
|
|
return ar.A.String()
|
|
case *dns.AAAA:
|
|
if ar.Hdr.Name != questionDomain {
|
|
return ""
|
|
}
|
|
return ar.AAAA.String()
|
|
}
|
|
return ""
|
|
}
|
|
|
|
lookup := func(dnsType uint16) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
|
defer cancel()
|
|
m := new(dns.Msg)
|
|
m.SetQuestion(questionDomain, dnsType)
|
|
m.RecursionDesired = true
|
|
|
|
r, err := resolver.Resolve(ctx, m)
|
|
if err != nil {
|
|
ProxyLog.Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain)
|
|
return
|
|
}
|
|
if r.Rcode != dns.RcodeSuccess {
|
|
ProxyLog.Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode])
|
|
return
|
|
}
|
|
if len(r.Answer) == 0 {
|
|
ProxyLog.Error().Msg("no answer from OS resolver")
|
|
return
|
|
}
|
|
for _, a := range r.Answer {
|
|
if ip := ipFromRecord(a); ip != "" {
|
|
ips = append(ips, ip)
|
|
}
|
|
}
|
|
}
|
|
// Find all A, AAAA records of the domain.
|
|
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
|
|
lookup(dnsType)
|
|
}
|
|
return ips
|
|
}
|