all: fix bug that causes ctrld stop working if bootstrap failed

The bootstrap process has two issues that can make ctrld stop resolving
after restarting machine host.

ctrld uses bootstrap DNS and os nameservers for resolving upstream. On
unix, /etc/resolv.conf content is used to get available nameservers.
This works well when installing ctrld. However, after being installed,
ctrld may modify the content of /etc/resolv.conf itself, to make other
apps use its listener as DNS resolver. So when ctrld starts after OS
restart, it ends up using [bootstrap DNS + ctrld's listener], for
resolving upstream. At this moment, if ctrld could not contact bootstrap
DNS for any reason, upstream domain will not be resolved.

For above reason, an upstream may not have bootstrap IPs after ctrld
starts. When re-bootstrapping, if there's no bootstrap IPs, ctrld should
call the setup bootstrap process again. Currently, it does not, causing
all queries failed.

This commit fixes above issue by adding mechanism for retrieving OS
nameservers properly, by querying routing table information:

 - Parsing /proc/net subsystem on Linux.
 - For BSD variants, just fetching routing information base from OS.
 - On Windows, just include the gateway information when reading iface.

The fixing for second issue is trivial, just kickoff a bootstrap process
if there's no bootstrap IPs when re-boostrapping.

While at it, also ensure that fetching resolver information from
ControlD API is also used the same approach.

Fixes #34
This commit is contained in:
Cuong Manh Le
2023-03-30 18:25:36 +07:00
committed by Cuong Manh Le
parent ba48ff5965
commit b65a5ac283
10 changed files with 280 additions and 64 deletions

View File

@@ -155,6 +155,12 @@ func (uc *UpstreamConfig) Init() {
// SetupBootstrapIP manually find all available IPs of the upstream.
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) SetupBootstrapIP() {
uc.setupBootstrapIP(true)
}
// SetupBootstrapIP manually find all available IPs of the upstream.
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
bootstrapIP := func(record dns.RR) string {
switch ar := record.(type) {
case *dns.A:
@@ -166,7 +172,9 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
}
resolver := &osResolver{nameservers: availableNameservers()}
if withBootstrapDNS {
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
}
ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", uc.Domain, resolver.nameservers)
timeoutMs := 2000
if uc.Timeout > 0 && uc.Timeout < timeoutMs {
@@ -228,9 +236,13 @@ func (uc *UpstreamConfig) ReBootstrap() {
default:
return
}
_, _, _ = uc.g.Do("rebootstrap", func() (any, error) {
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
ProxyLog.Debug().Msg("re-bootstrapping upstream ip")
n := uint32(len(uc.bootstrapIPs))
if n == 0 {
uc.SetupBootstrapIP()
uc.setupTransportWithoutPingUpstream()
}
timeoutMs := 1000
if uc.Timeout > 0 && uc.Timeout < timeoutMs {

20
config_internal_test.go Normal file
View File

@@ -0,0 +1,20 @@
package ctrld
import (
"testing"
)
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
uc := &UpstreamConfig{
Name: "test",
Type: ResolverTypeDOH,
Endpoint: "https://freedns.controld.com/p2",
Timeout: 5000,
}
uc.Init()
uc.setupBootstrapIP(false)
if uc.BootstrapIP == "" {
t.Fatal("could not bootstrap ip without bootstrap DNS")
}
t.Log(uc)
}

2
go.mod
View File

@@ -20,6 +20,7 @@ require (
github.com/spf13/cobra v1.4.0
github.com/spf13/viper v1.14.0
github.com/stretchr/testify v1.8.1
golang.org/x/net v0.7.0
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -68,7 +69,6 @@ require (
golang.org/x/crypto v0.4.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.6.0 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/text v0.7.0 // indirect
golang.org/x/tools v0.2.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect

2
go.sum
View File

@@ -54,8 +54,6 @@ github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534 h1:rtAn27
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cuonglm/osinfo v0.0.0-20230329052356-117e0ee9d353 h1:PFKlvMrKAUendoPEiJxSMkYeGG/G/5k7vu2ldGBnq3I=
github.com/cuonglm/osinfo v0.0.0-20230329052356-117e0ee9d353/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs=
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View File

@@ -7,16 +7,26 @@ import (
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/miekg/dns"
"github.com/Control-D-Inc/ctrld"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
const (
apiDomain = "api.controld.com"
resolverDataURL = "https://api.controld.com/utility"
InvalidConfigCode = 40401
)
var (
resolveAPIDomainOnce sync.Once
apiDomainIP string
)
// ResolverConfig represents Control D resolver data.
type ResolverConfig struct {
DOH string `json:"doh"`
@@ -64,6 +74,44 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) {
if ctrldnet.SupportsIPv4() {
proto = "tcp4"
}
resolveAPIDomainOnce.Do(func() {
r, err := ctrld.NewResolver(&ctrld.UpstreamConfig{Type: ctrld.ResolverTypeOS})
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg := new(dns.Msg)
dnsType := dns.TypeAAAA
if proto == "tcp4" {
dnsType = dns.TypeA
}
msg.SetQuestion(apiDomain+".", dnsType)
msg.RecursionDesired = true
answer, err := r.Resolve(ctx, msg)
if err != nil {
return
}
if answer.Rcode != dns.RcodeSuccess || len(answer.Answer) == 0 {
return
}
for _, record := range answer.Answer {
switch ar := record.(type) {
case *dns.A:
apiDomainIP = ar.A.String()
return
case *dns.AAAA:
apiDomainIP = ar.AAAA.String()
return
}
}
})
if apiDomainIP != "" {
if _, port, _ := net.SplitHostPort(addr); port != "" {
return ctrldnet.Dialer.DialContext(ctx, proto, net.JoinHostPort(apiDomainIP, port))
}
}
return ctrldnet.Dialer.DialContext(ctx, proto, addr)
}
client := http.Client{

58
nameservers_bsd.go Normal file
View File

@@ -0,0 +1,58 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package ctrld
import (
"net"
"syscall"
"golang.org/x/net/route"
)
func osNameservers() []string {
var dns []string
seen := make(map[string]bool)
rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
if err != nil {
return nil
}
messages, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return nil
}
for _, message := range messages {
message, ok := message.(*route.RouteMessage)
if !ok {
continue
}
addresses := message.Addrs
if len(addresses) < 2 {
continue
}
dst, gw := toNetIP(addresses[0]), toNetIP(addresses[1])
if dst == nil || gw == nil {
continue
}
if gw.IsLoopback() || seen[gw.String()] {
continue
}
if dst.Equal(net.IPv4zero) || dst.Equal(net.IPv6zero) {
seen[gw.String()] = true
dns = append(dns, net.JoinHostPort(gw.String(), "53"))
}
}
return dns
}
func toNetIP(addr route.Addr) net.IP {
switch t := addr.(type) {
case *route.Inet4Addr:
return net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
case *route.Inet6Addr:
ip := make(net.IP, net.IPv6len)
copy(ip, t.IP[:])
return ip
default:
return nil
}
}

88
nameservers_linux.go Normal file
View File

@@ -0,0 +1,88 @@
package ctrld
import (
"bufio"
"bytes"
"encoding/hex"
"net"
"os"
)
const (
v4RouteFile = "/proc/net/route"
v6RouteFile = "/proc/net/ipv6_route"
)
func osNameservers() []string {
ns4 := dns4()
ns6 := dns6()
ns := make([]string, len(ns4)+len(ns6))
ns = append(ns, ns4...)
ns = append(ns, ns6...)
return ns
}
func dns4() []string {
f, err := os.Open(v4RouteFile)
if err != nil {
return nil
}
defer f.Close()
var dns []string
seen := make(map[string]bool)
s := bufio.NewScanner(f)
first := true
for s.Scan() {
if first {
first = false
continue
}
fields := bytes.Fields(s.Bytes())
if len(fields) < 2 {
continue
}
gw := make([]byte, net.IPv4len)
// Third fields is gateway.
if _, err := hex.Decode(gw, fields[2]); err != nil {
continue
}
ip := net.IPv4(gw[3], gw[2], gw[1], gw[0])
if ip.Equal(net.IPv4zero) || seen[ip.String()] {
continue
}
seen[ip.String()] = true
dns = append(dns, net.JoinHostPort(ip.String(), "53"))
}
return dns
}
func dns6() []string {
f, err := os.Open(v6RouteFile)
if err != nil {
return nil
}
defer f.Close()
var dns []string
s := bufio.NewScanner(f)
for s.Scan() {
fields := bytes.Fields(s.Bytes())
if len(fields) < 4 {
continue
}
gw := make([]byte, net.IPv6len)
// Fifth fields is gateway.
if _, err := hex.Decode(gw, fields[4]); err != nil {
continue
}
ip := net.IP(gw)
if ip.Equal(net.IPv6zero) {
continue
}
dns = append(dns, net.JoinHostPort(ip.String(), "53"))
}
return dns
}

11
nameservers_test.go Normal file
View File

@@ -0,0 +1,11 @@
package ctrld
import "testing"
func TestNameservers(t *testing.T) {
ns := nameservers()
if len(ns) == 0 {
t.Fatal("failed to get nameservers")
}
t.Log(ns)
}

View File

@@ -1,11 +1,7 @@
//go:build !js && !windows
//go:build unix
package ctrld
import (
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
func nameservers() []string {
return resolvconffile.NameServersWithPort()
return osNameservers()
}

View File

@@ -2,24 +2,24 @@ package ctrld
import (
"net"
"os"
"syscall"
"unsafe"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"golang.org/x/sys/windows"
)
func nameservers() []string {
aas, err := adapterAddresses()
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix)
if err != nil {
return nil
}
ns := make([]string, 0, len(aas))
for _, aa := range aas {
for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next {
sa, err := dns.Address.Sockaddr.Sockaddr()
ns := make([]string, 0, len(aas)*2)
seen := make(map[string]bool)
do := func(addr windows.SocketAddress) {
sa, err := addr.Sockaddr.Sockaddr()
if err != nil {
continue
return
}
var ip net.IP
switch sa := sa.(type) {
@@ -32,40 +32,25 @@ func nameservers() []string {
// Ignore these fec0/10 ones. Windows seems to
// populate them as defaults on its misc rando
// interfaces.
continue
return
}
default:
// Unexpected type.
continue
return
}
if ip.IsLoopback() || seen[ip.String()] {
return
}
seen[ip.String()] = true
ns = append(ns, net.JoinHostPort(ip.String(), "53"))
}
for _, aa := range aas {
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
do(dns.Address)
}
for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next {
do(gw.Address)
}
}
return ns
}
func adapterAddresses() ([]*windows.IpAdapterAddresses, error) {
var b []byte
l := uint32(15000) // recommended initial size
for {
b = make([]byte, l)
err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, windows.GAA_FLAG_INCLUDE_PREFIX, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l)
if err == nil {
if l == 0 {
return nil, nil
}
break
}
if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW {
return nil, os.NewSyscallError("getadaptersaddresses", err)
}
if l <= uint32(len(b)) {
return nil, os.NewSyscallError("getadaptersaddresses", err)
}
}
var aas []*windows.IpAdapterAddresses
for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next {
aas = append(aas, aa)
}
return aas, nil
}