mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
all: use parallel dialer for connecting upstream/api
So we don't have to depend on network stack probing to decide whether ipv4 or ipv6 will be used. While at it, also prevent a race report when doing the same parallel resolving for os resolver, even though this race is harmless.
This commit is contained in:
committed by
Cuong Manh Le
parent
d3d08022cc
commit
d52cd11322
@@ -78,7 +78,7 @@ var rootCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func curVersion() string {
|
||||
if version != "dev" {
|
||||
if version != "dev" && !strings.HasPrefix(version, "v") {
|
||||
version = "v" + version
|
||||
}
|
||||
if len(commit) > 7 {
|
||||
|
||||
69
config.go
69
config.go
@@ -177,71 +177,20 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
bootstrapIP := func(record dns.RR) string {
|
||||
switch ar := record.(type) {
|
||||
case *dns.A:
|
||||
return ar.A.String()
|
||||
case *dns.AAAA:
|
||||
return ar.AAAA.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
||||
for _, ip := range uc.bootstrapIPs {
|
||||
if uc.BootstrapIP == "" {
|
||||
// Remember what's the current IP in bootstrap IPs list,
|
||||
// so we can select next one upon re-bootstrapping.
|
||||
uc.nextBootstrapIP.Add(1)
|
||||
|
||||
resolver := &osResolver{nameservers: availableNameservers()}
|
||||
if withBootstrapDNS {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
}
|
||||
ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", uc.Domain, resolver.nameservers)
|
||||
timeoutMs := 2000
|
||||
if uc.Timeout > 0 && uc.Timeout < timeoutMs {
|
||||
timeoutMs = uc.Timeout
|
||||
}
|
||||
do := func(dnsType uint16) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(uc.Domain+".", dnsType)
|
||||
m.RecursionDesired = true
|
||||
|
||||
r, err := resolver.Resolve(ctx, m)
|
||||
if err != nil {
|
||||
ProxyLog.Error().Err(err).Str("type", dns.TypeToString[dnsType]).Msgf("could not resolve domain %s for upstream", uc.Domain)
|
||||
return
|
||||
}
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
ProxyLog.Error().Msgf("could not resolve domain %q, return code: %s", uc.Domain, dns.RcodeToString[r.Rcode])
|
||||
return
|
||||
}
|
||||
if len(r.Answer) == 0 {
|
||||
ProxyLog.Error().Msg("no answer from bootstrap DNS server")
|
||||
return
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
ip := bootstrapIP(a)
|
||||
if ip == "" {
|
||||
// If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip.
|
||||
if !ctrldnet.SupportsIPv6() && ctrldnet.IsIPv6(ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Storing the ip to uc.bootstrapIPs list, so it can be selected later
|
||||
// when retrying failed request due to network stack changed.
|
||||
uc.bootstrapIPs = append(uc.bootstrapIPs, ip)
|
||||
if uc.BootstrapIP == "" {
|
||||
// Remember what's the current IP in bootstrap IPs list,
|
||||
// so we can select next one upon re-bootstrapping.
|
||||
uc.nextBootstrapIP.Add(1)
|
||||
|
||||
// If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip.
|
||||
if !ctrldnet.SupportsIPv6() && ctrldnet.IsIPv6(ip) {
|
||||
continue
|
||||
}
|
||||
uc.BootstrapIP = ip
|
||||
}
|
||||
uc.BootstrapIP = ip
|
||||
}
|
||||
}
|
||||
// Find all A, AAAA records of the upstream.
|
||||
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
|
||||
do(dnsType)
|
||||
}
|
||||
ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,11 +8,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/certs"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
@@ -25,11 +22,6 @@ const (
|
||||
InvalidConfigCode = 40401
|
||||
)
|
||||
|
||||
var (
|
||||
resolveAPIDomainOnce sync.Once
|
||||
apiDomainIP string
|
||||
)
|
||||
|
||||
// ResolverConfig represents Control D resolver data.
|
||||
type ResolverConfig struct {
|
||||
DOH string `json:"doh"`
|
||||
@@ -71,51 +63,19 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) {
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// We experiment hanging in TLS handshake when connecting to ControlD API
|
||||
// with ipv6. So prefer ipv4 if available.
|
||||
proto := "tcp6"
|
||||
if ctrldnet.SupportsIPv4() {
|
||||
proto = "tcp4"
|
||||
ips := ctrld.LookupIP(apiDomain)
|
||||
if len(ips) == 0 {
|
||||
ctrld.ProxyLog.Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr)
|
||||
return ctrldnet.Dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
resolveAPIDomainOnce.Do(func() {
|
||||
r, err := ctrld.NewResolver(&ctrld.UpstreamConfig{Type: ctrld.ResolverTypeOS})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
msg := new(dns.Msg)
|
||||
dnsType := dns.TypeAAAA
|
||||
if proto == "tcp4" {
|
||||
dnsType = dns.TypeA
|
||||
}
|
||||
msg.SetQuestion(apiDomain+".", dnsType)
|
||||
msg.RecursionDesired = true
|
||||
answer, err := r.Resolve(ctx, msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess || len(answer.Answer) == 0 {
|
||||
return
|
||||
}
|
||||
for _, record := range answer.Answer {
|
||||
switch ar := record.(type) {
|
||||
case *dns.A:
|
||||
apiDomainIP = ar.A.String()
|
||||
return
|
||||
case *dns.AAAA:
|
||||
apiDomainIP = ar.AAAA.String()
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
if apiDomainIP != "" {
|
||||
if _, port, _ := net.SplitHostPort(addr); port != "" {
|
||||
return ctrldnet.Dialer.DialContext(ctx, proto, net.JoinHostPort(apiDomainIP, port))
|
||||
}
|
||||
ctrld.ProxyLog.Debug().Msgf("API IPs: %v", ips)
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
addrs := make([]string, len(ips))
|
||||
for i := range ips {
|
||||
addrs[i] = net.JoinHostPort(ips[i], port)
|
||||
}
|
||||
return ctrldnet.Dialer.DialContext(ctx, proto, addr)
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
return d.DialContext(ctx, network, addrs)
|
||||
}
|
||||
|
||||
if router.Name() == router.DDWrt {
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const utilityURL = "https://api.controld.com/utility"
|
||||
|
||||
func TestFetchResolverConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -2,6 +2,7 @@ package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -37,7 +38,6 @@ var probeStackDialer = &net.Dialer{
|
||||
|
||||
var (
|
||||
stackOnce atomic.Pointer[sync.Once]
|
||||
ipv4Enabled bool
|
||||
ipv6Enabled bool
|
||||
canListenIPv6Local bool
|
||||
hasNetworkUp bool
|
||||
@@ -75,7 +75,6 @@ func probeStack() {
|
||||
b.BackOff(context.Background(), err)
|
||||
}
|
||||
}
|
||||
ipv4Enabled = supportIPv4()
|
||||
ipv6Enabled = supportIPv6(context.Background())
|
||||
canListenIPv6Local = supportListenIPv6Local()
|
||||
}
|
||||
@@ -85,11 +84,6 @@ func Up() bool {
|
||||
return hasNetworkUp
|
||||
}
|
||||
|
||||
func SupportsIPv4() bool {
|
||||
stackOnce.Load().Do(probeStack)
|
||||
return ipv4Enabled
|
||||
}
|
||||
|
||||
func SupportsIPv6() bool {
|
||||
stackOnce.Load().Do(probeStack)
|
||||
return ipv6Enabled
|
||||
@@ -112,3 +106,47 @@ func IsIPv6(ip string) bool {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
return parsedIP != nil && parsedIP.To4() == nil && parsedIP.To16() != nil
|
||||
}
|
||||
|
||||
type parallelDialerResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type ParallelDialer struct {
|
||||
net.Dialer
|
||||
}
|
||||
|
||||
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
ch := make(chan *parallelDialerResult, len(addrs))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(addrs))
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
for _, addr := range addrs {
|
||||
go func(addr string) {
|
||||
defer wg.Done()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, addr)
|
||||
ch <- ¶llelDialerResult{conn: conn, err: err}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
errs := make([]error, 0, len(addrs))
|
||||
for res := range ch {
|
||||
if res.err == nil {
|
||||
cancel()
|
||||
return res.conn, res.err
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
62
resolver.go
62
resolver.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@@ -79,7 +80,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
for _, server := range o.nameservers {
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg, server)
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||
ch <- &osResolverResult{answer: answer, err: err}
|
||||
}(server)
|
||||
}
|
||||
@@ -122,3 +123,62 @@ func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, e
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint)
|
||||
return answer, err
|
||||
}
|
||||
|
||||
// LookupIP looks up host using OS resolver.
|
||||
// It returns a slice of that host's IPv4 and IPv6 addresses.
|
||||
func LookupIP(domain string) []string {
|
||||
return lookupIP(domain, -1, true)
|
||||
}
|
||||
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
resolver := &osResolver{nameservers: availableNameservers()}
|
||||
if withBootstrapDNS {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
}
|
||||
ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
|
||||
timeoutMs := 2000
|
||||
if timeout > 0 && timeout < timeoutMs {
|
||||
timeoutMs = timeoutMs
|
||||
}
|
||||
ipFromRecord := func(record dns.RR) string {
|
||||
switch ar := record.(type) {
|
||||
case *dns.A:
|
||||
return ar.A.String()
|
||||
case *dns.AAAA:
|
||||
return ar.AAAA.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
lookup := func(dnsType uint16) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(domain+".", dnsType)
|
||||
m.RecursionDesired = true
|
||||
|
||||
r, err := resolver.Resolve(ctx, m)
|
||||
if err != nil {
|
||||
ProxyLog.Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain)
|
||||
return
|
||||
}
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
ProxyLog.Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode])
|
||||
return
|
||||
}
|
||||
if len(r.Answer) == 0 {
|
||||
ProxyLog.Error().Msg("no answer from OS resolver")
|
||||
return
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
if ip := ipFromRecord(a); ip != "" {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Find all A, AAAA records of the domain.
|
||||
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
|
||||
lookup(dnsType)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user