Merge pull request #209 from Control-D-Inc/release-branch-v1.4.0

Release branch v1.4.0
This commit is contained in:
Cuong Manh Le
2025-02-12 14:55:47 +07:00
committed by GitHub
42 changed files with 4517 additions and 1566 deletions

View File

@@ -8,3 +8,8 @@ import (
// addExtraSplitDnsRule adds split DNS rule if present.
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
// getActiveDirectoryDomain returns AD domain name of this computer.
func getActiveDirectoryDomain() (string, error) {
return "", nil
}

View File

@@ -1,9 +1,14 @@
package cli
import (
"fmt"
"io"
"log"
"os"
"strings"
"github.com/microsoft/wmi/pkg/base/host"
hh "github.com/microsoft/wmi/pkg/hardware/host"
"github.com/Control-D-Inc/ctrld"
)
@@ -21,29 +26,48 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
// Network rules are lowercase during toml config marshaling,
// lowercase the domain here too for consistency.
domain = strings.ToLower(domain)
domainRuleAdded := addSplitDnsRule(cfg, domain)
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
return domainRuleAdded || wildcardDomainRuleRuleAdded
}
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
// The return value indicates whether the split-rule was added or not.
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
for n, lc := range cfg.Listener {
if lc.Policy == nil {
lc.Policy = &ctrld.ListenerPolicyConfig{}
}
domainRule := "*." + strings.TrimPrefix(domain, ".")
for _, rule := range lc.Policy.Rules {
if _, ok := rule[domainRule]; ok {
mainLog.Load().Debug().Msgf("domain rule already exist for listener.%s", n)
if _, ok := rule[domain]; ok {
mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n)
return false
}
}
mainLog.Load().Debug().Msgf("adding active directory domain for listener.%s", n)
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}})
mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n)
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
}
return true
}
// getActiveDirectoryDomain returns AD domain name of this computer.
func getActiveDirectoryDomain() (string, error) {
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
output, err := powershell(cmd)
if err != nil {
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
whost := host.NewWmiLocalHost()
cs, err := hh.GetComputerSystem(whost)
if cs != nil {
defer cs.Close()
}
return string(output), nil
if err != nil {
return "", err
}
pod, err := cs.GetPropertyPartOfDomain()
if err != nil {
return "", err
}
if pod {
return cs.GetPropertyDomain()
}
return "", nil
}

View File

@@ -0,0 +1,71 @@
package cli
import (
"fmt"
"testing"
"time"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/testhelper"
"github.com/stretchr/testify/assert"
)
func Test_getActiveDirectoryDomain(t *testing.T) {
start := time.Now()
domain, err := getActiveDirectoryDomain()
if err != nil {
t.Fatal(err)
}
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
start = time.Now()
domainPowershell, err := getActiveDirectoryDomainPowershell()
if err != nil {
t.Fatal(err)
}
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
if domain != domainPowershell {
t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain)
}
}
func getActiveDirectoryDomainPowershell() (string, error) {
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
output, err := powershell(cmd)
if err != nil {
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
}
return string(output), nil
}
func Test_addSplitDnsRule(t *testing.T) {
newCfg := func(domains ...string) *ctrld.Config {
cfg := testhelper.SampleConfig(t)
lc := cfg.Listener["0"]
for _, domain := range domains {
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
}
return cfg
}
tests := []struct {
name string
cfg *ctrld.Config
domain string
added bool
}{
{"added", newCfg(), "example.com", true},
{"TLD existed", newCfg("example.com"), "*.example.com", true},
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
added := addSplitDnsRule(tc.cfg, tc.domain)
assert.Equal(t, tc.added, added)
})
}
}

File diff suppressed because it is too large Load Diff

1362
cmd/cli/commands.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -25,6 +25,10 @@ func newControlClient(addr string) *controlClient {
}
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
// for log/send, set the timeout to 5 minutes
if path == sendLogsPath {
c.c.Timeout = time.Minute * 5
}
return c.c.Post("http://unix"+path, contentTypeJson, data)
}

View File

@@ -3,6 +3,8 @@ package cli
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
@@ -25,8 +27,16 @@ const (
deactivationPath = "/deactivation"
cdPath = "/cd"
ifacePath = "/iface"
viewLogsPath = "/log/view"
sendLogsPath = "/log/send"
)
type ifaceResponse struct {
Name string `json:"name"`
All bool `json:"all"`
OK bool `json:"ok"`
}
type controlServer struct {
server *http.Server
mux *http.ServeMux
@@ -201,15 +211,76 @@ func (p *prog) registerControlServerHandler() {
w.WriteHeader(http.StatusBadRequest)
}))
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
res := &ifaceResponse{Name: iface}
// p.setDNS is only called when running as a service
if !service.Interactive() {
<-p.csSetDnsDone
if p.csSetDnsOk {
w.Write([]byte(iface))
return
res.Name = p.runningIface
res.All = p.requiredMultiNICsConfig
res.OK = true
}
}
w.WriteHeader(http.StatusBadRequest)
if err := json.NewEncoder(w).Encode(res); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
return
}
}))
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
lr, err := p.logReader()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer lr.r.Close()
if lr.size == 0 {
w.WriteHeader(http.StatusMovedPermanently)
return
}
data, err := io.ReadAll(lr.r)
if err != nil {
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
return
}
}))
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
if time.Since(p.internalLogSent) < logSentInterval {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
r, err := p.logReader()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.size == 0 {
w.WriteHeader(http.StatusMovedPermanently)
return
}
req := &controld.LogsRequest{
UID: cdUID,
Data: r.r,
}
mainLog.Load().Debug().Msg("sending log file to ControlD server")
resp := logSentResponse{Size: r.size}
if err := controld.SendLogs(req, cdDev); err != nil {
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
resp.Error = err.Error()
w.WriteHeader(http.StatusInternalServerError)
} else {
mainLog.Load().Debug().Msg("sending log file successfully")
w.WriteHeader(http.StatusOK)
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
p.internalLogSent = time.Now()
}))
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/netip"
"os/exec"
"runtime"
"slices"
"strconv"
@@ -19,6 +20,7 @@ import (
"golang.org/x/sync/errgroup"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/controld"
@@ -41,7 +43,7 @@ const (
var osUpstreamConfig = &ctrld.UpstreamConfig{
Name: "OS resolver",
Type: ctrld.ResolverTypeOS,
Timeout: 2000,
Timeout: 3000,
}
var privateUpstreamConfig = &ctrld.UpstreamConfig{
@@ -50,6 +52,12 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{
Timeout: 2000,
}
var localUpstreamConfig = &ctrld.UpstreamConfig{
Name: "Local resolver",
Type: ctrld.ResolverTypeLocal,
Timeout: 2000,
}
// proxyRequest contains data for proxying a DNS query to upstream.
type proxyRequest struct {
msg *dns.Msg
@@ -76,7 +84,13 @@ type upstreamForResult struct {
srcAddr string
}
func (p *prog) serveDNS(listenerNum string) error {
func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
// Start network monitoring
if err := p.monitorNetworkChanges(mainCtx); err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring")
// Don't return here as we still want DNS service to run
}
listenerConfig := p.cfg.Listener[listenerNum]
// make sure ip is allocated
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
@@ -106,11 +120,18 @@ func (p *prog) serveDNS(listenerNum string) error {
go p.detectLoop(m)
q := m.Question[0]
domain := canonicalName(q.Name)
if domain == selfCheckInternalTestDomain {
switch {
case domain == "":
answer := new(dns.Msg)
answer.SetRcode(m, dns.RcodeFormatError)
_ = w.WriteMsg(answer)
return
case domain == selfCheckInternalTestDomain:
answer := resolveInternalDomainTestQuery(ctx, domain, m)
_ = w.WriteMsg(answer)
return
}
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
p.cache.Purge()
ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain)
@@ -411,23 +432,19 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
leaked := false
// If ctrld is going to leak query to OS resolver, check remote upstream in background,
// so ctrld could be back to normal operation as long as the network is back online.
if len(upstreamConfigs) > 0 && p.leakingQuery.Load() {
for n, uc := range upstreamConfigs {
go p.checkUpstream(upstreams[n], uc)
}
upstreamConfigs = nil
leaked = true
ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams)
}
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
upstreams = []string{upstreamOS}
}
if p.isAdDomainQuery(req.msg) {
ctrld.Log(ctx, mainLog.Load().Debug(),
"AD domain query detected for %s in domain %s",
req.msg.Question[0].Name, p.adDomain)
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
upstreams = []string{upstreamOS}
}
res := &proxyResponse{}
// LAN/PTR lookup flow:
@@ -438,13 +455,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
// 4. Try remote upstream.
isLanOrPtrQuery := false
if req.ufr.matched {
if leaked {
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v (leaked)", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
} else {
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
}
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
} else {
switch {
case isSrvLookup(req.msg):
upstreams = []string{upstreamOS}
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
ctx = ctrld.LanQueryCtx(ctx)
ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams)
case isPrivatePtrLookup(req.msg):
isLanOrPtrQuery = true
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
@@ -452,7 +470,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
res.clientInfo = true
return res
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs)
ctx = ctrld.LanQueryCtx(ctx)
ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams)
case isLanHostnameQuery(req.msg):
isLanOrPtrQuery = true
@@ -461,7 +480,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
res.clientInfo = true
return res
}
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
upstreams = []string{upstreamOS}
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
ctx = ctrld.LanQueryCtx(ctx)
ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
default:
ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams)
@@ -488,8 +509,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
staleAnswer = answer
}
}
resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
if err != nil {
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver")
@@ -504,43 +525,53 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
}
return dnsResolver.Resolve(resolveCtx, msg)
}
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
}
answer, err := resolve1(n, upstreamConfig, msg)
answer, err := resolve1(upstream, upstreamConfig, msg)
// if we have an answer, we should reset the failure count
// we dont use reset here since we dont want to prevent failure counts from being incremented
if answer != nil {
p.um.mu.Lock()
p.um.failureReq[upstream] = 0
p.um.down[upstream] = false
p.um.mu.Unlock()
return answer
}
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
// increase failure count when there is no answer
// rehardless of what kind of error we get
p.um.increaseFailureCount(upstream)
if err != nil {
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
isNetworkErr := errNetworkError(err)
if isNetworkErr {
p.um.increaseFailureCount(upstreams[n])
if p.um.isDown(upstreams[n]) {
go p.checkUpstream(upstreams[n], upstreamConfig)
}
}
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
var e net.Error
if errors.As(err, &e) && e.Timeout() {
upstreamConfig.ReBootstrap()
}
return nil
}
return answer
return nil
}
for n, upstreamConfig := range upstreamConfigs {
if upstreamConfig == nil {
continue
}
logger := mainLog.Load().Debug().
Str("upstream", upstreamConfig.String()).
Str("query", req.msg.Question[0].Name).
Bool("is_ad_query", p.isAdDomainQuery(req.msg)).
Bool("is_lan_query", isLanOrPtrQuery)
if p.isLoop(upstreamConfig) {
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint)
ctrld.Log(ctx, logger, "DNS loop detected")
continue
}
if p.um.isDown(upstreams[n]) {
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
continue
}
answer := resolve(n, upstreamConfig, req.msg)
answer := resolve(upstreams[n], upstreamConfig, req.msg)
if answer == nil {
if serveStaleCache && staleAnswer != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
@@ -587,21 +618,49 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
return res
}
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
if cdUID != "" && p.leakOnUpstreamFailure() {
p.leakingQueryMu.Lock()
if !p.leakingQueryWasRun {
p.leakingQueryWasRun = true
go p.performLeakingQuery()
// if we have no healthy upstreams, trigger recovery flow
if p.recoverOnUpstreamFailure() {
if p.um.countHealthy(upstreams) == 0 {
p.recoveryCancelMu.Lock()
if p.recoveryCancel == nil {
var reason RecoveryReason
if upstreams[0] == upstreamOS {
reason = RecoveryReasonOSFailure
} else {
reason = RecoveryReasonRegularFailure
}
mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
go p.handleRecovery(reason)
} else {
mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
}
p.recoveryCancelMu.Unlock()
} else {
mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
}
p.leakingQueryMu.Unlock()
}
// attempt query to OS resolver while as a retry catch all
if upstreams[0] != upstreamOS {
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
if answer != nil {
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
res.answer = answer
res.upstream = osUpstreamConfig.Endpoint
return res
}
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
}
answer := new(dns.Msg)
answer.SetRcode(req.msg, dns.RcodeServerFailure)
res.answer = answer
return res
}
func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
if len(p.localUpstreams) > 0 {
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
tmp = append(tmp, p.localUpstreams...)
@@ -620,6 +679,14 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U
return upstreamConfigs
}
func (p *prog) isAdDomainQuery(msg *dns.Msg) bool {
if p.adDomain == "" {
return false
}
cDomainName := canonicalName(msg.Question[0].Name)
return dns.IsSubDomain(p.adDomain, cDomainName)
}
// canonicalName returns canonical name from FQDN with "." trimmed.
func canonicalName(fqdn string) string {
q := strings.TrimSpace(fqdn)
@@ -916,18 +983,6 @@ func (p *prog) selfUninstallCoolOfPeriod() {
p.selfUninstallMu.Unlock()
}
// performLeakingQuery performs necessary works to leak queries to OS resolver.
func (p *prog) performLeakingQuery() {
mainLog.Load().Warn().Msg("leaking query to OS resolver")
// Signal dns watchers to stop, so changes made below won't be reverted.
p.leakingQuery.Store(true)
p.resetDNS()
ns := ctrld.InitializeOsResolver()
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
p.dnsWg.Wait()
p.setDNS()
}
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
// and the domain == "cdUID.verify.controld.com"
func (p *prog) forceFetchingAPI(domain string) {
@@ -1056,7 +1111,16 @@ func isLanHostnameQuery(m *dns.Msg) bool {
name := strings.TrimSuffix(q.Name, ".")
return !strings.Contains(name, ".") ||
strings.HasSuffix(name, ".domain") ||
strings.HasSuffix(name, ".lan")
strings.HasSuffix(name, ".lan") ||
strings.HasSuffix(name, ".local")
}
// isSrvLookup reports whether DNS message is a SRV query.
func isSrvLookup(m *dns.Msg) bool {
if m == nil || len(m.Question) == 0 {
return false
}
return m.Question[0].Qtype == dns.TypeSRV
}
// isWanClient reports whether the input is a WAN address.
@@ -1089,3 +1153,406 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M
answer.SetReply(m)
return answer
}
// FlushDNSCache flushes the DNS cache on macOS.
func FlushDNSCache() error {
// if not macOS, return
if runtime.GOOS != "darwin" {
return nil
}
// Flush the DNS cache via mDNSResponder.
// This is typically needed on modern macOS systems.
if out, err := exec.Command("killall", "-HUP", "mDNSResponder").CombinedOutput(); err != nil {
return fmt.Errorf("failed to flush mDNSResponder: %w, output: %s", err, string(out))
}
// Optionally, flush the directory services cache.
if out, err := exec.Command("dscacheutil", "-flushcache").CombinedOutput(); err != nil {
return fmt.Errorf("failed to flush dscacheutil: %w, output: %s", err, string(out))
}
return nil
}
// monitorNetworkChanges starts monitoring for network interface changes
func (p *prog) monitorNetworkChanges(ctx context.Context) error {
mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: "))
if err != nil {
return fmt.Errorf("creating network monitor: %w", err)
}
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
// Get map of valid interfaces
validIfaces := validInterfacesMap()
isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New)
mainLog.Load().Debug().
Interface("old_state", delta.Old).
Interface("new_state", delta.New).
Bool("is_major_change", isMajorChange).
Msg("Network change detected")
changed := false
activeInterfaceExists := false
var changeIPs []netip.Prefix
// Check each valid interface for changes
for ifaceName := range validIfaces {
oldIface, oldExists := delta.Old.Interface[ifaceName]
newIface, newExists := delta.New.Interface[ifaceName]
if !newExists {
continue
}
oldIPs := delta.Old.InterfaceIPs[ifaceName]
newIPs := delta.New.InterfaceIPs[ifaceName]
// if a valid interface did not exist in old
// check that its up and has usable IPs
if !oldExists {
// The interface is new (was not present in the old state).
usableNewIPs := filterUsableIPs(newIPs)
if newIface.IsUp() && len(usableNewIPs) > 0 {
changed = true
changeIPs = usableNewIPs
mainLog.Load().Debug().
Str("interface", ifaceName).
Interface("new_ips", usableNewIPs).
Msg("Interface newly appeared (was not present in old state)")
break
}
continue
}
// Filter new IPs to only those that are usable.
usableNewIPs := filterUsableIPs(newIPs)
// Check if interface is up and has usable IPs.
if newIface.IsUp() && len(usableNewIPs) > 0 {
activeInterfaceExists = true
}
// Compare interface states and IPs (interfaceIPsEqual will itself filter the IPs).
if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) {
if newIface.IsUp() && len(usableNewIPs) > 0 {
changed = true
changeIPs = usableNewIPs
mainLog.Load().Debug().
Str("interface", ifaceName).
Interface("old_ips", oldIPs).
Interface("new_ips", usableNewIPs).
Msg("Interface state or IPs changed")
break
}
}
}
if !changed {
mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected")
return
}
if !activeInterfaceExists {
mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization")
return
}
// Get IPs from default route interface in new state
selfIP := defaultRouteIP()
var ipv6 string
if delta.New.DefaultRouteInterface != "" {
mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface])
for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] {
ipAddr, _ := netip.ParsePrefix(ip.String())
addr := ipAddr.Addr()
if selfIP == "" && addr.Is4() {
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
selfIP = addr.String()
}
}
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
ipv6 = addr.String()
}
}
} else {
// If no default route interface is set yet, use the changed IPs
mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs)
for _, ip := range changeIPs {
ipAddr, _ := netip.ParsePrefix(ip.String())
addr := ipAddr.Addr()
if selfIP == "" && addr.Is4() {
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
selfIP = addr.String()
}
}
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
ipv6 = addr.String()
}
}
}
if ip := net.ParseIP(selfIP); ip != nil {
ctrld.SetDefaultLocalIPv4(ip)
if !isMobile() && p.ciTable != nil {
p.ciTable.SetSelfIP(selfIP)
}
}
if ip := net.ParseIP(ipv6); ip != nil {
ctrld.SetDefaultLocalIPv6(ip)
}
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
if p.recoverOnUpstreamFailure() {
p.handleRecovery(RecoveryReasonNetworkChange)
}
})
mon.Start()
mainLog.Load().Debug().Msg("Network monitor started")
return nil
}
// interfaceStatesEqual compares two interface states
func interfaceStatesEqual(a, b *netmon.Interface) bool {
if a == nil || b == nil {
return a == b
}
return a.IsUp() == b.IsUp()
}
// filterUsableIPs is a helper that returns only "usable" IP prefixes,
// filtering out link-local, loopback, multicast, unspecified, broadcast, or CGNAT addresses.
func filterUsableIPs(prefixes []netip.Prefix) []netip.Prefix {
var usable []netip.Prefix
for _, p := range prefixes {
addr := p.Addr()
if addr.IsLinkLocalUnicast() ||
addr.IsLoopback() ||
addr.IsMulticast() ||
addr.IsUnspecified() ||
addr.IsLinkLocalMulticast() ||
(addr.Is4() && addr.String() == "255.255.255.255") ||
tsaddr.CGNATRange().Contains(addr) {
continue
}
usable = append(usable, p)
}
return usable
}
// Modified interfaceIPsEqual compares only the usable (non-link local, non-loopback, etc.) IP addresses.
func interfaceIPsEqual(a, b []netip.Prefix) bool {
aUsable := filterUsableIPs(a)
bUsable := filterUsableIPs(b)
if len(aUsable) != len(bUsable) {
return false
}
aMap := make(map[string]bool)
for _, ip := range aUsable {
aMap[ip.String()] = true
}
for _, ip := range bUsable {
if !aMap[ip.String()] {
return false
}
}
return true
}
// checkUpstreamOnce sends a test query to the specified upstream.
// Returns nil if the upstream responds successfully.
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream)
resolver, err := ctrld.NewResolver(uc)
if err != nil {
mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
return err
}
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
timeout := 1000 * time.Millisecond
if uc.Timeout > 0 {
timeout = time.Millisecond * time.Duration(uc.Timeout)
}
mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
uc.ReBootstrap()
mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
start := time.Now()
_, err = resolver.Resolve(ctx, msg)
duration := time.Since(start)
if err != nil {
mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
} else {
mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
}
return err
}
// handleRecovery performs a unified recovery by removing DNS settings,
// canceling existing recovery checks for network changes, but coalescing duplicate
// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout),
// and then re-applying the DNS settings.
func (p *prog) handleRecovery(reason RecoveryReason) {
mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings")
// For network changes, cancel any existing recovery check because the network state has changed.
if reason == RecoveryReasonNetworkChange {
p.recoveryCancelMu.Lock()
if p.recoveryCancel != nil {
mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)")
p.recoveryCancel()
p.recoveryCancel = nil
}
p.recoveryCancelMu.Unlock()
} else {
// For upstream failures, if a recovery is already in progress, do nothing new.
p.recoveryCancelMu.Lock()
if p.recoveryCancel != nil {
mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
p.recoveryCancelMu.Unlock()
return
}
p.recoveryCancelMu.Unlock()
}
// Create a new recovery context without a fixed timeout.
p.recoveryCancelMu.Lock()
recoveryCtx, cancel := context.WithCancel(context.Background())
p.recoveryCancel = cancel
p.recoveryCancelMu.Unlock()
// Immediately remove our DNS settings from the interface.
// set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
p.recoveryRunning.Store(true)
p.resetDNS()
// For an OS failure, reinitialize OS resolver nameservers immediately.
if reason == RecoveryReasonOSFailure {
mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers")
ns := ctrld.InitializeOsResolver(true)
if len(ns) == 0 {
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
} else {
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
}
}
// Build upstream map based on the recovery reason.
upstreams := p.buildRecoveryUpstreams(reason)
// Wait indefinitely until one of the upstreams recovers.
recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams)
if err != nil {
mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed")
p.recoveryCancelMu.Lock()
p.recoveryCancel = nil
p.recoveryCancelMu.Unlock()
return
}
mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered)
// reset the upstream failure count and down state
p.um.reset(recovered)
// For network changes we also reinitialize the OS resolver.
if reason == RecoveryReasonNetworkChange {
ns := ctrld.InitializeOsResolver(true)
if len(ns) == 0 {
mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values")
} else {
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
}
}
// Apply our DNS settings back and log the interface state.
p.setDNS()
p.logInterfacesState()
// allow watchdogs to put the listener back on the interface if its changed for any reason
p.recoveryRunning.Store(false)
// Clear the recovery cancellation for a clean slate.
p.recoveryCancelMu.Lock()
p.recoveryCancel = nil
p.recoveryCancelMu.Unlock()
}
// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers.
// It returns the name of the recovered upstream or an error if the check times out.
func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) {
recoveredCh := make(chan string, 1)
var wg sync.WaitGroup
mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
for name, uc := range upstreams {
wg.Add(1)
go func(name string, uc *ctrld.UpstreamConfig) {
defer wg.Done()
mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name)
for {
select {
case <-ctx.Done():
mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name)
return
default:
// checkUpstreamOnce will reset any failure counters on success.
if err := p.checkUpstreamOnce(name, uc); err == nil {
mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name)
select {
case recoveredCh <- name:
mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name)
default:
mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered")
}
return
}
mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
time.Sleep(checkUpstreamBackoffSleep)
}
}
}(name, uc)
}
var recovered string
select {
case recovered = <-recoveredCh:
case <-ctx.Done():
return "", ctx.Err()
}
wg.Wait()
return recovered, nil
}
// buildRecoveryUpstreams constructs the map of upstream configurations to test.
// For OS failures we supply the manual OS resolver upstream configuration.
// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS).
func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig {
upstreams := make(map[string]*ctrld.UpstreamConfig)
switch reason {
case RecoveryReasonOSFailure:
upstreams[upstreamOS] = osUpstreamConfig
case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure:
// Use all configured upstreams except any OS type.
for k, uc := range p.cfg.Upstream {
if uc.Type != ctrld.ResolverTypeOS {
upstreams[upstreamPrefix+k] = uc
}
}
}
return upstreams
}

View File

@@ -75,6 +75,7 @@ func Test_canonicalName(t *testing.T) {
func Test_prog_upstreamFor(t *testing.T) {
cfg := testhelper.SampleConfig(t)
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
p := &prog{cfg: cfg}
p.um = newUpstreamMonitor(p.cfg)
p.lanLoopGuard = newLoopGuard()
@@ -365,6 +366,9 @@ func Test_isLanHostnameQuery(t *testing.T) {
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
}
for _, tc := range tests {
tc := tc
@@ -414,6 +418,26 @@ func Test_isPrivatePtrLookup(t *testing.T) {
}
}
func Test_isSrvLookup(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
isSrvLookup bool
}{
{"SRV", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isSrvLookup(tc.msg); tc.isSrvLookup != got {
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
}
})
}
}
func Test_isWanClient(t *testing.T) {
tests := []struct {
name string

186
cmd/cli/log_writer.go Normal file
View File

@@ -0,0 +1,186 @@
package cli
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/Control-D-Inc/ctrld"
)
const (
logWriterSize = 1024 * 1024 * 5 // 5 MB
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
logWriterInitialSize = 32 * 1024 // 32 KB
logSentInterval = time.Minute
logStartEndMarker = "\n\n=== INIT_END ===\n\n"
logLogEndMarker = "\n\n=== LOG_END ===\n\n"
logWarnEndMarker = "\n\n=== WARN_END ===\n\n"
)
type logViewResponse struct {
Data string `json:"data"`
}
type logSentResponse struct {
Size int64 `json:"size"`
Error string `json:"error"`
}
type logReader struct {
r io.ReadCloser
size int64
}
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
type logWriter struct {
mu sync.Mutex
buf bytes.Buffer
size int
}
// newLogWriter creates an internal log writer.
func newLogWriter() *logWriter {
return newLogWriterWithSize(logWriterSize)
}
// newSmallLogWriter creates an internal log writer with small buffer size.
func newSmallLogWriter() *logWriter {
return newLogWriterWithSize(logWriterSmallSize)
}
// newLogWriterWithSize creates an internal log writer with a given buffer size.
func newLogWriterWithSize(size int) *logWriter {
lw := &logWriter{size: size}
return lw
}
func (lw *logWriter) Write(p []byte) (int, error) {
lw.mu.Lock()
defer lw.mu.Unlock()
// If writing p causes overflows, discard old data.
if lw.buf.Len()+len(p) > lw.size {
buf := lw.buf.Bytes()
buf = buf[:logWriterInitialSize]
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
buf = buf[:idx]
}
lw.buf.Reset()
lw.buf.Write(buf)
lw.buf.WriteString(logStartEndMarker) // indicate that the log was truncated.
}
// If p is bigger than buffer size, truncate p by half until its size is smaller.
for len(p)+lw.buf.Len() > lw.size {
p = p[len(p)/2:]
}
return lw.buf.Write(p)
}
// initInternalLogging performs internal logging if there's no log enabled.
func (p *prog) initInternalLogging(writers []io.Writer) {
if !p.needInternalLogging() {
return
}
p.initInternalLogWriterOnce.Do(func() {
mainLog.Load().Notice().Msg("internal logging enabled")
p.internalLogWriter = newLogWriter()
p.internalLogSent = time.Now().Add(-logSentInterval)
p.internalWarnLogWriter = newSmallLogWriter()
})
p.mu.Lock()
lw := p.internalLogWriter
wlw := p.internalWarnLogWriter
p.mu.Unlock()
// If ctrld was run without explicit verbose level,
// run the internal logging at debug level, so we could
// have enough information for troubleshooting.
if verbose == 0 {
for i := range writers {
w := &zerolog.FilteredLevelWriter{
Writer: zerolog.LevelWriterAdapter{Writer: writers[i]},
Level: zerolog.NoticeLevel,
}
writers[i] = w
}
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
writers = append(writers, lw)
writers = append(writers, &zerolog.FilteredLevelWriter{
Writer: zerolog.LevelWriterAdapter{Writer: wlw},
Level: zerolog.WarnLevel,
})
multi := zerolog.MultiLevelWriter(writers...)
l := mainLog.Load().Output(multi).With().Logger()
mainLog.Store(&l)
ctrld.ProxyLogger.Store(&l)
}
// needInternalLogging reports whether prog needs to run internal logging.
func (p *prog) needInternalLogging() bool {
// Do not run in non-cd mode.
if cdUID == "" {
return false
}
// Do not run if there's already log file.
if p.cfg.Service.LogPath != "" {
return false
}
return true
}
func (p *prog) logReader() (*logReader, error) {
if p.needInternalLogging() {
p.mu.Lock()
lw := p.internalLogWriter
wlw := p.internalWarnLogWriter
p.mu.Unlock()
if lw == nil {
return nil, errors.New("nil internal log writer")
}
if wlw == nil {
return nil, errors.New("nil internal warn log writer")
}
// Normal log content.
lw.mu.Lock()
lwReader := bytes.NewReader(lw.buf.Bytes())
lwSize := lw.buf.Len()
lw.mu.Unlock()
// Warn log content.
wlw.mu.Lock()
wlwReader := bytes.NewReader(wlw.buf.Bytes())
wlwSize := wlw.buf.Len()
wlw.mu.Unlock()
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logLogEndMarker)), wlwReader)
lr := &logReader{r: io.NopCloser(reader)}
lr.size = int64(lwSize + wlwSize)
if lr.size == 0 {
return nil, errors.New("internal log is empty")
}
return lr, nil
}
if p.cfg.Service.LogPath == "" {
return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil
}
f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath))
if err != nil {
return nil, err
}
lr := &logReader{r: f}
if st, err := f.Stat(); err == nil {
lr.size = st.Size()
} else {
return nil, fmt.Errorf("f.Stat: %w", err)
}
if lr.size == 0 {
return nil, errors.New("log file is empty")
}
return lr, nil
}

View File

@@ -0,0 +1,49 @@
package cli
import (
"strings"
"sync"
"testing"
)
func Test_logWriter_Write(t *testing.T) {
size := 64 * 1024
lw := &logWriter{size: size}
lw.buf.Grow(lw.size)
data := strings.Repeat("A", size)
lw.Write([]byte(data))
if lw.buf.String() != data {
t.Fatalf("unexpected buf content: %v", lw.buf.String())
}
newData := "B"
halfData := strings.Repeat("A", len(data)/2) + logStartEndMarker
lw.Write([]byte(newData))
if lw.buf.String() != halfData+newData {
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
}
bigData := strings.Repeat("B", 256*1024)
expected := halfData + strings.Repeat("B", 16*1024)
lw.Write([]byte(bigData))
if lw.buf.String() != expected {
t.Fatalf("unexpected big buf content: %v", lw.buf.String())
}
}
func Test_logWriter_ConcurrentWrite(t *testing.T) {
size := 64 * 1024
lw := &logWriter{size: size}
n := 10
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
lw.Write([]byte(strings.Repeat("A", i)))
}()
}
wg.Wait()
if lw.buf.Len() > lw.size {
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
}
}

View File

@@ -101,9 +101,23 @@ func initConsoleLogging() {
}
// initLogging initializes global logging setup.
func initLogging() {
func initLogging() []io.Writer {
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
initLoggingWithBackup(true)
return initLoggingWithBackup(true)
}
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
// to be used for all interactive commands.
//
// Current log file config will also be ignored.
func initInteractiveLogging() {
old := cfg.Service.LogPath
cfg.Service.LogPath = ""
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
initLoggingWithBackup(false)
cfg.Service.LogPath = old
l := zerolog.New(io.Discard)
ctrld.ProxyLogger.Store(&l)
}
// initLoggingWithBackup initializes log setup base on current config.
@@ -112,8 +126,8 @@ func initLogging() {
// This is only used in runCmd for special handling in case of logging config
// change in cd mode. Without special reason, the caller should use initLogging
// wrapper instead of calling this function directly.
func initLoggingWithBackup(doBackup bool) {
writers := []io.Writer{io.Discard}
func initLoggingWithBackup(doBackup bool) []io.Writer {
var writers []io.Writer
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
// Create parent directory if necessary.
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
@@ -151,21 +165,22 @@ func initLoggingWithBackup(doBackup bool) {
switch {
case silent:
zerolog.SetGlobalLevel(zerolog.NoLevel)
return
return writers
case verbose == 1:
logLevel = "info"
case verbose > 1:
logLevel = "debug"
}
if logLevel == "" {
return
return writers
}
level, err := zerolog.ParseLevel(logLevel)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not set log level")
return
return writers
}
zerolog.SetGlobalLevel(level)
return writers
}
func initCache() {

View File

@@ -9,17 +9,18 @@ import (
"strings"
)
func patchNetIfaceName(iface *net.Interface) error {
func patchNetIfaceName(iface *net.Interface) (bool, error) {
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
if err != nil {
return err
return false, err
}
patched := false
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
patched = true
iface.Name = name
mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface")
}
return nil
return patched, nil
}
func networkServiceName(ifaceName string, r io.Reader) string {

52
cmd/cli/net_linux.go Normal file
View File

@@ -0,0 +1,52 @@
package cli
import (
"net"
"net/netip"
"os"
"strings"
"tailscale.com/net/netmon"
)
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
// validInterface reports whether the *net.Interface is a valid one.
// Only non-virtual interfaces are considered valid.
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
_, ok := validIfacesMap[iface.Name]
return ok
}
// validInterfacesMap returns a set containing non virtual interfaces.
func validInterfacesMap() map[string]struct{} {
m := make(map[string]struct{})
vis := virtualInterfaces()
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
if _, existed := vis[i.Name]; existed {
return
}
m[i.Name] = struct{}{}
})
// Fallback to default route interface if found nothing.
if len(m) == 0 {
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
return m
}
m[defaultRoute.InterfaceName] = struct{}{}
}
return m
}
// virtualInterfaces returns a map of virtual interfaces on current machine.
func virtualInterfaces() map[string]struct{} {
s := make(map[string]struct{})
entries, _ := os.ReadDir("/sys/devices/virtual/net")
for _, entry := range entries {
if entry.IsDir() {
s[strings.TrimSpace(entry.Name())] = struct{}{}
}
}
return s
}

View File

@@ -1,11 +1,22 @@
//go:build !darwin && !windows
//go:build !darwin && !windows && !linux
package cli
import "net"
import (
"net"
func patchNetIfaceName(iface *net.Interface) error { return nil }
"tailscale.com/net/netmon"
)
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
func validInterfacesMap() map[string]struct{} { return nil }
// validInterfacesMap returns a set containing only default route interfaces.
func validInterfacesMap() map[string]struct{} {
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
return nil
}
return map[string]struct{}{defaultRoute.InterfaceName: {}}
}

View File

@@ -1,14 +1,20 @@
package cli
import (
"bufio"
"bytes"
"io"
"log"
"net"
"strings"
"os"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
)
func patchNetIfaceName(iface *net.Interface) error {
return nil
func patchNetIfaceName(iface *net.Interface) (bool, error) {
return true, nil
}
// validInterface reports whether the *net.Interface is a valid one.
@@ -20,15 +26,68 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
// validInterfacesMap returns a set of all physical interfaces.
func validInterfacesMap() map[string]struct{} {
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
if err != nil {
return nil
}
m := make(map[string]struct{})
scanner := bufio.NewScanner(bytes.NewReader(out))
for scanner.Scan() {
ifaceName := strings.TrimSpace(scanner.Text())
for _, ifaceName := range validInterfaces() {
m[ifaceName] = struct{}{}
}
return m
}
// validInterfaces returns a list of all physical interfaces.
func validInterfaces() []string {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("MSFT_NetAdapter")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
return nil
}
var adapters []string
for _, i := range instances {
adapter, err := netadapter.NewNetworkAdapter(i)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
continue
}
name, err := adapter.GetPropertyName()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
continue
}
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
//
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
// if this is a physical adapter or FALSE if this is not a physical adapter."
physical, err := adapter.GetPropertyConnectorPresent()
if err != nil {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
continue
}
if !physical {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
continue
}
// Check if it's a hardware interface. Checking only for connector present is not enough
// because some interfaces are not physical but have a connector.
hardware, err := adapter.GetPropertyHardwareInterface()
if err != nil {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
continue
}
if !hardware {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
continue
}
adapters = append(adapters, name)
}
return adapters
}

View File

@@ -0,0 +1,42 @@
package cli
import (
"bufio"
"bytes"
"slices"
"strings"
"testing"
"time"
)
func Test_validInterfaces(t *testing.T) {
verbose = 3
initConsoleLogging()
start := time.Now()
ifaces := validInterfaces()
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
start = time.Now()
ifacesPowershell := validInterfacesPowershell()
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
slices.Sort(ifaces)
slices.Sort(ifacesPowershell)
if !slices.Equal(ifaces, ifacesPowershell) {
t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces)
}
}
func validInterfacesPowershell() []string {
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
if err != nil {
return nil
}
var res []string
scanner := bufio.NewScanner(bytes.NewReader(out))
for scanner.Scan() {
ifaceName := strings.TrimSpace(scanner.Text())
res = append(res, ifaceName)
}
return res
}

View File

@@ -70,11 +70,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
// TODO(cuonglm): use system API
func resetDNS(iface *net.Interface) error {
if ns := savedStaticNameservers(iface); len(ns) > 0 {
if err := setDNS(iface, ns); err == nil {
return nil
}
}
cmd := "networksetup"
args := []string{"-setdnsservers", iface.Name, "empty"}
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
@@ -83,6 +78,15 @@ func resetDNS(iface *net.Interface) error {
return nil
}
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
if ns := savedStaticNameservers(iface); len(ns) > 0 {
err = setDNS(iface, ns)
}
return err
}
func currentDNS(_ *net.Interface) []string {
return resolvconffile.NameServers("")
}

View File

@@ -76,6 +76,12 @@ func resetDNS(iface *net.Interface) error {
return nil
}
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
return err
}
func currentDNS(_ *net.Interface) []string {
return resolvconffile.NameServers("")
}

View File

@@ -195,6 +195,12 @@ func resetDNS(iface *net.Interface) (err error) {
})
}
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
return err
}
func currentDNS(iface *net.Interface) []string {
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
if ns := fn(iface.Name); len(ns) > 0 {

View File

@@ -1,23 +1,27 @@
package cli
import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"slices"
"strconv"
"strings"
"sync"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
const (
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
)
var (
@@ -30,14 +34,6 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
return setDNS(iface, nameservers)
}
func setDnsPowershellCmd(iface *net.Interface, nameservers []string) string {
nss := make([]string, 0, len(nameservers))
for _, ns := range nameservers {
nss = append(nss, strconv.Quote(ns))
}
return fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ServerAddresses (%s)", iface.Index, strings.Join(nss, ","))
}
// setDNS sets the dns server for the provided network interface
func setDNS(iface *net.Interface, nameservers []string) error {
if len(nameservers) == 0 {
@@ -46,7 +42,7 @@ func setDNS(iface *net.Interface, nameservers []string) error {
setDNSOnce.Do(func() {
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
// Configuring the Dns server to forward queries to ctrld instead.
if windowsHasLocalDnsServerRunning() {
if hasLocalDnsServerRunning() {
file := absHomeDir(windowsForwardersFilename)
oldForwardersContent, _ := os.ReadFile(file)
hasLocalIPv6Listener := needLocalIPv6Listener()
@@ -65,9 +61,36 @@ func setDNS(iface *net.Interface, nameservers []string) error {
}
}
})
out, err := powershell(setDnsPowershellCmd(iface, nameservers))
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return fmt.Errorf("%w: %s", err, string(out))
return fmt.Errorf("setDNS: %w", err)
}
var (
serversV4 []netip.Addr
serversV6 []netip.Addr
)
for _, ns := range nameservers {
if addr, err := netip.ParseAddr(ns); err == nil {
if addr.Is4() {
serversV4 = append(serversV4, addr)
} else {
serversV6 = append(serversV6, addr)
}
}
}
if len(serversV4) == 0 && len(serversV6) == 0 {
return errors.New("invalid DNS nameservers")
}
if len(serversV4) > 0 {
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
return fmt.Errorf("could not set DNS ipv4: %w", err)
}
}
if len(serversV6) > 0 {
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
return fmt.Errorf("could not set DNS ipv6: %w", err)
}
}
return nil
}
@@ -81,7 +104,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
func resetDNS(iface *net.Interface) error {
resetDNSOnce.Do(func() {
// See corresponding comment in setDNS.
if windowsHasLocalDnsServerRunning() {
if hasLocalDnsServerRunning() {
file := absHomeDir(windowsForwardersFilename)
content, err := os.ReadFile(file)
if err != nil {
@@ -96,14 +119,23 @@ func resetDNS(iface *net.Interface) error {
}
})
// Restoring DHCP settings.
cmd := fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ResetServerAddresses", iface.Index)
out, err := powershell(cmd)
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return fmt.Errorf("%w: %s", err, string(out))
return fmt.Errorf("resetDNS: %w", err)
}
// Restoring DHCP settings.
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
return fmt.Errorf("could not reset DNS ipv4: %w", err)
}
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
return fmt.Errorf("could not reset DNS ipv6: %w", err)
}
return nil
}
// If there's static DNS saved, restoring it.
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
if nss := savedStaticNameservers(iface); len(nss) > 0 {
v4ns := make([]string, 0, 2)
v6ns := make([]string, 0, 2)
@@ -120,12 +152,14 @@ func resetDNS(iface *net.Interface) error {
continue
}
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
if err := setDNS(iface, ns); err != nil {
err = setDNS(iface, ns)
if err != nil {
return err
}
}
}
return nil
return err
}
func currentDNS(iface *net.Interface) []string {
@@ -150,25 +184,31 @@ func currentDNS(iface *net.Interface) []string {
func currentStaticDNS(iface *net.Interface) ([]string, error) {
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return nil, err
return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err)
}
guid, err := luid.GUID()
if err != nil {
return nil, err
return nil, fmt.Errorf("luid.GUID: %w", err)
}
var ns []string
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
interfaceKeyPath := path + guid.String()
found := false
interfaceKeyPath := path + guid.String()
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
if err != nil {
return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err)
}
for _, key := range []string{"NameServer", "ProfileNameServer"} {
if found {
continue
}
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
out, err := powershell(cmd)
if err == nil && len(out) > 0 {
value, _, err := k.GetStringValue(key)
if err != nil && !errors.Is(err, registry.ErrNotExist) {
return nil, fmt.Errorf("%s: %w", key, err)
}
if len(value) > 0 {
found = true
for _, e := range strings.Split(string(out), ",") {
for _, e := range strings.Split(value, ",") {
ns = append(ns, strings.TrimRight(e, "\x00"))
}
}
@@ -216,3 +256,9 @@ func removeDnsServerForwarders(nameservers []string) error {
}
return nil
}
// powershell runs the given powershell command.
func powershell(cmd string) ([]byte, error) {
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
return bytes.TrimSpace(out), err
}

View File

@@ -0,0 +1,68 @@
package cli
import (
"fmt"
"net"
"slices"
"strings"
"testing"
"time"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
func Test_currentStaticDNS(t *testing.T) {
iface, err := net.InterfaceByName(defaultIfaceName())
if err != nil {
t.Fatal(err)
}
start := time.Now()
staticDns, err := currentStaticDNS(iface)
if err != nil {
t.Fatal(err)
}
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
start = time.Now()
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
if err != nil {
t.Fatal(err)
}
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
slices.Sort(staticDns)
slices.Sort(staticDnsPowershell)
if !slices.Equal(staticDns, staticDnsPowershell) {
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
}
}
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return nil, err
}
guid, err := luid.GUID()
if err != nil {
return nil, err
}
var ns []string
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
interfaceKeyPath := path + guid.String()
found := false
for _, key := range []string{"NameServer", "ProfileNameServer"} {
if found {
continue
}
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
out, err := powershell(cmd)
if err == nil && len(out) > 0 {
found = true
for _, e := range strings.Split(string(out), ",") {
ns = append(ns, strings.TrimRight(e, "\x00"))
}
}
}
}
return ns, nil
}

View File

@@ -45,6 +45,18 @@ const (
upstreamOS = upstreamPrefix + "os"
upstreamPrivate = upstreamPrefix + "private"
dnsWatchdogDefaultInterval = 20 * time.Second
ctrldServiceName = "ctrld"
)
// RecoveryReason provides context for why we are waiting for recovery.
// recovery involves removing the listener IP from the interface and
// waiting for the upstreams to work before returning
type RecoveryReason int
const (
RecoveryReasonNetworkChange RecoveryReason = iota
RecoveryReasonRegularFailure
RecoveryReasonOSFailure
)
// ControlSocketName returns name for control unix socket.
@@ -61,8 +73,9 @@ var logf = func(format string, args ...any) {
}
var svcConfig = &service.Config{
Name: "ctrld",
Name: ctrldServiceName,
DisplayName: "Control-D Helper Service",
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
Option: service.KeyValue{},
}
@@ -84,21 +97,29 @@ type prog struct {
dnsWg sync.WaitGroup
dnsWatcherClosedOnce sync.Once
dnsWatcherStopCh chan struct{}
rc *controld.ResolverConfig
cfg *ctrld.Config
localUpstreams []string
ptrNameservers []string
appCallback *AppCallback
cache dnscache.Cacher
cacheFlushDomainsMap map[string]struct{}
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
ptrLoopGuard *loopGuard
lanLoopGuard *loopGuard
metricsQueryStats atomic.Bool
queryFromSelfMap sync.Map
cfg *ctrld.Config
localUpstreams []string
ptrNameservers []string
appCallback *AppCallback
cache dnscache.Cacher
cacheFlushDomainsMap map[string]struct{}
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
ptrLoopGuard *loopGuard
lanLoopGuard *loopGuard
metricsQueryStats atomic.Bool
queryFromSelfMap sync.Map
initInternalLogWriterOnce sync.Once
internalLogWriter *logWriter
internalWarnLogWriter *logWriter
internalLogSent time.Time
runningIface string
requiredMultiNICsConfig bool
adDomain string
selfUninstallMu sync.Mutex
refusedQueryCount int
@@ -108,9 +129,9 @@ type prog struct {
loopMu sync.Mutex
loop map[string]bool
leakingQueryMu sync.Mutex
leakingQueryWasRun bool
leakingQuery atomic.Bool
recoveryCancelMu sync.Mutex
recoveryCancel context.CancelFunc
recoveryRunning atomic.Bool
started chan struct{}
onStartedDone chan struct{}
@@ -162,11 +183,13 @@ func (p *prog) runWait() {
if newCfg == nil {
newCfg = &ctrld.Config{}
confFile := v.ConfigFileUsed()
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.InitConfig(v, "ctrld")
if configPath != "" {
v.SetConfigFile(configPath)
confFile = configPath
}
v.SetConfigFile(confFile)
if err := v.ReadInConfig(); err != nil {
logger.Err(err).Msg("could not read new config")
waitOldRunDone()
@@ -178,10 +201,14 @@ func (p *prog) runWait() {
continue
}
if cdUID != "" {
if err := processCDFlags(newCfg); err != nil {
if rc, err := processCDFlags(newCfg); err != nil {
logger.Err(err).Msg("could not fetch ControlD config")
waitOldRunDone()
continue
} else {
p.mu.Lock()
p.rc = rc
p.mu.Unlock()
}
}
}
@@ -233,6 +260,11 @@ func (p *prog) runWait() {
}
func (p *prog) preRun() {
if iface == "auto" {
iface = defaultIfaceName()
p.requiredMultiNICsConfig = requiredMultiNICsConfig()
}
p.runningIface = iface
if runtime.GOOS == "darwin" {
p.onStopped = append(p.onStopped, func() {
if !service.Interactive() {
@@ -245,11 +277,12 @@ func (p *prog) preRun() {
func (p *prog) postRun() {
if !service.Interactive() {
p.resetDNS()
ns := ctrld.InitializeOsResolver()
ns := ctrld.InitializeOsResolver(false)
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
p.setDNS()
p.csSetDnsDone <- struct{}{}
close(p.csSetDnsDone)
p.logInterfacesState()
}
}
@@ -288,7 +321,24 @@ func (p *prog) apiConfigReload() {
cdDeactivationPin.Store(defaultDeactivationPin)
}
if resolverConfig.Ctrld.CustomConfig == "" {
p.mu.Lock()
rc := p.rc
p.rc = resolverConfig
p.mu.Unlock()
noCustomConfig := resolverConfig.Ctrld.CustomConfig == ""
noExcludeListChanged := true
if rc != nil {
slices.Sort(rc.Exclude)
slices.Sort(resolverConfig.Exclude)
noExcludeListChanged = slices.Equal(rc.Exclude, resolverConfig.Exclude)
}
if noCustomConfig && noExcludeListChanged {
return
}
if noCustomConfig && !noExcludeListChanged {
logger.Debug().Msg("exclude list changes detected, reloading...")
p.apiReloadCh <- nil
return
}
@@ -401,6 +451,10 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
}
}
}
if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() {
mainLog.Load().Debug().Msgf("active directory domain: %s", domain)
p.adDomain = domain
}
var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener))
@@ -429,12 +483,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
}
}
p.setupUpstream(p.cfg)
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
p.setupClientInfoDiscover(defaultRouteIP())
}
// context for managing spawn goroutines.
@@ -446,8 +495,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
wg.Add(1)
go func() {
defer wg.Done()
p.ciTable.Init()
p.ciTable.RefreshLoop(ctx)
p.runClientInfoDiscover(ctx)
}()
go p.watchLinkState(ctx)
}
@@ -463,9 +511,10 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
if err := p.serveDNS(listenerNum); err != nil {
if err := p.serveDNS(ctx, listenerNum); err != nil {
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
}
mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
}(listenerNum)
}
go func() {
@@ -511,16 +560,33 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
if !reload {
// Stop writing log to unix socket.
consoleWriter.Out = os.Stdout
initLoggingWithBackup(false)
logWriters := initLoggingWithBackup(false)
if p.logConn != nil {
_ = p.logConn.Close()
}
go p.apiConfigReload()
p.postRun()
p.initInternalLogging(logWriters)
}
wg.Wait()
}
// setupClientInfoDiscover performs necessary works for running client info discover.
func (p *prog) setupClientInfoDiscover(selfIP string) {
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers)
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
p.ciTable.AddLeaseFile(leaseFile, format)
}
}
// runClientInfoDiscover runs the client info discover.
func (p *prog) runClientInfoDiscover(ctx context.Context) {
p.ciTable.Init()
p.ciTable.RefreshLoop(ctx)
}
// metricsEnabled reports whether prometheus exporter is enabled/disabled.
func (p *prog) metricsEnabled() bool {
return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != ""
@@ -529,7 +595,9 @@ func (p *prog) metricsEnabled() bool {
func (p *prog) Stop(s service.Service) error {
p.stopDnsWatchers()
mainLog.Load().Debug().Msg("dns watchers stopped")
mainLog.Load().Info().Msg("Service stopped")
defer func() {
mainLog.Load().Info().Msg("Service stopped")
}()
close(p.stopCh)
if err := p.deAllocateIP(); err != nil {
mainLog.Load().Error().Err(err).Msg("de-allocate ip failed")
@@ -579,27 +647,42 @@ func (p *prog) setDNS() {
if cfg.Listener == nil {
return
}
if iface == "" {
if p.runningIface == "" {
return
}
runningIface := iface
// allIfaces tracks whether we should set DNS for all physical interfaces.
allIfaces := false
if runningIface == "auto" {
runningIface = defaultIfaceName()
// If runningIface is "auto", it means user does not specify "--iface" flag.
// In this case, ctrld has to set DNS for all physical interfaces, so
// thing will still work when user switch from one to the other.
allIfaces = requiredMultiNICsConfig()
}
allIfaces := p.requiredMultiNICsConfig
lc := cfg.FirstListener()
if lc == nil {
return
}
logger := mainLog.Load().With().Str("iface", runningIface).Logger()
netIface, err := netInterface(runningIface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
const maxDNSRetryAttempts = 3
const retryDelay = 1 * time.Second
var netIface *net.Interface
var err error
for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ {
netIface, err = netInterface(p.runningIface)
if err == nil {
break
}
if attempt < maxDNSRetryAttempts {
// Try to find a different working interface
newIface := findWorkingInterface(p.runningIface)
if newIface != p.runningIface {
p.runningIface = newIface
logger = mainLog.Load().With().Str("iface", p.runningIface).Logger()
logger.Info().Msg("switched to new interface")
continue
}
logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...")
time.Sleep(retryDelay)
continue
}
logger.Error().Err(err).Msg("could not get interface after all attempts")
return
}
if err := setupNetworkManager(); err != nil {
@@ -686,12 +769,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
if !requiredMultiNICsConfig() {
return
}
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
logger.Debug().Msg("start DNS settings watchdog")
mainLog.Load().Debug().Msg("start DNS settings watchdog")
ns := nameservers
slices.Sort(ns)
ticker := time.NewTicker(p.dnsWatchdogDuration())
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
for {
select {
case <-p.dnsWatcherStopCh:
@@ -700,7 +784,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
mainLog.Load().Debug().Msg("stop dns watchdog")
return
case <-ticker.C:
if p.leakingQuery.Load() {
if p.recoveryRunning.Load() {
return
}
if dnsChanged(iface, ns) {
@@ -726,22 +810,19 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
}
func (p *prog) resetDNS() {
if iface == "" {
if p.runningIface == "" {
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
return
}
runningIface := iface
allIfaces := false
if runningIface == "auto" {
runningIface = defaultIfaceName()
// See corresponding comments in (*prog).setDNS function.
allIfaces = requiredMultiNICsConfig()
}
logger := mainLog.Load().With().Str("iface", runningIface).Logger()
netIface, err := netInterface(runningIface)
// See corresponding comments in (*prog).setDNS function.
allIfaces := p.requiredMultiNICsConfig
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
netIface, err := netInterface(p.runningIface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
return
}
if err := restoreNetworkManager(); err != nil {
logger.Error().Err(err).Msg("could not restore NetworkManager")
return
@@ -757,16 +838,157 @@ func (p *prog) resetDNS() {
}
}
// leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams.
func (p *prog) leakOnUpstreamFailure() bool {
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
return *ptr
func (p *prog) logInterfacesState() {
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
addrs, err := i.Addrs()
if err != nil {
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses")
}
nss, err := currentStaticDNS(i)
if err != nil {
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS")
}
if len(nss) == 0 {
nss = currentDNS(i)
}
mainLog.Load().Debug().
Any("addrs", addrs).
Strs("nameservers", nss).
Int("index", i.Index).
Msgf("interface state: %s", i.Name)
return nil
})
}
// findWorkingInterface looks for a network interface with a valid IP configuration
func findWorkingInterface(currentIface string) string {
// Helper to check if IP is valid (not link-local)
isValidIP := func(ip net.IP) bool {
return ip != nil &&
!ip.IsLinkLocalUnicast() &&
!ip.IsLinkLocalMulticast() &&
!ip.IsLoopback() &&
!ip.IsUnspecified()
}
// Default is false on routers, since this leaking is only useful for devices that move between networks.
if router.Name() != "" {
// Helper to check if interface has valid IP configuration
hasValidIPConfig := func(iface *net.Interface) bool {
if iface == nil || iface.Flags&net.FlagUp == 0 {
return false
}
addrs, err := iface.Addrs()
if err != nil {
mainLog.Load().Debug().
Str("interface", iface.Name).
Err(err).
Msg("failed to get interface addresses")
return false
}
for _, addr := range addrs {
// Check for IP network
if ipNet, ok := addr.(*net.IPNet); ok {
if isValidIP(ipNet.IP) {
return true
}
}
}
return false
}
return true
// Get default route interface
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
mainLog.Load().Debug().
Err(err).
Msg("failed to get default route")
} else {
mainLog.Load().Debug().
Str("default_route_iface", defaultRoute.InterfaceName).
Msg("found default route")
}
// Get all interfaces
ifaces, err := net.Interfaces()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to list network interfaces")
return currentIface // Return current interface as fallback
}
var firstWorkingIface string
var currentIfaceValid bool
// Single pass through interfaces
for _, iface := range ifaces {
// Must be physical (has MAC address)
if len(iface.HardwareAddr) == 0 {
continue
}
// Skip interfaces that are:
// - Loopback
// - Not up
// - Point-to-point (like VPN tunnels)
if iface.Flags&net.FlagLoopback != 0 ||
iface.Flags&net.FlagUp == 0 ||
iface.Flags&net.FlagPointToPoint != 0 {
continue
}
if !hasValidIPConfig(&iface) {
continue
}
// Found working physical interface
if err == nil && defaultRoute.InterfaceName == iface.Name {
// Found interface with default route - use it immediately
mainLog.Load().Info().
Str("old_iface", currentIface).
Str("new_iface", iface.Name).
Msg("switching to interface with default route")
return iface.Name
}
// Keep track of first working interface as fallback
if firstWorkingIface == "" {
firstWorkingIface = iface.Name
}
// Check if this is our current interface
if iface.Name == currentIface {
currentIfaceValid = true
}
}
// Return interfaces in order of preference:
// 1. Current interface if it's still valid
if currentIfaceValid {
mainLog.Load().Debug().
Str("interface", currentIface).
Msg("keeping current interface")
return currentIface
}
// 2. First working interface found
if firstWorkingIface != "" {
mainLog.Load().Info().
Str("old_iface", currentIface).
Str("new_iface", firstWorkingIface).
Msg("switching to first working physical interface")
return firstWorkingIface
}
// 3. Fall back to current interface if nothing else works
mainLog.Load().Warn().
Str("current_iface", currentIface).
Msg("no working physical interface found, keeping current")
return currentIface
}
// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure.
func (p *prog) recoverOnUpstreamFailure() bool {
// Default is false on routers, since this recovery flow is only useful for devices that move between networks.
return router.Name() == ""
}
func randomLocalIP() string {
@@ -947,7 +1169,7 @@ func canBeLocalUpstream(addr string) bool {
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
validIfacesMap := validInterfacesMap()
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
// Skip loopback/virtual interface.
// Skip loopback/virtual/down interface.
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
return
}
@@ -956,9 +1178,12 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.
return
}
netIface := i.Interface
if err := patchNetIfaceName(netIface); err != nil {
if patched, err := patchNetIfaceName(netIface); err != nil {
mainLog.Load().Debug().Err(err).Msg("failed to patch net interface name")
return
} else if !patched {
// The interface is not functional, skipping.
return
}
// Skip excluded interface.
if netIface.Name == excludeIfaceName {
@@ -1025,7 +1250,16 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
func savedStaticNameservers(iface *net.Interface) []string {
file := savedStaticDnsSettingsFilePath(iface)
if data, _ := os.ReadFile(file); len(data) > 0 {
return strings.Split(string(data), ",")
saveValues := strings.Split(string(data), ",")
returnValues := []string{}
// check each one, if its in loopback range, remove it
for _, v := range saveValues {
if net.ParseIP(v).IsLoopback() {
continue
}
returnValues = append(returnValues, v)
}
return returnValues
}
return nil
}
@@ -1044,7 +1278,7 @@ func dnsChanged(iface *net.Interface, nameservers []string) bool {
// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then.
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
var uer *controld.UtilityErrorResponse
var uer *controld.ErrorResponse
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
p.stopDnsWatchers()

View File

@@ -3,11 +3,38 @@ package cli
import (
"net"
"net/netip"
"os"
"path/filepath"
"strings"
"time"
"github.com/fsnotify/fsnotify"
)
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
// Returns nil if no nameservers are found.
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Parse the file for "nameserver" lines
var currentNS []string
lines := strings.Split(string(content), "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "nameserver") {
parts := strings.Fields(trimmed)
if len(parts) >= 2 {
currentNS = append(currentNS, parts[1])
}
}
}
return currentNS, nil
}
// watchResolvConf watches any changes to /etc/resolv.conf file,
// and reverting to the original config set by ctrld.
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
@@ -40,7 +67,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
return
case event, ok := <-watcher.Events:
if p.leakingQuery.Load() {
if p.recoveryRunning.Load() {
return
}
if !ok {
@@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
continue
}
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
if err := watcher.Remove(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
continue
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
// Convert expected nameservers to strings for comparison
expectedNS := make([]string, len(ns))
for i, addr := range ns {
expectedNS[i] = addr.String()
}
if err := setDnsFn(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
var foundNS []string
var err error
maxRetries := 1
for retry := 0; retry < maxRetries; retry++ {
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
break
}
// If we found nameservers, break out of retry loop
if len(foundNS) > 0 {
break
}
// Only retry if we found no nameservers
if retry < maxRetries-1 {
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
select {
case <-p.stopCh:
return
case <-p.dnsWatcherStopCh:
return
case <-time.After(2 * time.Second):
continue
}
} else {
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
}
}
if err := watcher.Add(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
return
// If we found nameservers, check if they match what we expect
if len(foundNS) > 0 {
// Check if the nameservers match exactly what we expect
matches := len(foundNS) == len(expectedNS)
if matches {
for i := range foundNS {
if foundNS[i] != expectedNS[i] {
matches = false
break
}
}
}
mainLog.Load().Debug().
Strs("found", foundNS).
Strs("expected", expectedNS).
Bool("matches", matches).
Msg("checking nameservers")
// Only revert if the nameservers don't match
if !matches {
if err := watcher.Remove(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
continue
}
if err := setDnsFn(iface, ns); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
}
if err := watcher.Add(watchDir); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
return
}
}
}
}
case err, ok := <-watcher.Errors:

View File

@@ -156,17 +156,18 @@ func (l *launchd) Status() (service.Status, error) {
type task struct {
f func() error
abortOnError bool
Name string
}
func doTasks(tasks []task) bool {
var prevErr error
for _, task := range tasks {
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
if err := task.f(); err != nil {
if task.abortOnError {
mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error())
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
return false
}
prevErr = err
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
}
}
return true

View File

@@ -13,3 +13,8 @@ func hasElevatedPrivilege() (bool, error) {
func openLogFile(path string, flags int) (*os.File, error) {
return os.OpenFile(path, flags, os.FileMode(0o600))
}
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func hasLocalDnsServerRunning() bool { return false }
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }

View File

@@ -2,9 +2,14 @@ package cli
import (
"os"
"runtime"
"strings"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc/mgr"
)
func hasElevatedPrivilege() (bool, error) {
@@ -28,6 +33,67 @@ func hasElevatedPrivilege() (bool, error) {
return token.IsMember(sid)
}
// ConfigureWindowsServiceFailureActions checks if the given service
// has the correct failure actions configured, and updates them if not.
func ConfigureWindowsServiceFailureActions(serviceName string) error {
if runtime.GOOS != "windows" {
return nil // no-op on non-Windows
}
m, err := mgr.Connect()
if err != nil {
return err
}
defer m.Disconnect()
s, err := m.OpenService(serviceName)
if err != nil {
return err
}
defer s.Close()
// 1. Retrieve the current config
cfg, err := s.Config()
if err != nil {
return err
}
// 2. Update the Description
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
// 3. Apply the updated config
if err := s.UpdateConfig(cfg); err != nil {
return err
}
// Then proceed with existing actions, e.g. setting failure actions
actions := []mgr.RecoveryAction{
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
}
// Set the recovery actions (3 restarts, reset period = 120).
err = s.SetRecoveryActions(actions, 120)
if err != nil {
return err
}
// Ensure that failure actions are NOT triggered on user-initiated stops.
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
if err := windows.ChangeServiceConfig2(
s.Handle,
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
(*byte)(unsafe.Pointer(&failureActionsFlag)),
); err != nil {
return err
}
return nil
}
func openLogFile(path string, mode int) (*os.File, error) {
if len(path) == 0 {
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
@@ -79,3 +145,23 @@ func openLogFile(path string, mode int) (*os.File, error) {
return os.NewFile(uintptr(handle), path), nil
}
const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func hasLocalDnsServerRunning() bool {
h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
if e != nil {
return false
}
p := windows.ProcessEntry32{Size: processEntrySize}
for {
e := windows.Process32Next(h, &p)
if e != nil {
return false
}
if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" {
return true
}
}
}

View File

@@ -0,0 +1,25 @@
package cli
import (
"testing"
"time"
)
func Test_hasLocalDnsServerRunning(t *testing.T) {
start := time.Now()
hasDns := hasLocalDnsServerRunning()
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
start = time.Now()
hasDnsPowershell := hasLocalDnsServerRunningPowershell()
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
if hasDns != hasDnsPowershell {
t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns)
}
}
func hasLocalDnsServerRunningPowershell() bool {
_, err := powershell("Get-Process -Name DNS")
return err == nil
}

View File

@@ -1,18 +1,15 @@
package cli
import (
"context"
"sync"
"time"
"github.com/miekg/dns"
"github.com/Control-D-Inc/ctrld"
)
const (
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
maxFailureRequest = 100
maxFailureRequest = 50
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
checkUpstreamBackoffSleep = 2 * time.Second
)
@@ -21,18 +18,24 @@ const (
type upstreamMonitor struct {
cfg *ctrld.Config
mu sync.Mutex
mu sync.RWMutex
checking map[string]bool
down map[string]bool
failureReq map[string]uint64
recovered map[string]bool
// failureTimerActive tracks if a timer is already running for a given upstream.
failureTimerActive map[string]bool
}
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
um := &upstreamMonitor{
cfg: cfg,
checking: make(map[string]bool),
down: make(map[string]bool),
failureReq: make(map[string]uint64),
cfg: cfg,
checking: make(map[string]bool),
down: make(map[string]bool),
failureReq: make(map[string]uint64),
recovered: make(map[string]bool),
failureTimerActive: make(map[string]bool),
}
for n := range cfg.Upstream {
upstream := upstreamPrefix + n
@@ -42,14 +45,47 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
return um
}
// increaseFailureCount increase failed queries count for an upstream by 1.
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down
// within 10 seconds if failures persist, without spawning duplicate goroutines.
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
um.mu.Lock()
defer um.mu.Unlock()
if um.recovered[upstream] {
mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream)
return
}
um.failureReq[upstream] += 1
failedCount := um.failureReq[upstream]
um.down[upstream] = failedCount >= maxFailureRequest
// Log the updated failure count.
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
// If this is the first failure and no timer is running, start a 10-second timer.
if failedCount == 1 && !um.failureTimerActive[upstream] {
um.failureTimerActive[upstream] = true
go func(upstream string) {
time.Sleep(10 * time.Second)
um.mu.Lock()
defer um.mu.Unlock()
// If no success occurred during the 10-second window (i.e. counter remains > 0)
// and the upstream is not in a recovered state, mark it as down.
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
um.down[upstream] = true
mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
}
// Reset the timer flag so that a new timer can be spawned if needed.
um.failureTimerActive[upstream] = false
}(upstream)
}
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
if failedCount >= maxFailureRequest {
um.down[upstream] = true
mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
}
}
// isDown reports whether the given upstream is being marked as down.
@@ -63,56 +99,28 @@ func (um *upstreamMonitor) isDown(upstream string) bool {
// reset marks an upstream as up and set failed queries counter to zero.
func (um *upstreamMonitor) reset(upstream string) {
um.mu.Lock()
defer um.mu.Unlock()
um.failureReq[upstream] = 0
um.down[upstream] = false
}
// checkUpstream checks the given upstream status, periodically sending query to upstream
// until successfully. An upstream status/counter will be reset once it becomes reachable.
func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) {
p.um.mu.Lock()
isChecking := p.um.checking[upstream]
if isChecking {
p.um.mu.Unlock()
return
}
p.um.checking[upstream] = true
p.um.mu.Unlock()
defer func() {
p.um.mu.Lock()
p.um.checking[upstream] = false
p.um.mu.Unlock()
um.recovered[upstream] = true
um.mu.Unlock()
go func() {
// debounce the recovery to avoid incrementing failure counts already in flight
time.Sleep(1 * time.Second)
um.mu.Lock()
um.recovered[upstream] = false
um.mu.Unlock()
}()
resolver, err := ctrld.NewResolver(uc)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
return
}
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
check := func() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
uc.ReBootstrap()
_, err := resolver.Resolve(ctx, msg)
return err
}
for {
if err := check(); err == nil {
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
p.um.reset(upstream)
if p.leakingQuery.CompareAndSwap(true, false) {
p.leakingQueryMu.Lock()
p.leakingQueryWasRun = false
p.leakingQueryMu.Unlock()
mainLog.Load().Warn().Msg("stop leaking query")
}
return
}
time.Sleep(checkUpstreamBackoffSleep)
}
}
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
var count int
um.mu.RLock()
for _, upstream := range upstreams {
if !um.down[upstream] {
count++
}
}
um.mu.RUnlock()
return count
}

View File

@@ -205,7 +205,7 @@ type ServiceConfig struct {
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp kea-dhcp4"`
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
@@ -384,7 +384,7 @@ func (uc *UpstreamConfig) IsDiscoverable() bool {
return *uc.Discoverable
}
switch uc.Type {
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate:
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate, ResolverTypeLocal:
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
}
@@ -458,7 +458,7 @@ func (uc *UpstreamConfig) ReBootstrap() {
}
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
if uc.rebootstrap.CompareAndSwap(false, true) {
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
}
return true, nil
})
@@ -886,3 +886,12 @@ func upstreamUID() string {
return hex.EncodeToString(b)
}
}
// String returns a string representation of the UpstreamConfig for logging.
func (uc *UpstreamConfig) String() string {
if uc == nil {
return "<nil>"
}
return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}",
uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack)
}

View File

@@ -2,12 +2,16 @@ package ctrld
import (
"net/url"
"os"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
l := zerolog.New(os.Stdout)
ProxyLogger.Store(&l)
uc := &UpstreamConfig{
Name: "test",
Type: ResolverTypeDOH,

View File

@@ -34,7 +34,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
}
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
rt := &http3.RoundTripper{}
rt := &http3.Transport{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
_, port, _ := net.SplitHostPort(addr)
@@ -64,7 +64,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) {
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
rt.CloseIdleConnections()
})
return rt

View File

@@ -111,6 +111,7 @@ func TestConfigValidation(t *testing.T) {
{"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false},
{"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false},
{"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true},
{"kea dhcp4 format", configWithDhcp4KeaFormat(t), false},
}
for _, tc := range tests {
@@ -307,6 +308,12 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
return cfg
}
func configWithDhcp4KeaFormat(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Service.DHCPLeaseFileFormat = "kea-dhcp4"
return cfg
}
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Upstream["0"].Endpoint = "/1.1.1.1"

23
go.mod
View File

@@ -9,6 +9,7 @@ require (
github.com/ameshkov/dnsstamps v1.0.3
github.com/coreos/go-systemd/v22 v22.5.0
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf
github.com/docker/go-units v0.5.0
github.com/frankban/quicktest v1.14.6
github.com/fsnotify/fsnotify v1.7.0
github.com/go-playground/validator/v10 v10.11.1
@@ -20,6 +21,7 @@ require (
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
github.com/kardianos/service v1.2.1
github.com/mdlayher/ndp v1.0.1
github.com/microsoft/wmi v0.24.5
github.com/miekg/dns v1.1.58
github.com/minio/selfupdate v0.6.0
github.com/olekukonko/tablewriter v0.0.5
@@ -27,16 +29,16 @@ require (
github.com/prometheus/client_golang v1.19.1
github.com/prometheus/client_model v0.5.0
github.com/prometheus/prom2json v1.3.3
github.com/quic-go/quic-go v0.42.0
github.com/quic-go/quic-go v0.48.2
github.com/rs/zerolog v1.28.0
github.com/spf13/cobra v1.8.1
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.16.0
github.com/stretchr/testify v1.9.0
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/net v0.27.0
golang.org/x/sync v0.7.0
golang.org/x/sys v0.22.0
golang.org/x/net v0.33.0
golang.org/x/sync v0.10.0
golang.org/x/sys v0.29.0
golang.zx2c4.com/wireguard/windows v0.5.3
tailscale.com v1.74.0
)
@@ -49,12 +51,14 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
@@ -72,10 +76,11 @@ require (
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/quic-go/qpack v0.4.0 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/spf13/afero v1.9.5 // indirect
@@ -87,10 +92,10 @@ require (
go.uber.org/mock v0.4.0 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect
golang.org/x/crypto v0.31.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/mod v0.19.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/tools v0.23.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
@@ -99,4 +104,4 @@ require (
replace github.com/mr-karan/doggo => github.com/Windscribe/doggo v0.0.0-20220919152748-2c118fc391f8
replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be
replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c

54
go.sum
View File

@@ -42,8 +42,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik=
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w=
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
@@ -74,6 +74,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa h1:h8TfIT1xc8FWbwwpmHn1J5i43Y0uZP97GqasGCzSRJk=
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
@@ -91,6 +93,8 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
@@ -162,6 +166,8 @@ github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlG
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
@@ -207,11 +213,10 @@ github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@@ -227,6 +232,8 @@ github.com/mdlayher/packet v1.1.2 h1:3Up1NG6LZrsgDVn6X4L9Ge/iyRyxFEFD9o6Pr3Q1nQY
github.com/mdlayher/packet v1.1.2/go.mod h1:GEu1+n9sG5VtiRE4SydOmX5GTwyyYlteZiFU+x0kew4=
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
github.com/microsoft/wmi v0.24.5 h1:NT+WqhjKbEcg3ldmDsRMarWgHGkpeW+gMopSCfON0kM=
github.com/microsoft/wmi v0.24.5/go.mod h1:1zbdSF0A+5OwTUII5p3hN7/K6KF2m3o27pSG6Y51VU8=
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU=
@@ -245,6 +252,7 @@ github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFu
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -261,10 +269,10 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/prometheus/prom2json v1.3.3 h1:IYfSMiZ7sSOfliBoo89PcufjWO4eAR0gznGcETyaUgo=
github.com/prometheus/prom2json v1.3.3/go.mod h1:Pv4yIPktEkK7btWsrUTWDDDrnpUrAELaOCj+oFwlgmc=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
@@ -274,7 +282,7 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM=
github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ=
@@ -338,8 +346,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -350,8 +358,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA=
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@@ -409,8 +417,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -430,8 +438,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -472,16 +480,16 @@ golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -492,8 +500,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@@ -77,6 +77,7 @@ type Table struct {
hostnameResolvers []HostnameResolver
refreshers []refresher
initOnce sync.Once
stopOnce sync.Once
refreshInterval int
dhcp *dhcp
@@ -90,7 +91,9 @@ type Table struct {
vni *virtualNetworkIface
svcCfg ctrld.ServiceConfig
quitCh chan struct{}
stopCh chan struct{}
selfIP string
selfIPLock sync.RWMutex
cdUID string
ptrNameservers []string
}
@@ -103,6 +106,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
return &Table{
svcCfg: cfg.Service,
quitCh: make(chan struct{}),
stopCh: make(chan struct{}),
selfIP: selfIP,
cdUID: cdUID,
ptrNameservers: ns,
@@ -120,24 +124,59 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) {
// RefreshLoop runs all the refresher to update new client info data.
func (t *Table) RefreshLoop(ctx context.Context) {
timer := time.NewTicker(time.Second * time.Duration(t.refreshInterval))
defer timer.Stop()
defer func() {
timer.Stop()
close(t.quitCh)
}()
for {
select {
case <-timer.C:
for _, r := range t.refreshers {
_ = r.refresh()
}
t.Refresh()
case <-t.stopCh:
return
case <-ctx.Done():
close(t.quitCh)
return
}
}
}
// Init initializes all client info discovers.
func (t *Table) Init() {
t.initOnce.Do(t.init)
}
// Refresh forces all discovers to retrieve new data.
func (t *Table) Refresh() {
for _, r := range t.refreshers {
_ = r.refresh()
}
}
// Stop stops all the discovers.
// It blocks until all the discovers done.
func (t *Table) Stop() {
t.stopOnce.Do(func() {
close(t.stopCh)
})
<-t.quitCh
}
// SelfIP returns the selfIP value of the Table in a thread-safe manner.
func (t *Table) SelfIP() string {
t.selfIPLock.RLock()
defer t.selfIPLock.RUnlock()
return t.selfIP
}
// SetSelfIP sets the selfIP value of the Table in a thread-safe manner.
func (t *Table) SetSelfIP(ip string) {
t.selfIPLock.Lock()
defer t.selfIPLock.Unlock()
t.selfIP = ip
t.dhcp.selfIP = t.selfIP
t.dhcp.addSelf()
}
func (t *Table) init() {
// Custom client ID presents, use it as the only source.
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
@@ -381,9 +420,7 @@ func (t *Table) lookupHostnameAll(ip, mac string) []*hostnameEntry {
// ListClients returns list of clients discovered by ctrld.
func (t *Table) ListClients() []*Client {
for _, r := range t.refreshers {
_ = r.refresh()
}
t.Refresh()
ipMap := make(map[string]*Client)
il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni}
for _, ir := range il {

View File

@@ -25,8 +25,12 @@ import (
const (
apiDomainCom = "api.controld.com"
apiDomainDev = "api.controld.dev"
resolverDataURLCom = "https://api.controld.com/utility"
resolverDataURLDev = "https://api.controld.dev/utility"
apiURLCom = "https://api.controld.com"
apiURLDev = "https://api.controld.dev"
resolverDataURLCom = apiURLCom + "/utility"
resolverDataURLDev = apiURLDev + "/utility"
logURLCom = apiURLCom + "/logs"
logURLDev = apiURLDev + "/logs"
InvalidConfigCode = 40402
)
@@ -49,14 +53,14 @@ type utilityResponse struct {
} `json:"body"`
}
type UtilityErrorResponse struct {
type ErrorResponse struct {
ErrorField struct {
Message string `json:"message"`
Code int `json:"code"`
} `json:"error"`
}
func (u UtilityErrorResponse) Error() string {
func (u ErrorResponse) Error() string {
return u.ErrorField.Message
}
@@ -71,6 +75,12 @@ type UtilityOrgRequest struct {
Hostname string `json:"hostname"`
}
// LogsRequest contains request data for sending runtime logs to API.
type LogsRequest struct {
UID string `json:"uid"`
Data io.ReadCloser `json:"-"`
}
// FetchResolverConfig fetch Control D config for given uid.
func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) {
uid, clientID := ParseRawUID(rawUID)
@@ -123,6 +133,81 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
}
req.URL.RawQuery = q.Encode()
req.Header.Add("Content-Type", "application/json")
transport := apiTransport(cdDev)
client := http.Client{
Timeout: 10 * time.Second,
Transport: transport,
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err)
}
defer resp.Body.Close()
d := json.NewDecoder(resp.Body)
if resp.StatusCode != http.StatusOK {
errResp := &ErrorResponse{}
if err := d.Decode(errResp); err != nil {
return nil, err
}
return nil, errResp
}
ur := &utilityResponse{}
if err := d.Decode(ur); err != nil {
return nil, err
}
return &ur.Body.Resolver, nil
}
// SendLogs sends runtime log to ControlD API.
func SendLogs(lr *LogsRequest, cdDev bool) error {
defer lr.Data.Close()
apiUrl := logURLCom
if cdDev {
apiUrl = logURLDev
}
req, err := http.NewRequest("POST", apiUrl, lr.Data)
if err != nil {
return fmt.Errorf("http.NewRequest: %w", err)
}
q := req.URL.Query()
q.Set("uid", lr.UID)
req.URL.RawQuery = q.Encode()
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
transport := apiTransport(cdDev)
client := http.Client{
Timeout: 300 * time.Second,
Transport: transport,
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("SendLogs client.Do: %w", err)
}
defer resp.Body.Close()
d := json.NewDecoder(resp.Body)
if resp.StatusCode != http.StatusOK {
errResp := &ErrorResponse{}
if err := d.Decode(errResp); err != nil {
return err
}
return errResp
}
_, _ = io.Copy(io.Discard, resp.Body)
return nil
}
// ParseRawUID parse the input raw UID, returning real UID and ClientID.
// The raw UID can have 2 forms:
//
// - <uid>
// - <uid>/<client_id>
func ParseRawUID(rawUID string) (string, string) {
uid, clientID, _ := strings.Cut(rawUID, "/")
return uid, clientID
}
// apiTransport returns an HTTP transport for connecting to ControlD API endpoint.
func apiTransport(cdDev bool) *http.Transport {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
apiDomain := apiDomainCom
@@ -143,41 +228,8 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
d := &ctrldnet.ParallelDialer{}
return d.DialContext(ctx, network, addrs)
}
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
}
client := http.Client{
Timeout: 10 * time.Second,
Transport: transport,
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("client.Do: %w", err)
}
defer resp.Body.Close()
d := json.NewDecoder(resp.Body)
if resp.StatusCode != http.StatusOK {
errResp := &UtilityErrorResponse{}
if err := d.Decode(errResp); err != nil {
return nil, err
}
return nil, errResp
}
ur := &utilityResponse{}
if err := d.Decode(ur); err != nil {
return nil, err
}
return &ur.Body.Resolver, nil
}
// ParseRawUID parse the input raw UID, returning real UID and ClientID.
// The raw UID can have 2 forms:
//
// - <uid>
// - <uid>/<client_id>
func ParseRawUID(rawUID string) (string, string) {
uid, clientID, _ := strings.Cut(rawUID, "/")
return uid, clientID
return transport
}

View File

@@ -1,19 +1,16 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
//go:build dragonfly || freebsd || netbsd || openbsd
package ctrld
import (
"net"
"os/exec"
"runtime"
"strings"
"syscall"
"golang.org/x/net/route"
)
func dnsFns() []dnsFn {
return []dnsFn{dnsFromRIB, dnsFromIPConfig}
return []dnsFn{dnsFromRIB}
}
func dnsFromRIB() []string {
@@ -49,18 +46,6 @@ func dnsFromRIB() []string {
return dns
}
func dnsFromIPConfig() []string {
if runtime.GOOS != "darwin" {
return nil
}
cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server")
out, _ := cmd.Output()
if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil {
return []string{ip.String()}
}
return nil
}
func toNetIP(addr route.Addr) net.IP {
switch t := addr.(type) {
case *route.Inet4Addr:

217
nameservers_darwin.go Normal file
View File

@@ -0,0 +1,217 @@
//go:build darwin
package ctrld
import (
"bufio"
"bytes"
"context"
"fmt"
"net"
"os/exec"
"regexp"
"slices"
"strings"
"time"
"tailscale.com/net/netmon"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
func dnsFns() []dnsFn {
return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers}
}
// dnsFromResolvConf reads nameservers from /etc/resolv.conf
func dnsFromResolvConf() []string {
const (
maxRetries = 10
retryInterval = 100 * time.Millisecond
)
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
var dns []string
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
time.Sleep(retryInterval)
}
nss := resolvconffile.NameServers("")
var localDNS []string
seen := make(map[string]bool)
for _, ns := range nss {
if ip := net.ParseIP(ns); ip != nil {
// skip loopback IPs
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
ipStr := v.String()
if ip.String() == ipStr {
continue
}
}
if !seen[ip.String()] {
seen[ip.String()] = true
localDNS = append(localDNS, ip.String())
}
}
}
// If we successfully read the file and found nameservers, return them
if len(localDNS) > 0 {
return localDNS
}
}
return dns
}
func getDNSFromScutil() []string {
logger := *ProxyLogger.Load()
const (
maxRetries = 10
retryInterval = 100 * time.Millisecond
)
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
var nameservers []string
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
time.Sleep(retryInterval)
}
cmd := exec.Command("scutil", "--dns")
output, err := cmd.Output()
if err != nil {
Log(context.Background(), logger.Error(), "failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err)
continue
}
var localDNS []string
seen := make(map[string]bool)
scanner := bufio.NewScanner(bytes.NewReader(output))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "nameserver[") {
parts := strings.Split(line, ":")
if len(parts) == 2 {
ns := strings.TrimSpace(parts[1])
if ip := net.ParseIP(ns); ip != nil {
// skip loopback IPs
isLocal := false
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
ipStr := v.String()
if ip.String() == ipStr {
isLocal = true
break
}
}
if !isLocal && !seen[ip.String()] {
seen[ip.String()] = true
localDNS = append(localDNS, ip.String())
}
}
}
}
}
if err := scanner.Err(); err != nil {
Log(context.Background(), logger.Error(), "error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err)
continue
}
// If we successfully read the output and found nameservers, return them
if len(localDNS) > 0 {
return localDNS
}
}
return nameservers
}
func getDHCPNameservers(iface string) ([]string, error) {
// Run the ipconfig command for the given interface.
cmd := exec.Command("ipconfig", "getpacket", iface)
output, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("error running ipconfig: %v", err)
}
// Look for a line like:
// domain_name_servers = 192.168.1.1 8.8.8.8;
re := regexp.MustCompile(`domain_name_servers\s*=\s*(.*);`)
matches := re.FindStringSubmatch(string(output))
if len(matches) < 2 {
return nil, fmt.Errorf("no DHCP nameservers found")
}
// Split the nameservers by whitespace.
nameservers := strings.Fields(matches[1])
return nameservers, nil
}
func getAllDHCPNameservers() []string {
interfaces, err := net.Interfaces()
if err != nil {
return nil
}
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
var allNameservers []string
seen := make(map[string]bool)
for _, iface := range interfaces {
// Skip interfaces that are:
// - down
// - loopback
// - not physical (virtual)
// - point-to-point (like VPN interfaces)
// - without MAC address (non-physical)
if iface.Flags&net.FlagUp == 0 ||
iface.Flags&net.FlagLoopback != 0 ||
iface.Flags&net.FlagPointToPoint != 0 ||
(iface.Flags&net.FlagBroadcast == 0 &&
iface.Flags&net.FlagMulticast == 0) ||
len(iface.HardwareAddr) == 0 ||
strings.HasPrefix(iface.Name, "utun") ||
strings.HasPrefix(iface.Name, "llw") ||
strings.HasPrefix(iface.Name, "awdl") {
continue
}
// Verify it's a valid MAC address (should be 6 bytes for IEEE 802 MAC-48)
if len(iface.HardwareAddr) != 6 {
continue
}
nameservers, err := getDHCPNameservers(iface.Name)
if err != nil {
continue
}
// Add unique nameservers to the result, skipping local IPs
for _, ns := range nameservers {
if ip := net.ParseIP(ns); ip != nil {
// skip loopback and local IPs
isLocal := false
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
if ip.String() == v.String() {
isLocal = true
break
}
}
if !isLocal && !seen[ns] {
seen[ns] = true
allNameservers = append(allNameservers, ns)
}
}
}
}
return allNameservers
}

View File

@@ -1,44 +1,473 @@
package ctrld
import (
"context"
"fmt"
"io"
"log"
"net"
"os"
"strings"
"syscall"
"time"
"unsafe"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
"github.com/rs/zerolog"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
const (
maxDNSAdapterRetries = 5
retryDelayDNSAdapter = 1 * time.Second
defaultDNSAdapterTimeout = 10 * time.Second
minDNSServers = 1 // Minimum number of DNS servers we want to find
NetSetupUnknown uint32 = 0
NetSetupWorkgroup uint32 = 1
NetSetupDomain uint32 = 2
NetSetupCloudDomain uint32 = 3
DS_FORCE_REDISCOVERY = 0x00000001
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
DS_BACKGROUND_ONLY = 0x00000100
DS_IP_REQUIRED = 0x00000200
DS_IS_DNS_NAME = 0x00020000
DS_RETURN_DNS_NAME = 0x40000000
)
type DomainControllerInfo struct {
DomainControllerName *uint16
DomainControllerAddress *uint16
DomainControllerAddressType uint32
DomainGuid windows.GUID
DomainName *uint16
DnsForestName *uint16
Flags uint32
DcSiteName *uint16
ClientSiteName *uint16
}
func dnsFns() []dnsFn {
return []dnsFn{dnsFromAdapter}
}
func dnsFromAdapter() []string {
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix)
if err != nil {
return nil
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
defer cancel()
var ns []string
var err error
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
for i := 0; i < maxDNSAdapterRetries; i++ {
if ctx.Err() != nil {
Log(context.Background(), logger.Debug(),
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
return nil
}
ns, err = getDNSServers(ctx)
if err == nil && len(ns) >= minDNSServers {
if i > 0 {
Log(context.Background(), logger.Debug(),
"Successfully got DNS servers after %d attempts, found %d servers",
i+1, len(ns))
}
return ns
}
// if osResolver is not initialized, this is likely a command line run
// and ctrld is already on the interface, abort retries
if or == nil {
return ns
}
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get DNS servers, attempt %d: %v", i+1, err)
} else {
Log(context.Background(), logger.Debug(),
"Got insufficient DNS servers, retrying, found %d servers", len(ns))
}
select {
case <-ctx.Done():
return nil
case <-time.After(retryDelayDNSAdapter):
}
}
Log(context.Background(), logger.Debug(),
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
return ns
}
func getDNSServers(ctx context.Context) ([]string, error) {
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
// Check context before making the call
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Get DNS servers from adapters (existing method)
flags := winipcfg.GAAFlagIncludeGateways |
winipcfg.GAAFlagIncludePrefix
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
if err != nil {
return nil, fmt.Errorf("getting adapters: %w", err)
}
Log(context.Background(), logger.Debug(),
"Found network adapters, count=%d", len(aas))
// Try to get domain controller info if domain-joined
var dcServers []string
isDomain := checkDomainJoined()
if isDomain {
domainName, err := getLocalADDomain()
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get local AD domain: %v", err)
} else {
// Load netapi32.dll
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
dsDcName := netapi32.NewProc("DsGetDcNameW")
var info *DomainControllerInfo
flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME)
domainUTF16, err := windows.UTF16PtrFromString(domainName)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to convert domain name to UTF16: %v", err)
} else {
Log(context.Background(), logger.Debug(),
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
// Call DsGetDcNameW with domain name
ret, _, err := dsDcName.Call(
0, // ComputerName - can be NULL
uintptr(unsafe.Pointer(domainUTF16)), // DomainName
0, // DomainGuid - not needed
0, // SiteName - not needed
uintptr(flags), // Flags
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
if ret != 0 {
switch ret {
case 1355: // ERROR_NO_SUCH_DOMAIN
Log(context.Background(), logger.Debug(),
"Domain not found: %s (%d)", domainName, ret)
case 1311: // ERROR_NO_LOGON_SERVERS
Log(context.Background(), logger.Debug(),
"No logon servers available for domain: %s (%d)", domainName, ret)
case 1004: // ERROR_DC_NOT_FOUND
Log(context.Background(), logger.Debug(),
"Domain controller not found for domain: %s (%d)", domainName, ret)
case 1722: // RPC_S_SERVER_UNAVAILABLE
Log(context.Background(), logger.Debug(),
"RPC server unavailable for domain: %s (%d)", domainName, ret)
default:
Log(context.Background(), logger.Debug(),
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
}
} else if info != nil {
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
if info.DomainControllerAddress != nil {
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
Log(context.Background(), logger.Debug(),
"Found domain controller address: %s", dcAddr)
if ip := net.ParseIP(dcAddr); ip != nil {
dcServers = append(dcServers, ip.String())
Log(context.Background(), logger.Debug(),
"Added domain controller DNS servers: %v", dcServers)
}
} else {
Log(context.Background(), logger.Debug(),
"No domain controller address found")
}
}
}
}
}
// Continue with existing adapter DNS collection
ns := make([]string, 0, len(aas)*2)
seen := make(map[string]bool)
addressMap := make(map[string]struct{})
// Collect all local IPs
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
Log(context.Background(), logger.Debug(),
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
continue
}
// Skip if software loopback or other non-physical types
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
Log(context.Background(), logger.Debug(),
"Skipping %s (software loopback)", aa.FriendlyName())
continue
}
Log(context.Background(), logger.Debug(),
"Processing adapter %s", aa.FriendlyName())
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
addressMap[a.Address.IP().String()] = struct{}{}
ip := a.Address.IP().String()
addressMap[ip] = struct{}{}
Log(context.Background(), logger.Debug(),
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
}
}
validInterfacesMap := validInterfaces()
// Collect DNS servers
for _, aa := range aas {
if aa.OperStatus != winipcfg.IfOperStatusUp {
continue
}
// Skip if software loopback or other non-physical types
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
Log(context.Background(), logger.Debug(),
"Skipping %s (software loopback)", aa.FriendlyName())
continue
}
// if not in the validInterfacesMap, skip
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
Log(context.Background(), logger.Debug(),
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
continue
}
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
ip := dns.Address.IP()
if ip == nil || ip.IsLoopback() || seen[ip.String()] {
if ip == nil {
Log(context.Background(), logger.Debug(),
"Skipping nil IP from adapter %s", aa.FriendlyName())
continue
}
if _, ok := addressMap[ip.String()]; ok {
ipStr := ip.String()
l := logger.Debug().
Str("ip", ipStr).
Str("adapter", aa.FriendlyName())
if ip.IsLoopback() {
l.Msg("Skipping loopback IP")
continue
}
seen[ip.String()] = true
ns = append(ns, ip.String())
if seen[ipStr] {
l.Msg("Skipping duplicate IP")
continue
}
if _, ok := addressMap[ipStr]; ok {
l.Msg("Skipping local interface IP")
continue
}
seen[ipStr] = true
ns = append(ns, ipStr)
l.Msg("Added DNS server")
}
}
return ns
// Add DC servers if they're not already in the list
for _, dcServer := range dcServers {
if !seen[dcServer] {
seen[dcServer] = true
ns = append(ns, dcServer)
Log(context.Background(), logger.Debug(),
"Added additional domain controller DNS server: %s", dcServer)
}
}
if len(ns) == 0 {
return nil, fmt.Errorf("no valid DNS servers found")
}
Log(context.Background(), logger.Debug(),
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
len(ns), ns, len(dcServers))
return ns, nil
}
func nameserversFromResolvconf() []string {
return nil
}
// checkDomainJoined checks if the machine is joined to an Active Directory domain
// Returns whether it's domain joined and the domain name if available
func checkDomainJoined() bool {
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
var domain *uint16
var status uint32
err := windows.NetGetJoinInformation(nil, &domain, &status)
if err != nil {
Log(context.Background(), logger.Debug(),
"Failed to get domain join status: %v", err)
return false
}
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
domainName := windows.UTF16PtrToString(domain)
Log(context.Background(), logger.Debug(),
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)",
domainName, status)
// Consider domain or cloud domain as domain-joined
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
Log(context.Background(), logger.Debug(),
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
status,
status == NetSetupDomain,
status == NetSetupCloudDomain,
isDomain)
return isDomain
}
// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*)
// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call.
func getLocalADDomain() (string, error) {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
// 1) Check environment variable
envDomain := os.Getenv("USERDNSDOMAIN")
if envDomain != "" {
return strings.TrimSpace(envDomain), nil
}
// 2) Query WMI via the microsoft/wmi library
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("Win32_ComputerSystem")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
return "", fmt.Errorf("WMI query failed: %v", err)
}
// If no results, return an error
if len(instances) == 0 {
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
}
// We only care about the first row
domainVal, err := instances[0].GetProperty("Domain")
if err != nil {
return "", fmt.Errorf("machine does not appear to have a domain set: %v", err)
}
domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal))
if domainName == "" {
return "", fmt.Errorf("machine does not appear to have a domain set")
}
return domainName, nil
}
// validInterfaces returns a list of all physical interfaces.
// this is a duplicate of what is in net_windows.go, we should
// clean this up so there is only one version
func validInterfaces() map[string]struct{} {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("MSFT_NetAdapter")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
Log(context.Background(), logger.Warn(),
"failed to get wmi network adapter: %v", err)
return nil
}
var adapters []string
for _, i := range instances {
adapter, err := netadapter.NewNetworkAdapter(i)
if err != nil {
Log(context.Background(), logger.Warn(),
"failed to get network adapter: %v", err)
continue
}
name, err := adapter.GetPropertyName()
if err != nil {
Log(context.Background(), logger.Warn(),
"failed to get interface name: %v", err)
continue
}
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
//
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
// if this is a physical adapter or FALSE if this is not a physical adapter."
physical, err := adapter.GetPropertyConnectorPresent()
if err != nil {
Log(context.Background(), logger.Debug(),
"failed to get network adapter connector present property: %v", err)
continue
}
if !physical {
Log(context.Background(), logger.Debug(),
"skipping non-physical adapter: %s", name)
continue
}
// Check if it's a hardware interface. Checking only for connector present is not enough
// because some interfaces are not physical but have a connector.
hardware, err := adapter.GetPropertyHardwareInterface()
if err != nil {
Log(context.Background(), logger.Debug(),
"failed to get network adapter hardware interface property: %v", err)
continue
}
if !hardware {
Log(context.Background(), logger.Debug(),
"skipping non-hardware interface: %s", name)
continue
}
adapters = append(adapters, name)
}
m := make(map[string]struct{})
for _, ifaceName := range adapters {
m[ifaceName] = struct{}{}
}
return m
}

View File

@@ -4,15 +4,17 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"github.com/rs/zerolog"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
)
@@ -30,8 +32,10 @@ const (
ResolverTypeOS = "os"
// ResolverTypeLegacy specifies legacy resolver.
ResolverTypeLegacy = "legacy"
// ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only.
// ResolverTypePrivate is like ResolverTypeOS, but use for private resolver only.
ResolverTypePrivate = "private"
// ResolverTypeLocal is like ResolverTypeOS, but use for local resolver only.
ResolverTypeLocal = "local"
// ResolverTypeSDNS specifies resolver with information encoded using DNS Stamps.
// See: https://dnscrypt.info/stamps-specifications/
ResolverTypeSDNS = "sdns"
@@ -44,8 +48,30 @@ const (
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
// or is the Resolver used for ResolverTypeOS.
var or = newResolverWithNameserver(defaultNameservers())
var localResolver = newLocalResolver()
var (
resolverMutex sync.Mutex
or *osResolver
defaultLocalIPv4 atomic.Value // holds net.IP (IPv4)
defaultLocalIPv6 atomic.Value // holds net.IP (IPv6)
)
func newLocalResolver() Resolver {
var nss []string
for _, addr := range Rfc1918Addresses() {
nss = append(nss, net.JoinHostPort(addr, "53"))
}
return NewResolverWithNameserver(nss)
}
// LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network.
type LanQueryCtxKey struct{}
// LanQueryCtx returns a context.Context with LanQueryCtxKey set.
func LanQueryCtx(ctx context.Context) context.Context {
return context.WithValue(ctx, LanQueryCtxKey{}, true)
}
// defaultNameservers is like nameservers with each element formed "ip:53".
func defaultNameservers() []string {
@@ -63,17 +89,39 @@ func availableNameservers() []string {
// Ignore local addresses to prevent loop.
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
machineIPsMap := make(map[string]struct{}, len(regularIPs))
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
machineIPsMap[v.String()] = struct{}{}
//load the logger
logger := zerolog.New(io.Discard)
if ProxyLogger.Load() != nil {
logger = *ProxyLogger.Load()
}
for _, ns := range nameservers() {
Log(context.Background(), logger.Debug(),
"Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs)
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
ipStr := v.String()
machineIPsMap[ipStr] = struct{}{}
Log(context.Background(), logger.Debug(),
"Added local IP to OS resolverexclusion map: %s", ipStr)
}
systemNameservers := nameservers()
Log(context.Background(), logger.Debug(),
"Got system nameservers: %v", systemNameservers)
for _, ns := range systemNameservers {
if _, ok := machineIPsMap[ns]; ok {
Log(context.Background(), logger.Debug(),
"Skipping local nameserver: %s", ns)
continue
}
if testNameserver(ns) {
nss = append(nss, ns)
}
nss = append(nss, ns)
Log(context.Background(), logger.Debug(),
"Added non-local nameserver: %s", ns)
}
Log(context.Background(), logger.Debug(),
"Final available nameservers: %v", nss)
return nss
}
@@ -82,77 +130,47 @@ func availableNameservers() []string {
//
// It's the caller's responsibility to ensure the system DNS is in a clean state before
// calling this function.
func InitializeOsResolver() []string {
return initializeOsResolver(availableNameservers())
func InitializeOsResolver(guardAgainstNoNameservers bool) []string {
nameservers := availableNameservers()
// if no nameservers, return empty slice so we dont remove all nameservers
if len(nameservers) == 0 && guardAgainstNoNameservers {
return []string{}
}
ns := initializeOsResolver(nameservers)
resolverMutex.Lock()
defer resolverMutex.Unlock()
or = newResolverWithNameserver(ns)
return ns
}
// initializeOsResolver performs logic for choosing OS resolver nameserver.
// The logic:
//
// - First available LAN servers are saved and store.
// - Later calls, if no LAN servers available, the saved servers above will be used.
func initializeOsResolver(servers []string) []string {
var (
nss []string
publicNss []string
)
var (
lastLanServer netip.Addr
curLanServer netip.Addr
curLanServerAvailable bool
)
if p := or.currentLanServer.Load(); p != nil {
curLanServer = *p
or.currentLanServer.Store(nil)
}
if p := or.lastLanServer.Load(); p != nil {
lastLanServer = *p
or.lastLanServer.Store(nil)
}
var lanNss, publicNss []string
// First categorize servers
for _, ns := range servers {
addr, err := netip.ParseAddr(ns)
if err != nil {
continue
}
server := net.JoinHostPort(ns, "53")
// Always use new public nameserver.
if !isLanAddr(addr) {
publicNss = append(publicNss, server)
nss = append(nss, server)
continue
}
// For LAN server, storing only current and last LAN server if any.
if addr.Compare(curLanServer) == 0 {
curLanServerAvailable = true
if isLanAddr(addr) {
lanNss = append(lanNss, server)
} else {
if addr.Compare(lastLanServer) == 0 {
or.lastLanServer.Store(&addr)
} else {
if or.currentLanServer.CompareAndSwap(nil, &addr) {
nss = append(nss, server)
}
}
publicNss = append(publicNss, server)
}
}
// Store current LAN server as last one only if it's still available.
if curLanServerAvailable && curLanServer.IsValid() {
or.lastLanServer.Store(&curLanServer)
nss = append(nss, net.JoinHostPort(curLanServer.String(), "53"))
}
if len(publicNss) == 0 {
publicNss = append(publicNss, controldPublicDnsWithPort)
nss = append(nss, controldPublicDnsWithPort)
}
or.publicServer.Store(&publicNss)
return nss
}
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
func testNameserver(addr string) bool {
msg := new(dns.Msg)
msg.SetQuestion("controld.com.", dns.TypeNS)
client := new(dns.Client)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53"))
if err != nil {
ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr)
if len(publicNss) == 0 {
publicNss = []string{controldPublicDnsWithPort}
}
return err == nil
return slices.Concat(lanNss, publicNss)
}
// Resolver is the interface that wraps the basic DNS operations.
@@ -175,19 +193,23 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
case ResolverTypeDOQ:
return &doqResolver{uc: uc}, nil
case ResolverTypeOS:
if or == nil {
or = newResolverWithNameserver(defaultNameservers())
}
return or, nil
case ResolverTypeLegacy:
return &legacyResolver{uc: uc}, nil
case ResolverTypePrivate:
return NewPrivateResolver(), nil
case ResolverTypeLocal:
return localResolver, nil
}
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
}
type osResolver struct {
currentLanServer atomic.Pointer[netip.Addr]
lastLanServer atomic.Pointer[netip.Addr]
publicServer atomic.Pointer[[]string]
lanServers atomic.Pointer[[]string]
publicServers atomic.Pointer[[]string]
}
type osResolverResult struct {
@@ -197,26 +219,75 @@ type osResolverResult struct {
lan bool
}
type publicResponse struct {
answer *dns.Msg
server string
}
// SetDefaultLocalIPv4 updates the stored local IPv4.
func SetDefaultLocalIPv4(ip net.IP) {
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip)
defaultLocalIPv4.Store(ip)
}
// SetDefaultLocalIPv6 updates the stored local IPv6.
func SetDefaultLocalIPv6(ip net.IP) {
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip)
defaultLocalIPv6.Store(ip)
}
// GetDefaultLocalIPv4 returns the stored local IPv4 or nil if none.
func GetDefaultLocalIPv4() net.IP {
if v := defaultLocalIPv4.Load(); v != nil {
return v.(net.IP)
}
return nil
}
// GetDefaultLocalIPv6 returns the stored local IPv6 or nil if none.
func GetDefaultLocalIPv6() net.IP {
if v := defaultLocalIPv6.Load(); v != nil {
return v.(net.IP)
}
return nil
}
// customDNSExchange wraps the DNS exchange to use our debug dialer.
// It uses dns.ExchangeWithConn so that our custom dialer is used directly.
func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, time.Duration, error) {
baseDialer := &net.Dialer{
Timeout: 3 * time.Second,
Resolver: &net.Resolver{PreferGo: true},
}
if desiredLocalIP != nil {
baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0}
}
dnsClient := &dns.Client{Net: "udp"}
dnsClient.Dialer = baseDialer
return dnsClient.ExchangeContext(ctx, msg, server)
}
// Resolve resolves DNS queries using pre-configured nameservers.
// Query is sent to all nameservers concurrently, and the first
// success response will be returned.
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
publicServers := *o.publicServer.Load()
nss := make([]string, 0, 2)
if p := o.currentLanServer.Load(); p != nil {
nss = append(nss, net.JoinHostPort(p.String(), "53"))
}
if p := o.lastLanServer.Load(); p != nil {
nss = append(nss, net.JoinHostPort(p.String(), "53"))
publicServers := *o.publicServers.Load()
var nss []string
if p := o.lanServers.Load(); p != nil {
nss = append(nss, (*p)...)
}
numServers := len(nss) + len(publicServers)
// If this is a LAN query, skip public DNS.
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
if ok && lan {
numServers -= len(publicServers)
}
if numServers == 0 {
return nil, errors.New("no nameservers available")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
dnsClient := &dns.Client{Net: "udp"}
ch := make(chan *osResolverResult, numServers)
wg := &sync.WaitGroup{}
wg.Add(numServers)
@@ -229,57 +300,86 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
for _, server := range servers {
go func(server string) {
defer wg.Done()
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
var answer *dns.Msg
var err error
var localOSResolverIP net.IP
if runtime.GOOS == "darwin" {
host, _, err := net.SplitHostPort(server)
if err == nil {
ip := net.ParseIP(host)
if ip != nil && ip.To4() == nil {
// IPv6 nameserver; use default IPv6 address (if set)
localOSResolverIP = GetDefaultLocalIPv6()
} else {
localOSResolverIP = GetDefaultLocalIPv4()
}
}
}
answer, _, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP)
ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan}
}(server)
}
}
do(nss, true)
do(publicServers, false)
if !lan {
do(publicServers, false)
}
logAnswer := func(server string) {
if before, _, found := strings.Cut(server, ":"); found {
server = before
host, _, err := net.SplitHostPort(server)
if err != nil {
// If splitting fails, fallback to the original server string
host = server
}
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", server)
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host)
}
var (
nonSuccessAnswer *dns.Msg
nonSuccessServer string
controldSuccessAnswer *dns.Msg
publicServerAnswer *dns.Msg
publicServer string
publicResponses []publicResponse
)
errs := make([]error, 0, numServers)
for res := range ch {
switch {
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
switch {
case res.server == controldPublicDnsWithPort:
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
case !res.lan && publicServerAnswer == nil:
publicServerAnswer = res.answer // use public DNS answer after LAN server..
publicServer = res.server
default:
case res.lan:
// Always prefer LAN responses immediately
Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server)
cancel()
logAnswer(res.server)
return res.answer, nil
case res.server == controldPublicDnsWithPort:
controldSuccessAnswer = res.answer
case !res.lan:
publicResponses = append(publicResponses, publicResponse{
answer: res.answer,
server: res.server,
})
}
case res.answer != nil:
nonSuccessAnswer = res.answer
nonSuccessServer = res.server
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d",
res.server, res.answer.Rcode)
}
errs = append(errs, res.err)
}
if publicServerAnswer != nil {
logAnswer(publicServer)
return publicServerAnswer, nil
if len(publicResponses) > 0 {
resp := publicResponses[0]
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server)
logAnswer(resp.server)
return resp.answer, nil
}
if controldSuccessAnswer != nil {
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort)
logAnswer(controldPublicDnsWithPort)
return controldSuccessAnswer, nil
}
if nonSuccessAnswer != nil {
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer)
logAnswer(nonSuccessServer)
return nonSuccessAnswer, nil
}
@@ -328,7 +428,11 @@ func LookupIP(domain string) []string {
}
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
nss := defaultNameservers()
if or == nil {
or = newResolverWithNameserver(defaultNameservers())
}
nss := *or.lanServers.Load()
nss = append(nss, *or.publicServers.Load()...)
if withBootstrapDNS {
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
}
@@ -467,17 +571,19 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
// The caller must ensure each server in list is formed "ip:53".
func newResolverWithNameserver(nameservers []string) *osResolver {
r := &osResolver{}
nss := slices.Sorted(slices.Values(nameservers))
for i, ns := range nss {
var publicNss []string
var lanNss []string
for _, ns := range slices.Sorted(slices.Values(nameservers)) {
ip, _, _ := net.SplitHostPort(ns)
addr, _ := netip.ParseAddr(ip)
if isLanAddr(addr) {
r.currentLanServer.Store(&addr)
nss = slices.Delete(nss, i, i+1)
break
lanNss = append(lanNss, ns)
} else {
publicNss = append(publicNss, ns)
}
}
r.publicServer.Store(&nss)
r.lanServers.Store(&lanNss)
r.publicServers.Store(&publicNss)
return r
}

View File

@@ -3,13 +3,10 @@ package ctrld
import (
"context"
"net"
"slices"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/miekg/dns"
)
@@ -20,7 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) {
go func() {
defer cancel()
resolver := &osResolver{}
resolver.publicServer.Store(&[]string{"127.0.0.127:5353"})
resolver.publicServers.Store(&[]string{"127.0.0.127:5353"})
m := new(dns.Msg)
m.SetQuestion("controld.com.", dns.TypeA)
m.RecursionDesired = true
@@ -34,26 +31,51 @@ func Test_osResolver_Resolve(t *testing.T) {
}
}
func Test_osResolver_ResolveLanHostname(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
reqId := "req-id"
ctx = context.WithValue(ctx, ReqIdCtxKey{}, reqId)
ctx = LanQueryCtx(ctx)
go func(ctx context.Context) {
defer cancel()
id, ok := ctx.Value(ReqIdCtxKey{}).(string)
if !ok || id != reqId {
t.Error("missing request id")
return
}
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
if !ok || !lan {
t.Error("not a LAN query")
return
}
resolver := &osResolver{}
resolver.publicServers.Store(&[]string{"76.76.2.0:53"})
m := new(dns.Msg)
m.SetQuestion("controld.com.", dns.TypeA)
m.RecursionDesired = true
_, err := resolver.Resolve(ctx, m)
if err == nil {
t.Error("os resolver succeeded unexpectedly")
return
}
}(ctx)
select {
case <-time.After(10 * time.Second):
t.Error("os resolver hangs")
case <-ctx.Done():
}
}
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
ns := make([]string, 0, 2)
servers := make([]*dns.Server, 0, 2)
successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, dns.RcodeSuccess)
w.WriteMsg(m)
})
nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc {
return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, rcode)
w.WriteMsg(m)
})
}
handlers := []dns.Handler{
nonSuccessHandlerWithRcode(dns.RcodeRefused),
nonSuccessHandlerWithRcode(dns.RcodeNameError),
successHandler,
successHandler(),
}
for i := range handlers {
pc, err := net.ListenPacket("udp", ":0")
@@ -74,7 +96,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
}
}()
resolver := &osResolver{}
resolver.publicServer.Store(&ns)
resolver.publicServers.Store(&ns)
msg := new(dns.Msg)
msg.SetQuestion(".", dns.TypeNS)
answer, err := resolver.Resolve(context.Background(), msg)
@@ -93,7 +115,7 @@ func Test_osResolver_InitializationRace(t *testing.T) {
for range n {
go func() {
defer wg.Done()
InitializeOsResolver()
InitializeOsResolver(false)
}()
}
wg.Wait()
@@ -153,41 +175,18 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H
return server, addr, nil
}
func Test_initializeOsResolver(t *testing.T) {
lanServer1 := "192.168.1.1"
lanServer2 := "10.0.10.69"
wanServer := "1.1.1.1"
publicServers := []string{net.JoinHostPort(wanServer, "53")}
// First initialization.
initializeOsResolver([]string{lanServer1, wanServer})
p := or.currentLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
// No new LAN server, current LAN server -> last LAN server.
initializeOsResolver([]string{lanServer1, wanServer})
p = or.currentLanServer.Load()
assert.Nil(t, p)
p = or.lastLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
// New LAN server detected.
initializeOsResolver([]string{lanServer2, lanServer1, wanServer})
p = or.currentLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer2, p.String())
p = or.lastLanServer.Load()
assert.NotNil(t, p)
assert.Equal(t, lanServer1, p.String())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
// No LAN server available.
initializeOsResolver([]string{wanServer})
assert.Nil(t, or.currentLanServer.Load())
assert.Nil(t, or.lastLanServer.Load())
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
func successHandler() dns.HandlerFunc {
return func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, dns.RcodeSuccess)
w.WriteMsg(m)
}
}
func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc {
return func(w dns.ResponseWriter, msg *dns.Msg) {
m := new(dns.Msg)
m.SetRcode(msg, rcode)
w.WriteMsg(m)
}
}