mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Merge pull request #209 from Control-D-Inc/release-branch-v1.4.0
Release branch v1.4.0
This commit is contained in:
@@ -8,3 +8,8 @@ import (
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule if present.
|
||||
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
hh "github.com/microsoft/wmi/pkg/hardware/host"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
@@ -21,29 +26,48 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
|
||||
// Network rules are lowercase during toml config marshaling,
|
||||
// lowercase the domain here too for consistency.
|
||||
domain = strings.ToLower(domain)
|
||||
domainRuleAdded := addSplitDnsRule(cfg, domain)
|
||||
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
|
||||
return domainRuleAdded || wildcardDomainRuleRuleAdded
|
||||
}
|
||||
|
||||
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
|
||||
// The return value indicates whether the split-rule was added or not.
|
||||
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
|
||||
for n, lc := range cfg.Listener {
|
||||
if lc.Policy == nil {
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{}
|
||||
}
|
||||
domainRule := "*." + strings.TrimPrefix(domain, ".")
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
if _, ok := rule[domainRule]; ok {
|
||||
mainLog.Load().Debug().Msgf("domain rule already exist for listener.%s", n)
|
||||
if _, ok := rule[domain]; ok {
|
||||
mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n)
|
||||
return false
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("adding active directory domain for listener.%s", n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}})
|
||||
mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
cs, err := hh.GetComputerSystem(whost)
|
||||
if cs != nil {
|
||||
defer cs.Close()
|
||||
}
|
||||
return string(output), nil
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pod, err := cs.GetPropertyPartOfDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pod {
|
||||
return cs.GetPropertyDomain()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
71
cmd/cli/ad_windows_test.go
Normal file
71
cmd/cli/ad_windows_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_getActiveDirectoryDomain(t *testing.T) {
|
||||
start := time.Now()
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
domainPowershell, err := getActiveDirectoryDomainPowershell()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if domain != domainPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func getActiveDirectoryDomainPowershell() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
func Test_addSplitDnsRule(t *testing.T) {
|
||||
newCfg := func(domains ...string) *ctrld.Config {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
lc := cfg.Listener["0"]
|
||||
for _, domain := range domains {
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *ctrld.Config
|
||||
domain string
|
||||
added bool
|
||||
}{
|
||||
{"added", newCfg(), "example.com", true},
|
||||
{"TLD existed", newCfg("example.com"), "*.example.com", true},
|
||||
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
|
||||
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
|
||||
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
added := addSplitDnsRule(tc.cfg, tc.domain)
|
||||
assert.Equal(t, tc.added, added)
|
||||
})
|
||||
}
|
||||
}
|
||||
1104
cmd/cli/cli.go
1104
cmd/cli/cli.go
File diff suppressed because it is too large
Load Diff
1362
cmd/cli/commands.go
Normal file
1362
cmd/cli/commands.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,10 @@ func newControlClient(addr string) *controlClient {
|
||||
}
|
||||
|
||||
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
|
||||
// for log/send, set the timeout to 5 minutes
|
||||
if path == sendLogsPath {
|
||||
c.c.Timeout = time.Minute * 5
|
||||
}
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -25,8 +27,16 @@ const (
|
||||
deactivationPath = "/deactivation"
|
||||
cdPath = "/cd"
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
type controlServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
@@ -201,15 +211,76 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
res := &ifaceResponse{Name: iface}
|
||||
// p.setDNS is only called when running as a service
|
||||
if !service.Interactive() {
|
||||
<-p.csSetDnsDone
|
||||
if p.csSetDnsOk {
|
||||
w.Write([]byte(iface))
|
||||
return
|
||||
res.Name = p.runningIface
|
||||
res.All = p.requiredMultiNICsConfig
|
||||
res.OK = true
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
if err := json.NewEncoder(w).Encode(res); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
lr, err := p.logReader()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer lr.r.Close()
|
||||
if lr.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(lr.r)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if time.Since(p.internalLogSent) < logSentInterval {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r, err := p.logReader()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
req := &controld.LogsRequest{
|
||||
UID: cdUID,
|
||||
Data: r.r,
|
||||
}
|
||||
mainLog.Load().Debug().Msg("sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
if err := controld.SendLogs(req, cdDev); err != nil {
|
||||
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("sending log file successfully")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
p.internalLogSent = time.Now()
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/logger"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
@@ -41,7 +43,7 @@ const (
|
||||
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "OS resolver",
|
||||
Type: ctrld.ResolverTypeOS,
|
||||
Timeout: 2000,
|
||||
Timeout: 3000,
|
||||
}
|
||||
|
||||
var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
@@ -50,6 +52,12 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
var localUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "Local resolver",
|
||||
Type: ctrld.ResolverTypeLocal,
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
// proxyRequest contains data for proxying a DNS query to upstream.
|
||||
type proxyRequest struct {
|
||||
msg *dns.Msg
|
||||
@@ -76,7 +84,13 @@ type upstreamForResult struct {
|
||||
srcAddr string
|
||||
}
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string) error {
|
||||
func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
|
||||
// Start network monitoring
|
||||
if err := p.monitorNetworkChanges(mainCtx); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to start network monitoring")
|
||||
// Don't return here as we still want DNS service to run
|
||||
}
|
||||
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
||||
@@ -106,11 +120,18 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
go p.detectLoop(m)
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
if domain == selfCheckInternalTestDomain {
|
||||
switch {
|
||||
case domain == "":
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeFormatError)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
case domain == selfCheckInternalTestDomain:
|
||||
answer := resolveInternalDomainTestQuery(ctx, domain, m)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
||||
p.cache.Purge()
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain)
|
||||
@@ -411,23 +432,19 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
|
||||
leaked := false
|
||||
// If ctrld is going to leak query to OS resolver, check remote upstream in background,
|
||||
// so ctrld could be back to normal operation as long as the network is back online.
|
||||
if len(upstreamConfigs) > 0 && p.leakingQuery.Load() {
|
||||
for n, uc := range upstreamConfigs {
|
||||
go p.checkUpstream(upstreams[n], uc)
|
||||
}
|
||||
upstreamConfigs = nil
|
||||
leaked = true
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams)
|
||||
}
|
||||
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
}
|
||||
|
||||
if p.isAdDomainQuery(req.msg) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(),
|
||||
"AD domain query detected for %s in domain %s",
|
||||
req.msg.Question[0].Name, p.adDomain)
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
}
|
||||
|
||||
res := &proxyResponse{}
|
||||
|
||||
// LAN/PTR lookup flow:
|
||||
@@ -438,13 +455,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
// 4. Try remote upstream.
|
||||
isLanOrPtrQuery := false
|
||||
if req.ufr.matched {
|
||||
if leaked {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v (leaked)", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
switch {
|
||||
case isSrvLookup(req.msg):
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams)
|
||||
case isPrivatePtrLookup(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
|
||||
@@ -452,7 +470,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
res.clientInfo = true
|
||||
return res
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs)
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
case isLanHostnameQuery(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
@@ -461,7 +480,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
res.clientInfo = true
|
||||
return res
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
default:
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
@@ -488,8 +509,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
staleAnswer = answer
|
||||
}
|
||||
}
|
||||
resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
|
||||
resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver")
|
||||
@@ -504,43 +525,53 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
}
|
||||
return dnsResolver.Resolve(resolveCtx, msg)
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
||||
}
|
||||
answer, err := resolve1(n, upstreamConfig, msg)
|
||||
answer, err := resolve1(upstream, upstreamConfig, msg)
|
||||
// if we have an answer, we should reset the failure count
|
||||
// we dont use reset here since we dont want to prevent failure counts from being incremented
|
||||
if answer != nil {
|
||||
p.um.mu.Lock()
|
||||
p.um.failureReq[upstream] = 0
|
||||
p.um.down[upstream] = false
|
||||
p.um.mu.Unlock()
|
||||
return answer
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
|
||||
|
||||
// increase failure count when there is no answer
|
||||
// rehardless of what kind of error we get
|
||||
p.um.increaseFailureCount(upstream)
|
||||
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
|
||||
isNetworkErr := errNetworkError(err)
|
||||
if isNetworkErr {
|
||||
p.um.increaseFailureCount(upstreams[n])
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
go p.checkUpstream(upstreams[n], upstreamConfig)
|
||||
}
|
||||
}
|
||||
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
||||
var e net.Error
|
||||
if errors.As(err, &e) && e.Timeout() {
|
||||
upstreamConfig.ReBootstrap()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return answer
|
||||
|
||||
return nil
|
||||
}
|
||||
for n, upstreamConfig := range upstreamConfigs {
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
logger := mainLog.Load().Debug().
|
||||
Str("upstream", upstreamConfig.String()).
|
||||
Str("query", req.msg.Question[0].Name).
|
||||
Bool("is_ad_query", p.isAdDomainQuery(req.msg)).
|
||||
Bool("is_lan_query", isLanOrPtrQuery)
|
||||
|
||||
if p.isLoop(upstreamConfig) {
|
||||
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint)
|
||||
ctrld.Log(ctx, logger, "DNS loop detected")
|
||||
continue
|
||||
}
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, req.msg)
|
||||
answer := resolve(upstreams[n], upstreamConfig, req.msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
|
||||
@@ -587,21 +618,49 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
return res
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
||||
if cdUID != "" && p.leakOnUpstreamFailure() {
|
||||
p.leakingQueryMu.Lock()
|
||||
if !p.leakingQueryWasRun {
|
||||
p.leakingQueryWasRun = true
|
||||
go p.performLeakingQuery()
|
||||
|
||||
// if we have no healthy upstreams, trigger recovery flow
|
||||
if p.recoverOnUpstreamFailure() {
|
||||
if p.um.countHealthy(upstreams) == 0 {
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel == nil {
|
||||
var reason RecoveryReason
|
||||
if upstreams[0] == upstreamOS {
|
||||
reason = RecoveryReasonOSFailure
|
||||
} else {
|
||||
reason = RecoveryReasonRegularFailure
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
||||
go p.handleRecovery(reason)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
||||
}
|
||||
p.recoveryCancelMu.Unlock()
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
|
||||
}
|
||||
p.leakingQueryMu.Unlock()
|
||||
}
|
||||
|
||||
// attempt query to OS resolver while as a retry catch all
|
||||
if upstreams[0] != upstreamOS {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
|
||||
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
|
||||
if answer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
|
||||
res.answer = answer
|
||||
res.upstream = osUpstreamConfig.Endpoint
|
||||
return res
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
|
||||
}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(req.msg, dns.RcodeServerFailure)
|
||||
res.answer = answer
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
||||
func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
||||
if len(p.localUpstreams) > 0 {
|
||||
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
||||
tmp = append(tmp, p.localUpstreams...)
|
||||
@@ -620,6 +679,14 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U
|
||||
return upstreamConfigs
|
||||
}
|
||||
|
||||
func (p *prog) isAdDomainQuery(msg *dns.Msg) bool {
|
||||
if p.adDomain == "" {
|
||||
return false
|
||||
}
|
||||
cDomainName := canonicalName(msg.Question[0].Name)
|
||||
return dns.IsSubDomain(p.adDomain, cDomainName)
|
||||
}
|
||||
|
||||
// canonicalName returns canonical name from FQDN with "." trimmed.
|
||||
func canonicalName(fqdn string) string {
|
||||
q := strings.TrimSpace(fqdn)
|
||||
@@ -916,18 +983,6 @@ func (p *prog) selfUninstallCoolOfPeriod() {
|
||||
p.selfUninstallMu.Unlock()
|
||||
}
|
||||
|
||||
// performLeakingQuery performs necessary works to leak queries to OS resolver.
|
||||
func (p *prog) performLeakingQuery() {
|
||||
mainLog.Load().Warn().Msg("leaking query to OS resolver")
|
||||
// Signal dns watchers to stop, so changes made below won't be reverted.
|
||||
p.leakingQuery.Store(true)
|
||||
p.resetDNS()
|
||||
ns := ctrld.InitializeOsResolver()
|
||||
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
|
||||
p.dnsWg.Wait()
|
||||
p.setDNS()
|
||||
}
|
||||
|
||||
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
|
||||
// and the domain == "cdUID.verify.controld.com"
|
||||
func (p *prog) forceFetchingAPI(domain string) {
|
||||
@@ -1056,7 +1111,16 @@ func isLanHostnameQuery(m *dns.Msg) bool {
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
return !strings.Contains(name, ".") ||
|
||||
strings.HasSuffix(name, ".domain") ||
|
||||
strings.HasSuffix(name, ".lan")
|
||||
strings.HasSuffix(name, ".lan") ||
|
||||
strings.HasSuffix(name, ".local")
|
||||
}
|
||||
|
||||
// isSrvLookup reports whether DNS message is a SRV query.
|
||||
func isSrvLookup(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
return m.Question[0].Qtype == dns.TypeSRV
|
||||
}
|
||||
|
||||
// isWanClient reports whether the input is a WAN address.
|
||||
@@ -1089,3 +1153,406 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M
|
||||
answer.SetReply(m)
|
||||
return answer
|
||||
}
|
||||
|
||||
// FlushDNSCache flushes the DNS cache on macOS.
|
||||
func FlushDNSCache() error {
|
||||
// if not macOS, return
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush the DNS cache via mDNSResponder.
|
||||
// This is typically needed on modern macOS systems.
|
||||
if out, err := exec.Command("killall", "-HUP", "mDNSResponder").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to flush mDNSResponder: %w, output: %s", err, string(out))
|
||||
}
|
||||
|
||||
// Optionally, flush the directory services cache.
|
||||
if out, err := exec.Command("dscacheutil", "-flushcache").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to flush dscacheutil: %w, output: %s", err, string(out))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorNetworkChanges starts monitoring for network interface changes
|
||||
func (p *prog) monitorNetworkChanges(ctx context.Context) error {
|
||||
mon, err := netmon.New(logger.WithPrefix(mainLog.Load().Printf, "netmon: "))
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating network monitor: %w", err)
|
||||
}
|
||||
|
||||
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
|
||||
// Get map of valid interfaces
|
||||
validIfaces := validInterfacesMap()
|
||||
|
||||
isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New)
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Interface("old_state", delta.Old).
|
||||
Interface("new_state", delta.New).
|
||||
Bool("is_major_change", isMajorChange).
|
||||
Msg("Network change detected")
|
||||
|
||||
changed := false
|
||||
activeInterfaceExists := false
|
||||
var changeIPs []netip.Prefix
|
||||
// Check each valid interface for changes
|
||||
for ifaceName := range validIfaces {
|
||||
oldIface, oldExists := delta.Old.Interface[ifaceName]
|
||||
newIface, newExists := delta.New.Interface[ifaceName]
|
||||
if !newExists {
|
||||
continue
|
||||
}
|
||||
|
||||
oldIPs := delta.Old.InterfaceIPs[ifaceName]
|
||||
newIPs := delta.New.InterfaceIPs[ifaceName]
|
||||
|
||||
// if a valid interface did not exist in old
|
||||
// check that its up and has usable IPs
|
||||
if !oldExists {
|
||||
// The interface is new (was not present in the old state).
|
||||
usableNewIPs := filterUsableIPs(newIPs)
|
||||
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
||||
changed = true
|
||||
changeIPs = usableNewIPs
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", ifaceName).
|
||||
Interface("new_ips", usableNewIPs).
|
||||
Msg("Interface newly appeared (was not present in old state)")
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter new IPs to only those that are usable.
|
||||
usableNewIPs := filterUsableIPs(newIPs)
|
||||
|
||||
// Check if interface is up and has usable IPs.
|
||||
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
||||
activeInterfaceExists = true
|
||||
}
|
||||
|
||||
// Compare interface states and IPs (interfaceIPsEqual will itself filter the IPs).
|
||||
if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) {
|
||||
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
||||
changed = true
|
||||
changeIPs = usableNewIPs
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", ifaceName).
|
||||
Interface("old_ips", oldIPs).
|
||||
Interface("new_ips", usableNewIPs).
|
||||
Msg("Interface state or IPs changed")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
||||
return
|
||||
}
|
||||
|
||||
if !activeInterfaceExists {
|
||||
mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization")
|
||||
return
|
||||
}
|
||||
|
||||
// Get IPs from default route interface in new state
|
||||
selfIP := defaultRouteIP()
|
||||
var ipv6 string
|
||||
|
||||
if delta.New.DefaultRouteInterface != "" {
|
||||
mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface])
|
||||
for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] {
|
||||
ipAddr, _ := netip.ParsePrefix(ip.String())
|
||||
addr := ipAddr.Addr()
|
||||
if selfIP == "" && addr.Is4() {
|
||||
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
|
||||
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
selfIP = addr.String()
|
||||
}
|
||||
}
|
||||
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
ipv6 = addr.String()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no default route interface is set yet, use the changed IPs
|
||||
mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs)
|
||||
for _, ip := range changeIPs {
|
||||
ipAddr, _ := netip.ParsePrefix(ip.String())
|
||||
addr := ipAddr.Addr()
|
||||
if selfIP == "" && addr.Is4() {
|
||||
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
|
||||
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
selfIP = addr.String()
|
||||
}
|
||||
}
|
||||
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
ipv6 = addr.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(selfIP); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv4(ip)
|
||||
if !isMobile() && p.ciTable != nil {
|
||||
p.ciTable.SetSelfIP(selfIP)
|
||||
}
|
||||
}
|
||||
if ip := net.ParseIP(ipv6); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv6(ip)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
|
||||
if p.recoverOnUpstreamFailure() {
|
||||
p.handleRecovery(RecoveryReasonNetworkChange)
|
||||
}
|
||||
})
|
||||
|
||||
mon.Start()
|
||||
mainLog.Load().Debug().Msg("Network monitor started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// interfaceStatesEqual compares two interface states
|
||||
func interfaceStatesEqual(a, b *netmon.Interface) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == b
|
||||
}
|
||||
return a.IsUp() == b.IsUp()
|
||||
}
|
||||
|
||||
// filterUsableIPs is a helper that returns only "usable" IP prefixes,
|
||||
// filtering out link-local, loopback, multicast, unspecified, broadcast, or CGNAT addresses.
|
||||
func filterUsableIPs(prefixes []netip.Prefix) []netip.Prefix {
|
||||
var usable []netip.Prefix
|
||||
for _, p := range prefixes {
|
||||
addr := p.Addr()
|
||||
if addr.IsLinkLocalUnicast() ||
|
||||
addr.IsLoopback() ||
|
||||
addr.IsMulticast() ||
|
||||
addr.IsUnspecified() ||
|
||||
addr.IsLinkLocalMulticast() ||
|
||||
(addr.Is4() && addr.String() == "255.255.255.255") ||
|
||||
tsaddr.CGNATRange().Contains(addr) {
|
||||
continue
|
||||
}
|
||||
usable = append(usable, p)
|
||||
}
|
||||
return usable
|
||||
}
|
||||
|
||||
// Modified interfaceIPsEqual compares only the usable (non-link local, non-loopback, etc.) IP addresses.
|
||||
func interfaceIPsEqual(a, b []netip.Prefix) bool {
|
||||
aUsable := filterUsableIPs(a)
|
||||
bUsable := filterUsableIPs(b)
|
||||
if len(aUsable) != len(bUsable) {
|
||||
return false
|
||||
}
|
||||
|
||||
aMap := make(map[string]bool)
|
||||
for _, ip := range aUsable {
|
||||
aMap[ip.String()] = true
|
||||
}
|
||||
for _, ip := range bUsable {
|
||||
if !aMap[ip.String()] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkUpstreamOnce sends a test query to the specified upstream.
|
||||
// Returns nil if the upstream responds successfully.
|
||||
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
|
||||
mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream)
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
||||
return err
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
|
||||
timeout := 1000 * time.Millisecond
|
||||
if uc.Timeout > 0 {
|
||||
timeout = time.Millisecond * time.Duration(uc.Timeout)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
uc.ReBootstrap()
|
||||
mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
||||
|
||||
start := time.Now()
|
||||
_, err = resolver.Resolve(ctx, msg)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleRecovery performs a unified recovery by removing DNS settings,
|
||||
// canceling existing recovery checks for network changes, but coalescing duplicate
|
||||
// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout),
|
||||
// and then re-applying the DNS settings.
|
||||
func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings")
|
||||
|
||||
// For network changes, cancel any existing recovery check because the network state has changed.
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel != nil {
|
||||
mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)")
|
||||
p.recoveryCancel()
|
||||
p.recoveryCancel = nil
|
||||
}
|
||||
p.recoveryCancelMu.Unlock()
|
||||
} else {
|
||||
// For upstream failures, if a recovery is already in progress, do nothing new.
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel != nil {
|
||||
mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
|
||||
p.recoveryCancelMu.Unlock()
|
||||
return
|
||||
}
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
// Create a new recovery context without a fixed timeout.
|
||||
p.recoveryCancelMu.Lock()
|
||||
recoveryCtx, cancel := context.WithCancel(context.Background())
|
||||
p.recoveryCancel = cancel
|
||||
p.recoveryCancelMu.Unlock()
|
||||
|
||||
// Immediately remove our DNS settings from the interface.
|
||||
// set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
|
||||
p.recoveryRunning.Store(true)
|
||||
p.resetDNS()
|
||||
|
||||
// For an OS failure, reinitialize OS resolver nameservers immediately.
|
||||
if reason == RecoveryReasonOSFailure {
|
||||
mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers")
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
|
||||
// Build upstream map based on the recovery reason.
|
||||
upstreams := p.buildRecoveryUpstreams(reason)
|
||||
|
||||
// Wait indefinitely until one of the upstreams recovers.
|
||||
recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed")
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = nil
|
||||
p.recoveryCancelMu.Unlock()
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered)
|
||||
|
||||
// reset the upstream failure count and down state
|
||||
p.um.reset(recovered)
|
||||
|
||||
// For network changes we also reinitialize the OS resolver.
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply our DNS settings back and log the interface state.
|
||||
p.setDNS()
|
||||
p.logInterfacesState()
|
||||
|
||||
// allow watchdogs to put the listener back on the interface if its changed for any reason
|
||||
p.recoveryRunning.Store(false)
|
||||
|
||||
// Clear the recovery cancellation for a clean slate.
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = nil
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers.
|
||||
// It returns the name of the recovered upstream or an error if the check times out.
|
||||
func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) {
|
||||
recoveredCh := make(chan string, 1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
|
||||
|
||||
for name, uc := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(name string, uc *ctrld.UpstreamConfig) {
|
||||
defer wg.Done()
|
||||
mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name)
|
||||
return
|
||||
default:
|
||||
// checkUpstreamOnce will reset any failure counters on success.
|
||||
if err := p.checkUpstreamOnce(name, uc); err == nil {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name)
|
||||
select {
|
||||
case recoveredCh <- name:
|
||||
mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name)
|
||||
default:
|
||||
mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered")
|
||||
}
|
||||
return
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
}
|
||||
}
|
||||
}(name, uc)
|
||||
}
|
||||
|
||||
var recovered string
|
||||
select {
|
||||
case recovered = <-recoveredCh:
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
wg.Wait()
|
||||
return recovered, nil
|
||||
}
|
||||
|
||||
// buildRecoveryUpstreams constructs the map of upstream configurations to test.
|
||||
// For OS failures we supply the manual OS resolver upstream configuration.
|
||||
// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS).
|
||||
func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig {
|
||||
upstreams := make(map[string]*ctrld.UpstreamConfig)
|
||||
switch reason {
|
||||
case RecoveryReasonOSFailure:
|
||||
upstreams[upstreamOS] = osUpstreamConfig
|
||||
case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure:
|
||||
// Use all configured upstreams except any OS type.
|
||||
for k, uc := range p.cfg.Upstream {
|
||||
if uc.Type != ctrld.ResolverTypeOS {
|
||||
upstreams[upstreamPrefix+k] = uc
|
||||
}
|
||||
}
|
||||
}
|
||||
return upstreams
|
||||
}
|
||||
|
||||
@@ -75,6 +75,7 @@ func Test_canonicalName(t *testing.T) {
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
@@ -365,6 +366,9 @@ func Test_isLanHostnameQuery(t *testing.T) {
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
||||
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
||||
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
@@ -414,6 +418,26 @@ func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSrvLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isSrvLookup bool
|
||||
}{
|
||||
{"SRV", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isSrvLookup(tc.msg); tc.isSrvLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
186
cmd/cli/log_writer.go
Normal file
186
cmd/cli/log_writer.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
logSentInterval = time.Minute
|
||||
logStartEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
logLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
logWarnEndMarker = "\n\n=== WARN_END ===\n\n"
|
||||
)
|
||||
|
||||
type logViewResponse struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type logSentResponse struct {
|
||||
Size int64 `json:"size"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type logReader struct {
|
||||
r io.ReadCloser
|
||||
size int64
|
||||
}
|
||||
|
||||
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
|
||||
type logWriter struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
}
|
||||
|
||||
// newLogWriter creates an internal log writer.
|
||||
func newLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSize)
|
||||
}
|
||||
|
||||
// newSmallLogWriter creates an internal log writer with small buffer size.
|
||||
func newSmallLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSmallSize)
|
||||
}
|
||||
|
||||
// newLogWriterWithSize creates an internal log writer with a given buffer size.
|
||||
func newLogWriterWithSize(size int) *logWriter {
|
||||
lw := &logWriter{size: size}
|
||||
return lw
|
||||
}
|
||||
|
||||
func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
|
||||
// If writing p causes overflows, discard old data.
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
}
|
||||
lw.buf.Reset()
|
||||
lw.buf.Write(buf)
|
||||
lw.buf.WriteString(logStartEndMarker) // indicate that the log was truncated.
|
||||
}
|
||||
// If p is bigger than buffer size, truncate p by half until its size is smaller.
|
||||
for len(p)+lw.buf.Len() > lw.size {
|
||||
p = p[len(p)/2:]
|
||||
}
|
||||
return lw.buf.Write(p)
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
func (p *prog) initInternalLogging(writers []io.Writer) {
|
||||
if !p.needInternalLogging() {
|
||||
return
|
||||
}
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
mainLog.Load().Notice().Msg("internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
})
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
// If ctrld was run without explicit verbose level,
|
||||
// run the internal logging at debug level, so we could
|
||||
// have enough information for troubleshooting.
|
||||
if verbose == 0 {
|
||||
for i := range writers {
|
||||
w := &zerolog.FilteredLevelWriter{
|
||||
Writer: zerolog.LevelWriterAdapter{Writer: writers[i]},
|
||||
Level: zerolog.NoticeLevel,
|
||||
}
|
||||
writers[i] = w
|
||||
}
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
writers = append(writers, lw)
|
||||
writers = append(writers, &zerolog.FilteredLevelWriter{
|
||||
Writer: zerolog.LevelWriterAdapter{Writer: wlw},
|
||||
Level: zerolog.WarnLevel,
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// needInternalLogging reports whether prog needs to run internal logging.
|
||||
func (p *prog) needInternalLogging() bool {
|
||||
// Do not run in non-cd mode.
|
||||
if cdUID == "" {
|
||||
return false
|
||||
}
|
||||
// Do not run if there's already log file.
|
||||
if p.cfg.Service.LogPath != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *prog) logReader() (*logReader, error) {
|
||||
if p.needInternalLogging() {
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
return nil, errors.New("nil internal log writer")
|
||||
}
|
||||
if wlw == nil {
|
||||
return nil, errors.New("nil internal warn log writer")
|
||||
}
|
||||
// Normal log content.
|
||||
lw.mu.Lock()
|
||||
lwReader := bytes.NewReader(lw.buf.Bytes())
|
||||
lwSize := lw.buf.Len()
|
||||
lw.mu.Unlock()
|
||||
// Warn log content.
|
||||
wlw.mu.Lock()
|
||||
wlwReader := bytes.NewReader(wlw.buf.Bytes())
|
||||
wlwSize := wlw.buf.Len()
|
||||
wlw.mu.Unlock()
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logLogEndMarker)), wlwReader)
|
||||
lr := &logReader{r: io.NopCloser(reader)}
|
||||
lr.size = int64(lwSize + wlwSize)
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("internal log is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
if p.cfg.Service.LogPath == "" {
|
||||
return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil
|
||||
}
|
||||
f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lr := &logReader{r: f}
|
||||
if st, err := f.Stat(); err == nil {
|
||||
lr.size = st.Size()
|
||||
} else {
|
||||
return nil, fmt.Errorf("f.Stat: %w", err)
|
||||
}
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("log file is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
49
cmd/cli/log_writer_test.go
Normal file
49
cmd/cli/log_writer_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_logWriter_Write(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
data := strings.Repeat("A", size)
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
newData := "B"
|
||||
halfData := strings.Repeat("A", len(data)/2) + logStartEndMarker
|
||||
lw.Write([]byte(newData))
|
||||
if lw.buf.String() != halfData+newData {
|
||||
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
|
||||
}
|
||||
|
||||
bigData := strings.Repeat("B", 256*1024)
|
||||
expected := halfData + strings.Repeat("B", 16*1024)
|
||||
lw.Write([]byte(bigData))
|
||||
if lw.buf.String() != expected {
|
||||
t.Fatalf("unexpected big buf content: %v", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_ConcurrentWrite(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lw.Write([]byte(strings.Repeat("A", i)))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if lw.buf.Len() > lw.size {
|
||||
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
|
||||
}
|
||||
}
|
||||
@@ -101,9 +101,23 @@ func initConsoleLogging() {
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func initLogging() {
|
||||
func initLogging() []io.Writer {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(true)
|
||||
return initLoggingWithBackup(true)
|
||||
}
|
||||
|
||||
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
|
||||
// to be used for all interactive commands.
|
||||
//
|
||||
// Current log file config will also be ignored.
|
||||
func initInteractiveLogging() {
|
||||
old := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(false)
|
||||
cfg.Service.LogPath = old
|
||||
l := zerolog.New(io.Discard)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// initLoggingWithBackup initializes log setup base on current config.
|
||||
@@ -112,8 +126,8 @@ func initLogging() {
|
||||
// This is only used in runCmd for special handling in case of logging config
|
||||
// change in cd mode. Without special reason, the caller should use initLogging
|
||||
// wrapper instead of calling this function directly.
|
||||
func initLoggingWithBackup(doBackup bool) {
|
||||
writers := []io.Writer{io.Discard}
|
||||
func initLoggingWithBackup(doBackup bool) []io.Writer {
|
||||
var writers []io.Writer
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
@@ -151,21 +165,22 @@ func initLoggingWithBackup(doBackup bool) {
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
return
|
||||
return writers
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
if logLevel == "" {
|
||||
return
|
||||
return writers
|
||||
}
|
||||
level, err := zerolog.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set log level")
|
||||
return
|
||||
return writers
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
return writers
|
||||
}
|
||||
|
||||
func initCache() {
|
||||
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface")
|
||||
}
|
||||
return nil
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
|
||||
52
cmd/cli/net_linux.go
Normal file
52
cmd/cli/net_linux.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// Only non-virtual interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set containing non virtual interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
vis := virtualInterfaces()
|
||||
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
||||
if _, existed := vis[i.Name]; existed {
|
||||
return
|
||||
}
|
||||
m[i.Name] = struct{}{}
|
||||
})
|
||||
// Fallback to default route interface if found nothing.
|
||||
if len(m) == 0 {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return m
|
||||
}
|
||||
m[defaultRoute.InterfaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// virtualInterfaces returns a map of virtual interfaces on current machine.
|
||||
func virtualInterfaces() map[string]struct{} {
|
||||
s := make(map[string]struct{})
|
||||
entries, _ := os.ReadDir("/sys/devices/virtual/net")
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
s[strings.TrimSpace(entry.Name())] = struct{}{}
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -1,11 +1,22 @@
|
||||
//go:build !darwin && !windows
|
||||
//go:build !darwin && !windows && !linux
|
||||
|
||||
package cli
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||
|
||||
func validInterfacesMap() map[string]struct{} { return nil }
|
||||
// validInterfacesMap returns a set containing only default route interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]struct{}{defaultRoute.InterfaceName: {}}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"os"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
return nil
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
@@ -20,15 +26,68 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
|
||||
|
||||
// validInterfacesMap returns a set of all physical interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
m := make(map[string]struct{})
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
for _, ifaceName := range validInterfaces() {
|
||||
m[ifaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// validInterfaces returns a list of all physical interfaces.
|
||||
func validInterfaces() []string {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
|
||||
return nil
|
||||
}
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := adapter.GetPropertyName()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
|
||||
continue
|
||||
}
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||
//
|
||||
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
||||
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||
physical, err := adapter.GetPropertyConnectorPresent()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
|
||||
continue
|
||||
}
|
||||
if !physical {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a hardware interface. Checking only for connector present is not enough
|
||||
// because some interfaces are not physical but have a connector.
|
||||
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
|
||||
continue
|
||||
}
|
||||
if !hardware {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
|
||||
continue
|
||||
}
|
||||
|
||||
adapters = append(adapters, name)
|
||||
}
|
||||
return adapters
|
||||
}
|
||||
|
||||
42
cmd/cli/net_windows_test.go
Normal file
42
cmd/cli/net_windows_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_validInterfaces(t *testing.T) {
|
||||
verbose = 3
|
||||
initConsoleLogging()
|
||||
start := time.Now()
|
||||
ifaces := validInterfaces()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
ifacesPowershell := validInterfacesPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(ifaces)
|
||||
slices.Sort(ifacesPowershell)
|
||||
if !slices.Equal(ifaces, ifacesPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces)
|
||||
}
|
||||
}
|
||||
|
||||
func validInterfacesPowershell() []string {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var res []string
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
res = append(res, ifaceName)
|
||||
}
|
||||
return res
|
||||
}
|
||||
@@ -70,11 +70,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
|
||||
// TODO(cuonglm): use system API
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
if err := setDNS(iface, ns); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name, "empty"}
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
@@ -83,6 +78,15 @@ func resetDNS(iface *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
err = setDNS(iface, ns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
}
|
||||
|
||||
@@ -76,6 +76,12 @@ func resetDNS(iface *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
}
|
||||
|
||||
@@ -195,6 +195,12 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
})
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const (
|
||||
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -30,14 +34,6 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
func setDnsPowershellCmd(iface *net.Interface, nameservers []string) string {
|
||||
nss := make([]string, 0, len(nameservers))
|
||||
for _, ns := range nameservers {
|
||||
nss = append(nss, strconv.Quote(ns))
|
||||
}
|
||||
return fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ServerAddresses (%s)", iface.Index, strings.Join(nss, ","))
|
||||
}
|
||||
|
||||
// setDNS sets the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
if len(nameservers) == 0 {
|
||||
@@ -46,7 +42,7 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
setDNSOnce.Do(func() {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
if hasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
oldForwardersContent, _ := os.ReadFile(file)
|
||||
hasLocalIPv6Listener := needLocalIPv6Listener()
|
||||
@@ -65,9 +61,36 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
}
|
||||
}
|
||||
})
|
||||
out, err := powershell(setDnsPowershellCmd(iface, nameservers))
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
return fmt.Errorf("setDNS: %w", err)
|
||||
}
|
||||
var (
|
||||
serversV4 []netip.Addr
|
||||
serversV6 []netip.Addr
|
||||
)
|
||||
for _, ns := range nameservers {
|
||||
if addr, err := netip.ParseAddr(ns); err == nil {
|
||||
if addr.Is4() {
|
||||
serversV4 = append(serversV4, addr)
|
||||
} else {
|
||||
serversV6 = append(serversV6, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(serversV4) == 0 && len(serversV6) == 0 {
|
||||
return errors.New("invalid DNS nameservers")
|
||||
}
|
||||
if len(serversV4) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
||||
}
|
||||
}
|
||||
if len(serversV6) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -81,7 +104,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
resetDNSOnce.Do(func() {
|
||||
// See corresponding comment in setDNS.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
if hasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
@@ -96,14 +119,23 @@ func resetDNS(iface *net.Interface) error {
|
||||
}
|
||||
})
|
||||
|
||||
// Restoring DHCP settings.
|
||||
cmd := fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ResetServerAddresses", iface.Index)
|
||||
out, err := powershell(cmd)
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
return fmt.Errorf("resetDNS: %w", err)
|
||||
}
|
||||
// Restoring DHCP settings.
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
||||
}
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If there's static DNS saved, restoring it.
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
||||
v4ns := make([]string, 0, 2)
|
||||
v6ns := make([]string, 0, 2)
|
||||
@@ -120,12 +152,14 @@ func resetDNS(iface *net.Interface) error {
|
||||
continue
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
err = setDNS(iface, ns)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
@@ -150,25 +184,31 @@ func currentDNS(iface *net.Interface) []string {
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("luid.GUID: %w", err)
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
interfaceKeyPath := path + guid.String()
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", interfaceKeyPath, err)
|
||||
}
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
value, _, err := k.GetStringValue(key)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
return nil, fmt.Errorf("%s: %w", key, err)
|
||||
}
|
||||
if len(value) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
for _, e := range strings.Split(value, ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
@@ -216,3 +256,9 @@ func removeDnsServerForwarders(nameservers []string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
|
||||
68
cmd/cli/os_windows_test.go
Normal file
68
cmd/cli/os_windows_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func Test_currentStaticDNS(t *testing.T) {
|
||||
iface, err := net.InterfaceByName(defaultIfaceName())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
start := time.Now()
|
||||
staticDns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(staticDns)
|
||||
slices.Sort(staticDnsPowershell)
|
||||
if !slices.Equal(staticDns, staticDnsPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
|
||||
}
|
||||
}
|
||||
|
||||
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
376
cmd/cli/prog.go
376
cmd/cli/prog.go
@@ -45,6 +45,18 @@ const (
|
||||
upstreamOS = upstreamPrefix + "os"
|
||||
upstreamPrivate = upstreamPrefix + "private"
|
||||
dnsWatchdogDefaultInterval = 20 * time.Second
|
||||
ctrldServiceName = "ctrld"
|
||||
)
|
||||
|
||||
// RecoveryReason provides context for why we are waiting for recovery.
|
||||
// recovery involves removing the listener IP from the interface and
|
||||
// waiting for the upstreams to work before returning
|
||||
type RecoveryReason int
|
||||
|
||||
const (
|
||||
RecoveryReasonNetworkChange RecoveryReason = iota
|
||||
RecoveryReasonRegularFailure
|
||||
RecoveryReasonOSFailure
|
||||
)
|
||||
|
||||
// ControlSocketName returns name for control unix socket.
|
||||
@@ -61,8 +73,9 @@ var logf = func(format string, args ...any) {
|
||||
}
|
||||
|
||||
var svcConfig = &service.Config{
|
||||
Name: "ctrld",
|
||||
Name: ctrldServiceName,
|
||||
DisplayName: "Control-D Helper Service",
|
||||
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
|
||||
Option: service.KeyValue{},
|
||||
}
|
||||
|
||||
@@ -84,21 +97,29 @@ type prog struct {
|
||||
dnsWg sync.WaitGroup
|
||||
dnsWatcherClosedOnce sync.Once
|
||||
dnsWatcherStopCh chan struct{}
|
||||
rc *controld.ResolverConfig
|
||||
|
||||
cfg *ctrld.Config
|
||||
localUpstreams []string
|
||||
ptrNameservers []string
|
||||
appCallback *AppCallback
|
||||
cache dnscache.Cacher
|
||||
cacheFlushDomainsMap map[string]struct{}
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
um *upstreamMonitor
|
||||
router router.Router
|
||||
ptrLoopGuard *loopGuard
|
||||
lanLoopGuard *loopGuard
|
||||
metricsQueryStats atomic.Bool
|
||||
queryFromSelfMap sync.Map
|
||||
cfg *ctrld.Config
|
||||
localUpstreams []string
|
||||
ptrNameservers []string
|
||||
appCallback *AppCallback
|
||||
cache dnscache.Cacher
|
||||
cacheFlushDomainsMap map[string]struct{}
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
um *upstreamMonitor
|
||||
router router.Router
|
||||
ptrLoopGuard *loopGuard
|
||||
lanLoopGuard *loopGuard
|
||||
metricsQueryStats atomic.Bool
|
||||
queryFromSelfMap sync.Map
|
||||
initInternalLogWriterOnce sync.Once
|
||||
internalLogWriter *logWriter
|
||||
internalWarnLogWriter *logWriter
|
||||
internalLogSent time.Time
|
||||
runningIface string
|
||||
requiredMultiNICsConfig bool
|
||||
adDomain string
|
||||
|
||||
selfUninstallMu sync.Mutex
|
||||
refusedQueryCount int
|
||||
@@ -108,9 +129,9 @@ type prog struct {
|
||||
loopMu sync.Mutex
|
||||
loop map[string]bool
|
||||
|
||||
leakingQueryMu sync.Mutex
|
||||
leakingQueryWasRun bool
|
||||
leakingQuery atomic.Bool
|
||||
recoveryCancelMu sync.Mutex
|
||||
recoveryCancel context.CancelFunc
|
||||
recoveryRunning atomic.Bool
|
||||
|
||||
started chan struct{}
|
||||
onStartedDone chan struct{}
|
||||
@@ -162,11 +183,13 @@ func (p *prog) runWait() {
|
||||
|
||||
if newCfg == nil {
|
||||
newCfg = &ctrld.Config{}
|
||||
confFile := v.ConfigFileUsed()
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
confFile = configPath
|
||||
}
|
||||
v.SetConfigFile(confFile)
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
logger.Err(err).Msg("could not read new config")
|
||||
waitOldRunDone()
|
||||
@@ -178,10 +201,14 @@ func (p *prog) runWait() {
|
||||
continue
|
||||
}
|
||||
if cdUID != "" {
|
||||
if err := processCDFlags(newCfg); err != nil {
|
||||
if rc, err := processCDFlags(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not fetch ControlD config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
} else {
|
||||
p.mu.Lock()
|
||||
p.rc = rc
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,6 +260,11 @@ func (p *prog) runWait() {
|
||||
}
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if iface == "auto" {
|
||||
iface = defaultIfaceName()
|
||||
p.requiredMultiNICsConfig = requiredMultiNICsConfig()
|
||||
}
|
||||
p.runningIface = iface
|
||||
if runtime.GOOS == "darwin" {
|
||||
p.onStopped = append(p.onStopped, func() {
|
||||
if !service.Interactive() {
|
||||
@@ -245,11 +277,12 @@ func (p *prog) preRun() {
|
||||
func (p *prog) postRun() {
|
||||
if !service.Interactive() {
|
||||
p.resetDNS()
|
||||
ns := ctrld.InitializeOsResolver()
|
||||
ns := ctrld.InitializeOsResolver(false)
|
||||
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
|
||||
p.setDNS()
|
||||
p.csSetDnsDone <- struct{}{}
|
||||
close(p.csSetDnsDone)
|
||||
p.logInterfacesState()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,7 +321,24 @@ func (p *prog) apiConfigReload() {
|
||||
cdDeactivationPin.Store(defaultDeactivationPin)
|
||||
}
|
||||
|
||||
if resolverConfig.Ctrld.CustomConfig == "" {
|
||||
p.mu.Lock()
|
||||
rc := p.rc
|
||||
p.rc = resolverConfig
|
||||
p.mu.Unlock()
|
||||
noCustomConfig := resolverConfig.Ctrld.CustomConfig == ""
|
||||
noExcludeListChanged := true
|
||||
if rc != nil {
|
||||
slices.Sort(rc.Exclude)
|
||||
slices.Sort(resolverConfig.Exclude)
|
||||
noExcludeListChanged = slices.Equal(rc.Exclude, resolverConfig.Exclude)
|
||||
}
|
||||
if noCustomConfig && noExcludeListChanged {
|
||||
return
|
||||
}
|
||||
|
||||
if noCustomConfig && !noExcludeListChanged {
|
||||
logger.Debug().Msg("exclude list changes detected, reloading...")
|
||||
p.apiReloadCh <- nil
|
||||
return
|
||||
}
|
||||
|
||||
@@ -401,6 +451,10 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if domain, err := getActiveDirectoryDomain(); err == nil && domain != "" && hasLocalDnsServerRunning() {
|
||||
mainLog.Load().Debug().Msgf("active directory domain: %s", domain)
|
||||
p.adDomain = domain
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(p.cfg.Listener))
|
||||
@@ -429,12 +483,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
}
|
||||
}
|
||||
p.setupUpstream(p.cfg)
|
||||
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers)
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
p.setupClientInfoDiscover(defaultRouteIP())
|
||||
}
|
||||
|
||||
// context for managing spawn goroutines.
|
||||
@@ -446,8 +495,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
p.ciTable.Init()
|
||||
p.ciTable.RefreshLoop(ctx)
|
||||
p.runClientInfoDiscover(ctx)
|
||||
}()
|
||||
go p.watchLinkState(ctx)
|
||||
}
|
||||
@@ -463,9 +511,10 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
}
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
if err := p.serveDNS(listenerNum); err != nil {
|
||||
if err := p.serveDNS(ctx, listenerNum); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("end of serveDNS listener.%s: %s", listenerNum, addr)
|
||||
}(listenerNum)
|
||||
}
|
||||
go func() {
|
||||
@@ -511,16 +560,33 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
if !reload {
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
initLoggingWithBackup(false)
|
||||
logWriters := initLoggingWithBackup(false)
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
go p.apiConfigReload()
|
||||
p.postRun()
|
||||
p.initInternalLogging(logWriters)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// setupClientInfoDiscover performs necessary works for running client info discover.
|
||||
func (p *prog) setupClientInfoDiscover(selfIP string) {
|
||||
p.ciTable = clientinfo.NewTable(&cfg, selfIP, cdUID, p.ptrNameservers)
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
}
|
||||
|
||||
// runClientInfoDiscover runs the client info discover.
|
||||
func (p *prog) runClientInfoDiscover(ctx context.Context) {
|
||||
p.ciTable.Init()
|
||||
p.ciTable.RefreshLoop(ctx)
|
||||
}
|
||||
|
||||
// metricsEnabled reports whether prometheus exporter is enabled/disabled.
|
||||
func (p *prog) metricsEnabled() bool {
|
||||
return p.cfg.Service.MetricsQueryStats || p.cfg.Service.MetricsListener != ""
|
||||
@@ -529,7 +595,9 @@ func (p *prog) metricsEnabled() bool {
|
||||
func (p *prog) Stop(s service.Service) error {
|
||||
p.stopDnsWatchers()
|
||||
mainLog.Load().Debug().Msg("dns watchers stopped")
|
||||
mainLog.Load().Info().Msg("Service stopped")
|
||||
defer func() {
|
||||
mainLog.Load().Info().Msg("Service stopped")
|
||||
}()
|
||||
close(p.stopCh)
|
||||
if err := p.deAllocateIP(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("de-allocate ip failed")
|
||||
@@ -579,27 +647,42 @@ func (p *prog) setDNS() {
|
||||
if cfg.Listener == nil {
|
||||
return
|
||||
}
|
||||
if iface == "" {
|
||||
if p.runningIface == "" {
|
||||
return
|
||||
}
|
||||
runningIface := iface
|
||||
|
||||
// allIfaces tracks whether we should set DNS for all physical interfaces.
|
||||
allIfaces := false
|
||||
if runningIface == "auto" {
|
||||
runningIface = defaultIfaceName()
|
||||
// If runningIface is "auto", it means user does not specify "--iface" flag.
|
||||
// In this case, ctrld has to set DNS for all physical interfaces, so
|
||||
// thing will still work when user switch from one to the other.
|
||||
allIfaces = requiredMultiNICsConfig()
|
||||
}
|
||||
allIfaces := p.requiredMultiNICsConfig
|
||||
lc := cfg.FirstListener()
|
||||
if lc == nil {
|
||||
return
|
||||
}
|
||||
logger := mainLog.Load().With().Str("iface", runningIface).Logger()
|
||||
netIface, err := netInterface(runningIface)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("could not get interface")
|
||||
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
|
||||
const maxDNSRetryAttempts = 3
|
||||
const retryDelay = 1 * time.Second
|
||||
var netIface *net.Interface
|
||||
var err error
|
||||
for attempt := 1; attempt <= maxDNSRetryAttempts; attempt++ {
|
||||
netIface, err = netInterface(p.runningIface)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if attempt < maxDNSRetryAttempts {
|
||||
// Try to find a different working interface
|
||||
newIface := findWorkingInterface(p.runningIface)
|
||||
if newIface != p.runningIface {
|
||||
p.runningIface = newIface
|
||||
logger = mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
logger.Info().Msg("switched to new interface")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Warn().Err(err).Int("attempt", attempt).Msg("could not get interface, retrying...")
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
logger.Error().Err(err).Msg("could not get interface after all attempts")
|
||||
return
|
||||
}
|
||||
if err := setupNetworkManager(); err != nil {
|
||||
@@ -686,12 +769,13 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
if !requiredMultiNICsConfig() {
|
||||
return
|
||||
}
|
||||
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||
logger.Debug().Msg("start DNS settings watchdog")
|
||||
|
||||
mainLog.Load().Debug().Msg("start DNS settings watchdog")
|
||||
ns := nameservers
|
||||
slices.Sort(ns)
|
||||
ticker := time.NewTicker(p.dnsWatchdogDuration())
|
||||
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
@@ -700,7 +784,7 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
mainLog.Load().Debug().Msg("stop dns watchdog")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if p.leakingQuery.Load() {
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if dnsChanged(iface, ns) {
|
||||
@@ -726,22 +810,19 @@ func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces
|
||||
}
|
||||
|
||||
func (p *prog) resetDNS() {
|
||||
if iface == "" {
|
||||
if p.runningIface == "" {
|
||||
mainLog.Load().Debug().Msg("no running interface, skipping resetDNS")
|
||||
return
|
||||
}
|
||||
runningIface := iface
|
||||
allIfaces := false
|
||||
if runningIface == "auto" {
|
||||
runningIface = defaultIfaceName()
|
||||
// See corresponding comments in (*prog).setDNS function.
|
||||
allIfaces = requiredMultiNICsConfig()
|
||||
}
|
||||
logger := mainLog.Load().With().Str("iface", runningIface).Logger()
|
||||
netIface, err := netInterface(runningIface)
|
||||
// See corresponding comments in (*prog).setDNS function.
|
||||
allIfaces := p.requiredMultiNICsConfig
|
||||
logger := mainLog.Load().With().Str("iface", p.runningIface).Logger()
|
||||
netIface, err := netInterface(p.runningIface)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("could not get interface")
|
||||
return
|
||||
}
|
||||
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
@@ -757,16 +838,157 @@ func (p *prog) resetDNS() {
|
||||
}
|
||||
}
|
||||
|
||||
// leakOnUpstreamFailure reports whether ctrld should leak query to OS resolver when failed to connect all upstreams.
|
||||
func (p *prog) leakOnUpstreamFailure() bool {
|
||||
if ptr := p.cfg.Service.LeakOnUpstreamFailure; ptr != nil {
|
||||
return *ptr
|
||||
func (p *prog) logInterfacesState() {
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
addrs, err := i.Addrs()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get addresses")
|
||||
}
|
||||
nss, err := currentStaticDNS(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Str("interface", i.Name).Err(err).Msg("failed to get DNS")
|
||||
}
|
||||
if len(nss) == 0 {
|
||||
nss = currentDNS(i)
|
||||
}
|
||||
mainLog.Load().Debug().
|
||||
Any("addrs", addrs).
|
||||
Strs("nameservers", nss).
|
||||
Int("index", i.Index).
|
||||
Msgf("interface state: %s", i.Name)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// findWorkingInterface looks for a network interface with a valid IP configuration
|
||||
func findWorkingInterface(currentIface string) string {
|
||||
// Helper to check if IP is valid (not link-local)
|
||||
isValidIP := func(ip net.IP) bool {
|
||||
return ip != nil &&
|
||||
!ip.IsLinkLocalUnicast() &&
|
||||
!ip.IsLinkLocalMulticast() &&
|
||||
!ip.IsLoopback() &&
|
||||
!ip.IsUnspecified()
|
||||
}
|
||||
// Default is false on routers, since this leaking is only useful for devices that move between networks.
|
||||
if router.Name() != "" {
|
||||
|
||||
// Helper to check if interface has valid IP configuration
|
||||
hasValidIPConfig := func(iface *net.Interface) bool {
|
||||
if iface == nil || iface.Flags&net.FlagUp == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", iface.Name).
|
||||
Err(err).
|
||||
Msg("failed to get interface addresses")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
// Check for IP network
|
||||
if ipNet, ok := addr.(*net.IPNet); ok {
|
||||
if isValidIP(ipNet.IP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
// Get default route interface
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Msg("failed to get default route")
|
||||
} else {
|
||||
mainLog.Load().Debug().
|
||||
Str("default_route_iface", defaultRoute.InterfaceName).
|
||||
Msg("found default route")
|
||||
}
|
||||
|
||||
// Get all interfaces
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to list network interfaces")
|
||||
return currentIface // Return current interface as fallback
|
||||
}
|
||||
|
||||
var firstWorkingIface string
|
||||
var currentIfaceValid bool
|
||||
|
||||
// Single pass through interfaces
|
||||
for _, iface := range ifaces {
|
||||
// Must be physical (has MAC address)
|
||||
if len(iface.HardwareAddr) == 0 {
|
||||
continue
|
||||
}
|
||||
// Skip interfaces that are:
|
||||
// - Loopback
|
||||
// - Not up
|
||||
// - Point-to-point (like VPN tunnels)
|
||||
if iface.Flags&net.FlagLoopback != 0 ||
|
||||
iface.Flags&net.FlagUp == 0 ||
|
||||
iface.Flags&net.FlagPointToPoint != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !hasValidIPConfig(&iface) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Found working physical interface
|
||||
if err == nil && defaultRoute.InterfaceName == iface.Name {
|
||||
// Found interface with default route - use it immediately
|
||||
mainLog.Load().Info().
|
||||
Str("old_iface", currentIface).
|
||||
Str("new_iface", iface.Name).
|
||||
Msg("switching to interface with default route")
|
||||
return iface.Name
|
||||
}
|
||||
|
||||
// Keep track of first working interface as fallback
|
||||
if firstWorkingIface == "" {
|
||||
firstWorkingIface = iface.Name
|
||||
}
|
||||
|
||||
// Check if this is our current interface
|
||||
if iface.Name == currentIface {
|
||||
currentIfaceValid = true
|
||||
}
|
||||
}
|
||||
|
||||
// Return interfaces in order of preference:
|
||||
// 1. Current interface if it's still valid
|
||||
if currentIfaceValid {
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", currentIface).
|
||||
Msg("keeping current interface")
|
||||
return currentIface
|
||||
}
|
||||
|
||||
// 2. First working interface found
|
||||
if firstWorkingIface != "" {
|
||||
mainLog.Load().Info().
|
||||
Str("old_iface", currentIface).
|
||||
Str("new_iface", firstWorkingIface).
|
||||
Msg("switching to first working physical interface")
|
||||
return firstWorkingIface
|
||||
}
|
||||
|
||||
// 3. Fall back to current interface if nothing else works
|
||||
mainLog.Load().Warn().
|
||||
Str("current_iface", currentIface).
|
||||
Msg("no working physical interface found, keeping current")
|
||||
return currentIface
|
||||
}
|
||||
|
||||
// recoverOnUpstreamFailure reports whether ctrld should recover from upstream failure.
|
||||
func (p *prog) recoverOnUpstreamFailure() bool {
|
||||
// Default is false on routers, since this recovery flow is only useful for devices that move between networks.
|
||||
return router.Name() == ""
|
||||
}
|
||||
|
||||
func randomLocalIP() string {
|
||||
@@ -947,7 +1169,7 @@ func canBeLocalUpstream(addr string) bool {
|
||||
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
|
||||
validIfacesMap := validInterfacesMap()
|
||||
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
||||
// Skip loopback/virtual interface.
|
||||
// Skip loopback/virtual/down interface.
|
||||
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -956,9 +1178,12 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.
|
||||
return
|
||||
}
|
||||
netIface := i.Interface
|
||||
if err := patchNetIfaceName(netIface); err != nil {
|
||||
if patched, err := patchNetIfaceName(netIface); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to patch net interface name")
|
||||
return
|
||||
} else if !patched {
|
||||
// The interface is not functional, skipping.
|
||||
return
|
||||
}
|
||||
// Skip excluded interface.
|
||||
if netIface.Name == excludeIfaceName {
|
||||
@@ -1025,7 +1250,16 @@ func savedStaticDnsSettingsFilePath(iface *net.Interface) string {
|
||||
func savedStaticNameservers(iface *net.Interface) []string {
|
||||
file := savedStaticDnsSettingsFilePath(iface)
|
||||
if data, _ := os.ReadFile(file); len(data) > 0 {
|
||||
return strings.Split(string(data), ",")
|
||||
saveValues := strings.Split(string(data), ",")
|
||||
returnValues := []string{}
|
||||
// check each one, if its in loopback range, remove it
|
||||
for _, v := range saveValues {
|
||||
if net.ParseIP(v).IsLoopback() {
|
||||
continue
|
||||
}
|
||||
returnValues = append(returnValues, v)
|
||||
}
|
||||
return returnValues
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1044,7 +1278,7 @@ func dnsChanged(iface *net.Interface, nameservers []string) bool {
|
||||
|
||||
// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then.
|
||||
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
|
||||
var uer *controld.UtilityErrorResponse
|
||||
var uer *controld.ErrorResponse
|
||||
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
p.stopDnsWatchers()
|
||||
|
||||
|
||||
@@ -3,11 +3,38 @@ package cli
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
@@ -40,7 +67,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.leakingQuery.Load() {
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
@@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
if retry < maxRetries-1 {
|
||||
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
|
||||
@@ -156,17 +156,18 @@ func (l *launchd) Status() (service.Status, error) {
|
||||
type task struct {
|
||||
f func() error
|
||||
abortOnError bool
|
||||
Name string
|
||||
}
|
||||
|
||||
func doTasks(tasks []task) bool {
|
||||
var prevErr error
|
||||
for _, task := range tasks {
|
||||
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error())
|
||||
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
|
||||
return false
|
||||
}
|
||||
prevErr = err
|
||||
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -13,3 +13,8 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool { return false }
|
||||
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
|
||||
@@ -2,9 +2,14 @@ package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
@@ -28,6 +33,67 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
return token.IsMember(sid)
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions checks if the given service
|
||||
// has the correct failure actions configured, and updates them if not.
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil // no-op on non-Windows
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// 1. Retrieve the current config
|
||||
cfg, err := s.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update the Description
|
||||
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||
|
||||
// 3. Apply the updated config
|
||||
if err := s.UpdateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then proceed with existing actions, e.g. setting failure actions
|
||||
actions := []mgr.RecoveryAction{
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
}
|
||||
|
||||
// Set the recovery actions (3 restarts, reset period = 120).
|
||||
err = s.SetRecoveryActions(actions, 120)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that failure actions are NOT triggered on user-initiated stops.
|
||||
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
|
||||
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
|
||||
|
||||
if err := windows.ChangeServiceConfig2(
|
||||
s.Handle,
|
||||
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
|
||||
(*byte)(unsafe.Pointer(&failureActionsFlag)),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openLogFile(path string, mode int) (*os.File, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||
@@ -79,3 +145,23 @@ func openLogFile(path string, mode int) (*os.File, error) {
|
||||
|
||||
return os.NewFile(uintptr(handle), path), nil
|
||||
}
|
||||
|
||||
const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool {
|
||||
h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
p := windows.ProcessEntry32{Size: processEntrySize}
|
||||
for {
|
||||
e := windows.Process32Next(h, &p)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
25
cmd/cli/service_windows_test.go
Normal file
25
cmd/cli/service_windows_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_hasLocalDnsServerRunning(t *testing.T) {
|
||||
start := time.Now()
|
||||
hasDns := hasLocalDnsServerRunning()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
hasDnsPowershell := hasLocalDnsServerRunningPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if hasDns != hasDnsPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns)
|
||||
}
|
||||
}
|
||||
|
||||
func hasLocalDnsServerRunningPowershell() bool {
|
||||
_, err := powershell("Get-Process -Name DNS")
|
||||
return err == nil
|
||||
}
|
||||
@@ -1,18 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 100
|
||||
maxFailureRequest = 50
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
@@ -21,18 +18,24 @@ const (
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
checking map[string]bool
|
||||
down map[string]bool
|
||||
failureReq map[string]uint64
|
||||
recovered map[string]bool
|
||||
|
||||
// failureTimerActive tracks if a timer is already running for a given upstream.
|
||||
failureTimerActive map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
recovered: make(map[string]bool),
|
||||
failureTimerActive: make(map[string]bool),
|
||||
}
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
@@ -42,14 +45,47 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down
|
||||
// within 10 seconds if failures persist, without spawning duplicate goroutines.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
if um.recovered[upstream] {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream)
|
||||
return
|
||||
}
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
um.down[upstream] = failedCount >= maxFailureRequest
|
||||
|
||||
// Log the updated failure count.
|
||||
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// If this is the first failure and no timer is running, start a 10-second timer.
|
||||
if failedCount == 1 && !um.failureTimerActive[upstream] {
|
||||
um.failureTimerActive[upstream] = true
|
||||
go func(upstream string) {
|
||||
time.Sleep(10 * time.Second)
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
// If no success occurred during the 10-second window (i.e. counter remains > 0)
|
||||
// and the upstream is not in a recovered state, mark it as down.
|
||||
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
}
|
||||
// Reset the timer flag so that a new timer can be spawned if needed.
|
||||
um.failureTimerActive[upstream] = false
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
@@ -63,56 +99,28 @@ func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
um.failureReq[upstream] = 0
|
||||
um.down[upstream] = false
|
||||
}
|
||||
|
||||
// checkUpstream checks the given upstream status, periodically sending query to upstream
|
||||
// until successfully. An upstream status/counter will be reset once it becomes reachable.
|
||||
func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) {
|
||||
p.um.mu.Lock()
|
||||
isChecking := p.um.checking[upstream]
|
||||
if isChecking {
|
||||
p.um.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.um.checking[upstream] = true
|
||||
p.um.mu.Unlock()
|
||||
defer func() {
|
||||
p.um.mu.Lock()
|
||||
p.um.checking[upstream] = false
|
||||
p.um.mu.Unlock()
|
||||
um.recovered[upstream] = true
|
||||
um.mu.Unlock()
|
||||
go func() {
|
||||
// debounce the recovery to avoid incrementing failure counts already in flight
|
||||
time.Sleep(1 * time.Second)
|
||||
um.mu.Lock()
|
||||
um.recovered[upstream] = false
|
||||
um.mu.Unlock()
|
||||
}()
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
|
||||
return
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
|
||||
check := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
uc.ReBootstrap()
|
||||
_, err := resolver.Resolve(ctx, msg)
|
||||
return err
|
||||
}
|
||||
for {
|
||||
if err := check(); err == nil {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
|
||||
p.um.reset(upstream)
|
||||
if p.leakingQuery.CompareAndSwap(true, false) {
|
||||
p.leakingQueryMu.Lock()
|
||||
p.leakingQueryWasRun = false
|
||||
p.leakingQueryMu.Unlock()
|
||||
mainLog.Load().Warn().Msg("stop leaking query")
|
||||
}
|
||||
return
|
||||
}
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
}
|
||||
}
|
||||
|
||||
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
|
||||
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
|
||||
var count int
|
||||
um.mu.RLock()
|
||||
for _, upstream := range upstreams {
|
||||
if !um.down[upstream] {
|
||||
count++
|
||||
}
|
||||
}
|
||||
um.mu.RUnlock()
|
||||
return count
|
||||
}
|
||||
|
||||
15
config.go
15
config.go
@@ -205,7 +205,7 @@ type ServiceConfig struct {
|
||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp kea-dhcp4"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
@@ -384,7 +384,7 @@ func (uc *UpstreamConfig) IsDiscoverable() bool {
|
||||
return *uc.Discoverable
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate:
|
||||
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate, ResolverTypeLocal:
|
||||
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
@@ -458,7 +458,7 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
@@ -886,3 +886,12 @@ func upstreamUID() string {
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the UpstreamConfig for logging.
|
||||
func (uc *UpstreamConfig) String() string {
|
||||
if uc == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}",
|
||||
uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack)
|
||||
}
|
||||
|
||||
@@ -2,12 +2,16 @@ package ctrld
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
|
||||
l := zerolog.New(os.Stdout)
|
||||
ProxyLogger.Store(&l)
|
||||
uc := &UpstreamConfig{
|
||||
Name: "test",
|
||||
Type: ResolverTypeDOH,
|
||||
|
||||
@@ -34,7 +34,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
rt := &http3.RoundTripper{}
|
||||
rt := &http3.Transport{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
@@ -64,7 +64,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) {
|
||||
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
|
||||
rt.CloseIdleConnections()
|
||||
})
|
||||
return rt
|
||||
|
||||
@@ -111,6 +111,7 @@ func TestConfigValidation(t *testing.T) {
|
||||
{"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false},
|
||||
{"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false},
|
||||
{"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true},
|
||||
{"kea dhcp4 format", configWithDhcp4KeaFormat(t), false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -307,6 +308,12 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithDhcp4KeaFormat(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.DHCPLeaseFileFormat = "kea-dhcp4"
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "/1.1.1.1"
|
||||
|
||||
23
go.mod
23
go.mod
@@ -9,6 +9,7 @@ require (
|
||||
github.com/ameshkov/dnsstamps v1.0.3
|
||||
github.com/coreos/go-systemd/v22 v22.5.0
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf
|
||||
github.com/docker/go-units v0.5.0
|
||||
github.com/frankban/quicktest v1.14.6
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/go-playground/validator/v10 v10.11.1
|
||||
@@ -20,6 +21,7 @@ require (
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
|
||||
github.com/kardianos/service v1.2.1
|
||||
github.com/mdlayher/ndp v1.0.1
|
||||
github.com/microsoft/wmi v0.24.5
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/minio/selfupdate v0.6.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
@@ -27,16 +29,16 @@ require (
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/prometheus/client_model v0.5.0
|
||||
github.com/prometheus/prom2json v1.3.3
|
||||
github.com/quic-go/quic-go v0.42.0
|
||||
github.com/quic-go/quic-go v0.48.2
|
||||
github.com/rs/zerolog v1.28.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.16.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/net v0.27.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/sys v0.22.0
|
||||
golang.org/x/net v0.33.0
|
||||
golang.org/x/sync v0.10.0
|
||||
golang.org/x/sys v0.29.0
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
tailscale.com v1.74.0
|
||||
)
|
||||
@@ -49,12 +51,14 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-playground/locales v0.14.0 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
|
||||
@@ -72,10 +76,11 @@ require (
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/prometheus/common v0.48.0 // indirect
|
||||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.4 // indirect
|
||||
github.com/rogpeppe/go-internal v1.11.0 // indirect
|
||||
github.com/spf13/afero v1.9.5 // indirect
|
||||
@@ -87,10 +92,10 @@ require (
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||
golang.org/x/crypto v0.25.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect
|
||||
golang.org/x/crypto v0.31.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
golang.org/x/mod v0.19.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/tools v0.23.0 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
@@ -99,4 +104,4 @@ require (
|
||||
|
||||
replace github.com/mr-karan/doggo => github.com/Windscribe/doggo v0.0.0-20220919152748-2c118fc391f8
|
||||
|
||||
replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be
|
||||
replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c
|
||||
|
||||
54
go.sum
54
go.sum
@@ -42,8 +42,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
|
||||
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
|
||||
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik=
|
||||
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w=
|
||||
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
|
||||
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
|
||||
@@ -74,6 +74,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa h1:h8TfIT1xc8FWbwwpmHn1J5i43Y0uZP97GqasGCzSRJk=
|
||||
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
@@ -91,6 +93,8 @@ github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f
|
||||
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
|
||||
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
|
||||
@@ -162,6 +166,8 @@ github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlG
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||
@@ -207,11 +213,10 @@ github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
|
||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
@@ -227,6 +232,8 @@ github.com/mdlayher/packet v1.1.2 h1:3Up1NG6LZrsgDVn6X4L9Ge/iyRyxFEFD9o6Pr3Q1nQY
|
||||
github.com/mdlayher/packet v1.1.2/go.mod h1:GEu1+n9sG5VtiRE4SydOmX5GTwyyYlteZiFU+x0kew4=
|
||||
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
|
||||
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
|
||||
github.com/microsoft/wmi v0.24.5 h1:NT+WqhjKbEcg3ldmDsRMarWgHGkpeW+gMopSCfON0kM=
|
||||
github.com/microsoft/wmi v0.24.5/go.mod h1:1zbdSF0A+5OwTUII5p3hN7/K6KF2m3o27pSG6Y51VU8=
|
||||
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
|
||||
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
|
||||
github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU=
|
||||
@@ -245,6 +252,7 @@ github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFu
|
||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -261,10 +269,10 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k
|
||||
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
|
||||
github.com/prometheus/prom2json v1.3.3 h1:IYfSMiZ7sSOfliBoo89PcufjWO4eAR0gznGcETyaUgo=
|
||||
github.com/prometheus/prom2json v1.3.3/go.mod h1:Pv4yIPktEkK7btWsrUTWDDDrnpUrAELaOCj+oFwlgmc=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
|
||||
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
|
||||
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
|
||||
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
@@ -274,7 +282,7 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM=
|
||||
github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ=
|
||||
@@ -338,8 +346,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
|
||||
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -350,8 +358,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
|
||||
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
|
||||
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
|
||||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA=
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
@@ -409,8 +417,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -430,8 +438,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -472,16 +480,16 @@ golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
|
||||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -492,8 +500,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
||||
@@ -77,6 +77,7 @@ type Table struct {
|
||||
hostnameResolvers []HostnameResolver
|
||||
refreshers []refresher
|
||||
initOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
refreshInterval int
|
||||
|
||||
dhcp *dhcp
|
||||
@@ -90,7 +91,9 @@ type Table struct {
|
||||
vni *virtualNetworkIface
|
||||
svcCfg ctrld.ServiceConfig
|
||||
quitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
selfIP string
|
||||
selfIPLock sync.RWMutex
|
||||
cdUID string
|
||||
ptrNameservers []string
|
||||
}
|
||||
@@ -103,6 +106,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
|
||||
return &Table{
|
||||
svcCfg: cfg.Service,
|
||||
quitCh: make(chan struct{}),
|
||||
stopCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
ptrNameservers: ns,
|
||||
@@ -120,24 +124,59 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) {
|
||||
// RefreshLoop runs all the refresher to update new client info data.
|
||||
func (t *Table) RefreshLoop(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Second * time.Duration(t.refreshInterval))
|
||||
defer timer.Stop()
|
||||
defer func() {
|
||||
timer.Stop()
|
||||
close(t.quitCh)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
t.Refresh()
|
||||
case <-t.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
close(t.quitCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes all client info discovers.
|
||||
func (t *Table) Init() {
|
||||
t.initOnce.Do(t.init)
|
||||
}
|
||||
|
||||
// Refresh forces all discovers to retrieve new data.
|
||||
func (t *Table) Refresh() {
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops all the discovers.
|
||||
// It blocks until all the discovers done.
|
||||
func (t *Table) Stop() {
|
||||
t.stopOnce.Do(func() {
|
||||
close(t.stopCh)
|
||||
})
|
||||
<-t.quitCh
|
||||
}
|
||||
|
||||
// SelfIP returns the selfIP value of the Table in a thread-safe manner.
|
||||
func (t *Table) SelfIP() string {
|
||||
t.selfIPLock.RLock()
|
||||
defer t.selfIPLock.RUnlock()
|
||||
return t.selfIP
|
||||
}
|
||||
|
||||
// SetSelfIP sets the selfIP value of the Table in a thread-safe manner.
|
||||
func (t *Table) SetSelfIP(ip string) {
|
||||
t.selfIPLock.Lock()
|
||||
defer t.selfIPLock.Unlock()
|
||||
t.selfIP = ip
|
||||
t.dhcp.selfIP = t.selfIP
|
||||
t.dhcp.addSelf()
|
||||
}
|
||||
|
||||
func (t *Table) init() {
|
||||
// Custom client ID presents, use it as the only source.
|
||||
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
|
||||
@@ -381,9 +420,7 @@ func (t *Table) lookupHostnameAll(ip, mac string) []*hostnameEntry {
|
||||
|
||||
// ListClients returns list of clients discovered by ctrld.
|
||||
func (t *Table) ListClients() []*Client {
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
t.Refresh()
|
||||
ipMap := make(map[string]*Client)
|
||||
il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni}
|
||||
for _, ir := range il {
|
||||
|
||||
@@ -25,8 +25,12 @@ import (
|
||||
const (
|
||||
apiDomainCom = "api.controld.com"
|
||||
apiDomainDev = "api.controld.dev"
|
||||
resolverDataURLCom = "https://api.controld.com/utility"
|
||||
resolverDataURLDev = "https://api.controld.dev/utility"
|
||||
apiURLCom = "https://api.controld.com"
|
||||
apiURLDev = "https://api.controld.dev"
|
||||
resolverDataURLCom = apiURLCom + "/utility"
|
||||
resolverDataURLDev = apiURLDev + "/utility"
|
||||
logURLCom = apiURLCom + "/logs"
|
||||
logURLDev = apiURLDev + "/logs"
|
||||
InvalidConfigCode = 40402
|
||||
)
|
||||
|
||||
@@ -49,14 +53,14 @@ type utilityResponse struct {
|
||||
} `json:"body"`
|
||||
}
|
||||
|
||||
type UtilityErrorResponse struct {
|
||||
type ErrorResponse struct {
|
||||
ErrorField struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func (u UtilityErrorResponse) Error() string {
|
||||
func (u ErrorResponse) Error() string {
|
||||
return u.ErrorField.Message
|
||||
}
|
||||
|
||||
@@ -71,6 +75,12 @@ type UtilityOrgRequest struct {
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
// LogsRequest contains request data for sending runtime logs to API.
|
||||
type LogsRequest struct {
|
||||
UID string `json:"uid"`
|
||||
Data io.ReadCloser `json:"-"`
|
||||
}
|
||||
|
||||
// FetchResolverConfig fetch Control D config for given uid.
|
||||
func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) {
|
||||
uid, clientID := ParseRawUID(rawUID)
|
||||
@@ -123,6 +133,81 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
transport := apiTransport(cdDev)
|
||||
client := http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
d := json.NewDecoder(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errResp := &ErrorResponse{}
|
||||
if err := d.Decode(errResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, errResp
|
||||
}
|
||||
|
||||
ur := &utilityResponse{}
|
||||
if err := d.Decode(ur); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ur.Body.Resolver, nil
|
||||
}
|
||||
|
||||
// SendLogs sends runtime log to ControlD API.
|
||||
func SendLogs(lr *LogsRequest, cdDev bool) error {
|
||||
defer lr.Data.Close()
|
||||
apiUrl := logURLCom
|
||||
if cdDev {
|
||||
apiUrl = logURLDev
|
||||
}
|
||||
req, err := http.NewRequest("POST", apiUrl, lr.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http.NewRequest: %w", err)
|
||||
}
|
||||
q := req.URL.Query()
|
||||
q.Set("uid", lr.UID)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
transport := apiTransport(cdDev)
|
||||
client := http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SendLogs client.Do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
d := json.NewDecoder(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errResp := &ErrorResponse{}
|
||||
if err := d.Decode(errResp); err != nil {
|
||||
return err
|
||||
}
|
||||
return errResp
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseRawUID parse the input raw UID, returning real UID and ClientID.
|
||||
// The raw UID can have 2 forms:
|
||||
//
|
||||
// - <uid>
|
||||
// - <uid>/<client_id>
|
||||
func ParseRawUID(rawUID string) (string, string) {
|
||||
uid, clientID, _ := strings.Cut(rawUID, "/")
|
||||
return uid, clientID
|
||||
}
|
||||
|
||||
// apiTransport returns an HTTP transport for connecting to ControlD API endpoint.
|
||||
func apiTransport(cdDev bool) *http.Transport {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
apiDomain := apiDomainCom
|
||||
@@ -143,41 +228,8 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
return d.DialContext(ctx, network, addrs)
|
||||
}
|
||||
|
||||
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("client.Do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
d := json.NewDecoder(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errResp := &UtilityErrorResponse{}
|
||||
if err := d.Decode(errResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, errResp
|
||||
}
|
||||
|
||||
ur := &utilityResponse{}
|
||||
if err := d.Decode(ur); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ur.Body.Resolver, nil
|
||||
}
|
||||
|
||||
// ParseRawUID parse the input raw UID, returning real UID and ClientID.
|
||||
// The raw UID can have 2 forms:
|
||||
//
|
||||
// - <uid>
|
||||
// - <uid>/<client_id>
|
||||
func ParseRawUID(rawUID string) (string, string) {
|
||||
uid, clientID, _ := strings.Cut(rawUID, "/")
|
||||
return uid, clientID
|
||||
return transport
|
||||
}
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromRIB, dnsFromIPConfig}
|
||||
return []dnsFn{dnsFromRIB}
|
||||
}
|
||||
|
||||
func dnsFromRIB() []string {
|
||||
@@ -49,18 +46,6 @@ func dnsFromRIB() []string {
|
||||
return dns
|
||||
}
|
||||
|
||||
func dnsFromIPConfig() []string {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server")
|
||||
out, _ := cmd.Output()
|
||||
if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil {
|
||||
return []string{ip.String()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toNetIP(addr route.Addr) net.IP {
|
||||
switch t := addr.(type) {
|
||||
case *route.Inet4Addr:
|
||||
|
||||
217
nameservers_darwin.go
Normal file
217
nameservers_darwin.go
Normal file
@@ -0,0 +1,217 @@
|
||||
//go:build darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers}
|
||||
}
|
||||
|
||||
// dnsFromResolvConf reads nameservers from /etc/resolv.conf
|
||||
func dnsFromResolvConf() []string {
|
||||
const (
|
||||
maxRetries = 10
|
||||
retryInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
|
||||
var dns []string
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(retryInterval)
|
||||
}
|
||||
|
||||
nss := resolvconffile.NameServers("")
|
||||
var localDNS []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, ns := range nss {
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback IPs
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
ipStr := v.String()
|
||||
if ip.String() == ipStr {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !seen[ip.String()] {
|
||||
seen[ip.String()] = true
|
||||
localDNS = append(localDNS, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we successfully read the file and found nameservers, return them
|
||||
if len(localDNS) > 0 {
|
||||
return localDNS
|
||||
}
|
||||
}
|
||||
|
||||
return dns
|
||||
}
|
||||
|
||||
func getDNSFromScutil() []string {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
const (
|
||||
maxRetries = 10
|
||||
retryInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
|
||||
var nameservers []string
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(retryInterval)
|
||||
}
|
||||
|
||||
cmd := exec.Command("scutil", "--dns")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Error(), "failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var localDNS []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "nameserver[") {
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) == 2 {
|
||||
ns := strings.TrimSpace(parts[1])
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback IPs
|
||||
isLocal := false
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
ipStr := v.String()
|
||||
if ip.String() == ipStr {
|
||||
isLocal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isLocal && !seen[ip.String()] {
|
||||
seen[ip.String()] = true
|
||||
localDNS = append(localDNS, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
Log(context.Background(), logger.Error(), "error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// If we successfully read the output and found nameservers, return them
|
||||
if len(localDNS) > 0 {
|
||||
return localDNS
|
||||
}
|
||||
}
|
||||
|
||||
return nameservers
|
||||
}
|
||||
|
||||
func getDHCPNameservers(iface string) ([]string, error) {
|
||||
// Run the ipconfig command for the given interface.
|
||||
cmd := exec.Command("ipconfig", "getpacket", iface)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error running ipconfig: %v", err)
|
||||
}
|
||||
|
||||
// Look for a line like:
|
||||
// domain_name_servers = 192.168.1.1 8.8.8.8;
|
||||
re := regexp.MustCompile(`domain_name_servers\s*=\s*(.*);`)
|
||||
matches := re.FindStringSubmatch(string(output))
|
||||
if len(matches) < 2 {
|
||||
return nil, fmt.Errorf("no DHCP nameservers found")
|
||||
}
|
||||
|
||||
// Split the nameservers by whitespace.
|
||||
nameservers := strings.Fields(matches[1])
|
||||
return nameservers, nil
|
||||
}
|
||||
|
||||
func getAllDHCPNameservers() []string {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
|
||||
var allNameservers []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, iface := range interfaces {
|
||||
// Skip interfaces that are:
|
||||
// - down
|
||||
// - loopback
|
||||
// - not physical (virtual)
|
||||
// - point-to-point (like VPN interfaces)
|
||||
// - without MAC address (non-physical)
|
||||
if iface.Flags&net.FlagUp == 0 ||
|
||||
iface.Flags&net.FlagLoopback != 0 ||
|
||||
iface.Flags&net.FlagPointToPoint != 0 ||
|
||||
(iface.Flags&net.FlagBroadcast == 0 &&
|
||||
iface.Flags&net.FlagMulticast == 0) ||
|
||||
len(iface.HardwareAddr) == 0 ||
|
||||
strings.HasPrefix(iface.Name, "utun") ||
|
||||
strings.HasPrefix(iface.Name, "llw") ||
|
||||
strings.HasPrefix(iface.Name, "awdl") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify it's a valid MAC address (should be 6 bytes for IEEE 802 MAC-48)
|
||||
if len(iface.HardwareAddr) != 6 {
|
||||
continue
|
||||
}
|
||||
|
||||
nameservers, err := getDHCPNameservers(iface.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add unique nameservers to the result, skipping local IPs
|
||||
for _, ns := range nameservers {
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback and local IPs
|
||||
isLocal := false
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
if ip.String() == v.String() {
|
||||
isLocal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isLocal && !seen[ns] {
|
||||
seen[ns] = true
|
||||
allNameservers = append(allNameservers, ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allNameservers
|
||||
}
|
||||
@@ -1,44 +1,473 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
const (
|
||||
maxDNSAdapterRetries = 5
|
||||
retryDelayDNSAdapter = 1 * time.Second
|
||||
defaultDNSAdapterTimeout = 10 * time.Second
|
||||
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
||||
NetSetupUnknown uint32 = 0
|
||||
NetSetupWorkgroup uint32 = 1
|
||||
NetSetupDomain uint32 = 2
|
||||
NetSetupCloudDomain uint32 = 3
|
||||
DS_FORCE_REDISCOVERY = 0x00000001
|
||||
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
|
||||
DS_BACKGROUND_ONLY = 0x00000100
|
||||
DS_IP_REQUIRED = 0x00000200
|
||||
DS_IS_DNS_NAME = 0x00020000
|
||||
DS_RETURN_DNS_NAME = 0x40000000
|
||||
)
|
||||
|
||||
type DomainControllerInfo struct {
|
||||
DomainControllerName *uint16
|
||||
DomainControllerAddress *uint16
|
||||
DomainControllerAddressType uint32
|
||||
DomainGuid windows.GUID
|
||||
DomainName *uint16
|
||||
DnsForestName *uint16
|
||||
Flags uint32
|
||||
DcSiteName *uint16
|
||||
ClientSiteName *uint16
|
||||
}
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromAdapter}
|
||||
}
|
||||
|
||||
func dnsFromAdapter() []string {
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix)
|
||||
if err != nil {
|
||||
return nil
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
|
||||
defer cancel()
|
||||
|
||||
var ns []string
|
||||
var err error
|
||||
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
|
||||
for i := 0; i < maxDNSAdapterRetries; i++ {
|
||||
if ctx.Err() != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
||||
return nil
|
||||
}
|
||||
|
||||
ns, err = getDNSServers(ctx)
|
||||
if err == nil && len(ns) >= minDNSServers {
|
||||
if i > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Successfully got DNS servers after %d attempts, found %d servers",
|
||||
i+1, len(ns))
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// if osResolver is not initialized, this is likely a command line run
|
||||
// and ctrld is already on the interface, abort retries
|
||||
if or == nil {
|
||||
return ns
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get DNS servers, attempt %d: %v", i+1, err)
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got insufficient DNS servers, retrying, found %d servers", len(ns))
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(retryDelayDNSAdapter):
|
||||
}
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
|
||||
return ns
|
||||
}
|
||||
|
||||
func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
// Check context before making the call
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Get DNS servers from adapters (existing method)
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting adapters: %w", err)
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Found network adapters, count=%d", len(aas))
|
||||
|
||||
// Try to get domain controller info if domain-joined
|
||||
var dcServers []string
|
||||
isDomain := checkDomainJoined()
|
||||
if isDomain {
|
||||
domainName, err := getLocalADDomain()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get local AD domain: %v", err)
|
||||
} else {
|
||||
// Load netapi32.dll
|
||||
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
||||
dsDcName := netapi32.NewProc("DsGetDcNameW")
|
||||
|
||||
var info *DomainControllerInfo
|
||||
flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME)
|
||||
|
||||
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to convert domain name to UTF16: %v", err)
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
|
||||
|
||||
// Call DsGetDcNameW with domain name
|
||||
ret, _, err := dsDcName.Call(
|
||||
0, // ComputerName - can be NULL
|
||||
uintptr(unsafe.Pointer(domainUTF16)), // DomainName
|
||||
0, // DomainGuid - not needed
|
||||
0, // SiteName - not needed
|
||||
uintptr(flags), // Flags
|
||||
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
|
||||
|
||||
if ret != 0 {
|
||||
switch ret {
|
||||
case 1355: // ERROR_NO_SUCH_DOMAIN
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain not found: %s (%d)", domainName, ret)
|
||||
case 1311: // ERROR_NO_LOGON_SERVERS
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"No logon servers available for domain: %s (%d)", domainName, ret)
|
||||
case 1004: // ERROR_DC_NOT_FOUND
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain controller not found for domain: %s (%d)", domainName, ret)
|
||||
case 1722: // RPC_S_SERVER_UNAVAILABLE
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"RPC server unavailable for domain: %s (%d)", domainName, ret)
|
||||
default:
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
|
||||
}
|
||||
} else if info != nil {
|
||||
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
|
||||
|
||||
if info.DomainControllerAddress != nil {
|
||||
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
||||
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Found domain controller address: %s", dcAddr)
|
||||
|
||||
if ip := net.ParseIP(dcAddr); ip != nil {
|
||||
dcServers = append(dcServers, ip.String())
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added domain controller DNS servers: %v", dcServers)
|
||||
}
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"No domain controller address found")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Continue with existing adapter DNS collection
|
||||
ns := make([]string, 0, len(aas)*2)
|
||||
seen := make(map[string]bool)
|
||||
addressMap := make(map[string]struct{})
|
||||
|
||||
// Collect all local IPs
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Processing adapter %s", aa.FriendlyName())
|
||||
|
||||
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
|
||||
addressMap[a.Address.IP().String()] = struct{}{}
|
||||
ip := a.Address.IP().String()
|
||||
addressMap[ip] = struct{}{}
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
|
||||
}
|
||||
}
|
||||
|
||||
validInterfacesMap := validInterfaces()
|
||||
|
||||
// Collect DNS servers
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
// if not in the validInterfacesMap, skip
|
||||
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
||||
ip := dns.Address.IP()
|
||||
if ip == nil || ip.IsLoopback() || seen[ip.String()] {
|
||||
if ip == nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping nil IP from adapter %s", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
if _, ok := addressMap[ip.String()]; ok {
|
||||
|
||||
ipStr := ip.String()
|
||||
l := logger.Debug().
|
||||
Str("ip", ipStr).
|
||||
Str("adapter", aa.FriendlyName())
|
||||
|
||||
if ip.IsLoopback() {
|
||||
l.Msg("Skipping loopback IP")
|
||||
continue
|
||||
}
|
||||
seen[ip.String()] = true
|
||||
ns = append(ns, ip.String())
|
||||
if seen[ipStr] {
|
||||
l.Msg("Skipping duplicate IP")
|
||||
continue
|
||||
}
|
||||
if _, ok := addressMap[ipStr]; ok {
|
||||
l.Msg("Skipping local interface IP")
|
||||
continue
|
||||
}
|
||||
|
||||
seen[ipStr] = true
|
||||
ns = append(ns, ipStr)
|
||||
l.Msg("Added DNS server")
|
||||
}
|
||||
}
|
||||
return ns
|
||||
|
||||
// Add DC servers if they're not already in the list
|
||||
for _, dcServer := range dcServers {
|
||||
if !seen[dcServer] {
|
||||
seen[dcServer] = true
|
||||
ns = append(ns, dcServer)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added additional domain controller DNS server: %s", dcServer)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ns) == 0 {
|
||||
return nil, fmt.Errorf("no valid DNS servers found")
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
|
||||
len(ns), ns, len(dcServers))
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
func nameserversFromResolvconf() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
||||
// Returns whether it's domain joined and the domain name if available
|
||||
func checkDomainJoined() bool {
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
var domain *uint16
|
||||
var status uint32
|
||||
|
||||
err := windows.NetGetJoinInformation(nil, &domain, &status)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get domain join status: %v", err)
|
||||
return false
|
||||
}
|
||||
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
|
||||
|
||||
domainName := windows.UTF16PtrToString(domain)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain join status: domain=%s status=%d (Unknown=0, Workgroup=1, Domain=2, CloudDomain=3)",
|
||||
domainName, status)
|
||||
|
||||
// Consider domain or cloud domain as domain-joined
|
||||
isDomain := status == NetSetupDomain || status == NetSetupCloudDomain
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Is domain joined? status=%d, traditional=%v, cloud=%v, result=%v",
|
||||
status,
|
||||
status == NetSetupDomain,
|
||||
status == NetSetupCloudDomain,
|
||||
isDomain)
|
||||
|
||||
return isDomain
|
||||
}
|
||||
|
||||
// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*)
|
||||
// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call.
|
||||
func getLocalADDomain() (string, error) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
// 1) Check environment variable
|
||||
envDomain := os.Getenv("USERDNSDOMAIN")
|
||||
if envDomain != "" {
|
||||
return strings.TrimSpace(envDomain), nil
|
||||
}
|
||||
|
||||
// 2) Query WMI via the microsoft/wmi library
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("WMI query failed: %v", err)
|
||||
}
|
||||
|
||||
// If no results, return an error
|
||||
if len(instances) == 0 {
|
||||
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
|
||||
}
|
||||
|
||||
// We only care about the first row
|
||||
domainVal, err := instances[0].GetProperty("Domain")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("machine does not appear to have a domain set: %v", err)
|
||||
}
|
||||
|
||||
domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal))
|
||||
if domainName == "" {
|
||||
return "", fmt.Errorf("machine does not appear to have a domain set")
|
||||
}
|
||||
return domainName, nil
|
||||
}
|
||||
|
||||
// validInterfaces returns a list of all physical interfaces.
|
||||
// this is a duplicate of what is in net_windows.go, we should
|
||||
// clean this up so there is only one version
|
||||
func validInterfaces() map[string]struct{} {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
//load the logger
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get wmi network adapter: %v", err)
|
||||
return nil
|
||||
}
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get network adapter: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := adapter.GetPropertyName()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get interface name: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||
//
|
||||
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
||||
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||
physical, err := adapter.GetPropertyConnectorPresent()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"failed to get network adapter connector present property: %v", err)
|
||||
continue
|
||||
}
|
||||
if !physical {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"skipping non-physical adapter: %s", name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a hardware interface. Checking only for connector present is not enough
|
||||
// because some interfaces are not physical but have a connector.
|
||||
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"failed to get network adapter hardware interface property: %v", err)
|
||||
continue
|
||||
}
|
||||
if !hardware {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"skipping non-hardware interface: %s", name)
|
||||
continue
|
||||
}
|
||||
|
||||
adapters = append(adapters, name)
|
||||
}
|
||||
|
||||
m := make(map[string]struct{})
|
||||
for _, ifaceName := range adapters {
|
||||
m[ifaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
310
resolver.go
310
resolver.go
@@ -4,15 +4,17 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/zerolog"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
@@ -30,8 +32,10 @@ const (
|
||||
ResolverTypeOS = "os"
|
||||
// ResolverTypeLegacy specifies legacy resolver.
|
||||
ResolverTypeLegacy = "legacy"
|
||||
// ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only.
|
||||
// ResolverTypePrivate is like ResolverTypeOS, but use for private resolver only.
|
||||
ResolverTypePrivate = "private"
|
||||
// ResolverTypeLocal is like ResolverTypeOS, but use for local resolver only.
|
||||
ResolverTypeLocal = "local"
|
||||
// ResolverTypeSDNS specifies resolver with information encoded using DNS Stamps.
|
||||
// See: https://dnscrypt.info/stamps-specifications/
|
||||
ResolverTypeSDNS = "sdns"
|
||||
@@ -44,8 +48,30 @@ const (
|
||||
|
||||
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = newResolverWithNameserver(defaultNameservers())
|
||||
var localResolver = newLocalResolver()
|
||||
|
||||
var (
|
||||
resolverMutex sync.Mutex
|
||||
or *osResolver
|
||||
defaultLocalIPv4 atomic.Value // holds net.IP (IPv4)
|
||||
defaultLocalIPv6 atomic.Value // holds net.IP (IPv6)
|
||||
)
|
||||
|
||||
func newLocalResolver() Resolver {
|
||||
var nss []string
|
||||
for _, addr := range Rfc1918Addresses() {
|
||||
nss = append(nss, net.JoinHostPort(addr, "53"))
|
||||
}
|
||||
return NewResolverWithNameserver(nss)
|
||||
}
|
||||
|
||||
// LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network.
|
||||
type LanQueryCtxKey struct{}
|
||||
|
||||
// LanQueryCtx returns a context.Context with LanQueryCtxKey set.
|
||||
func LanQueryCtx(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, LanQueryCtxKey{}, true)
|
||||
}
|
||||
|
||||
// defaultNameservers is like nameservers with each element formed "ip:53".
|
||||
func defaultNameservers() []string {
|
||||
@@ -63,17 +89,39 @@ func availableNameservers() []string {
|
||||
// Ignore local addresses to prevent loop.
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
machineIPsMap[v.String()] = struct{}{}
|
||||
|
||||
//load the logger
|
||||
logger := zerolog.New(io.Discard)
|
||||
if ProxyLogger.Load() != nil {
|
||||
logger = *ProxyLogger.Load()
|
||||
}
|
||||
for _, ns := range nameservers() {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs)
|
||||
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
ipStr := v.String()
|
||||
machineIPsMap[ipStr] = struct{}{}
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added local IP to OS resolverexclusion map: %s", ipStr)
|
||||
}
|
||||
|
||||
systemNameservers := nameservers()
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got system nameservers: %v", systemNameservers)
|
||||
|
||||
for _, ns := range systemNameservers {
|
||||
if _, ok := machineIPsMap[ns]; ok {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping local nameserver: %s", ns)
|
||||
continue
|
||||
}
|
||||
if testNameserver(ns) {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
nss = append(nss, ns)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added non-local nameserver: %s", ns)
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Final available nameservers: %v", nss)
|
||||
return nss
|
||||
}
|
||||
|
||||
@@ -82,77 +130,47 @@ func availableNameservers() []string {
|
||||
//
|
||||
// It's the caller's responsibility to ensure the system DNS is in a clean state before
|
||||
// calling this function.
|
||||
func InitializeOsResolver() []string {
|
||||
return initializeOsResolver(availableNameservers())
|
||||
func InitializeOsResolver(guardAgainstNoNameservers bool) []string {
|
||||
nameservers := availableNameservers()
|
||||
// if no nameservers, return empty slice so we dont remove all nameservers
|
||||
if len(nameservers) == 0 && guardAgainstNoNameservers {
|
||||
return []string{}
|
||||
}
|
||||
ns := initializeOsResolver(nameservers)
|
||||
resolverMutex.Lock()
|
||||
defer resolverMutex.Unlock()
|
||||
or = newResolverWithNameserver(ns)
|
||||
return ns
|
||||
}
|
||||
|
||||
// initializeOsResolver performs logic for choosing OS resolver nameserver.
|
||||
// The logic:
|
||||
//
|
||||
// - First available LAN servers are saved and store.
|
||||
// - Later calls, if no LAN servers available, the saved servers above will be used.
|
||||
func initializeOsResolver(servers []string) []string {
|
||||
var (
|
||||
nss []string
|
||||
publicNss []string
|
||||
)
|
||||
var (
|
||||
lastLanServer netip.Addr
|
||||
curLanServer netip.Addr
|
||||
curLanServerAvailable bool
|
||||
)
|
||||
if p := or.currentLanServer.Load(); p != nil {
|
||||
curLanServer = *p
|
||||
or.currentLanServer.Store(nil)
|
||||
}
|
||||
if p := or.lastLanServer.Load(); p != nil {
|
||||
lastLanServer = *p
|
||||
or.lastLanServer.Store(nil)
|
||||
}
|
||||
|
||||
var lanNss, publicNss []string
|
||||
|
||||
// First categorize servers
|
||||
for _, ns := range servers {
|
||||
addr, err := netip.ParseAddr(ns)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
server := net.JoinHostPort(ns, "53")
|
||||
// Always use new public nameserver.
|
||||
if !isLanAddr(addr) {
|
||||
publicNss = append(publicNss, server)
|
||||
nss = append(nss, server)
|
||||
continue
|
||||
}
|
||||
// For LAN server, storing only current and last LAN server if any.
|
||||
if addr.Compare(curLanServer) == 0 {
|
||||
curLanServerAvailable = true
|
||||
if isLanAddr(addr) {
|
||||
lanNss = append(lanNss, server)
|
||||
} else {
|
||||
if addr.Compare(lastLanServer) == 0 {
|
||||
or.lastLanServer.Store(&addr)
|
||||
} else {
|
||||
if or.currentLanServer.CompareAndSwap(nil, &addr) {
|
||||
nss = append(nss, server)
|
||||
}
|
||||
}
|
||||
publicNss = append(publicNss, server)
|
||||
}
|
||||
}
|
||||
// Store current LAN server as last one only if it's still available.
|
||||
if curLanServerAvailable && curLanServer.IsValid() {
|
||||
or.lastLanServer.Store(&curLanServer)
|
||||
nss = append(nss, net.JoinHostPort(curLanServer.String(), "53"))
|
||||
}
|
||||
if len(publicNss) == 0 {
|
||||
publicNss = append(publicNss, controldPublicDnsWithPort)
|
||||
nss = append(nss, controldPublicDnsWithPort)
|
||||
}
|
||||
or.publicServer.Store(&publicNss)
|
||||
return nss
|
||||
}
|
||||
|
||||
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
|
||||
func testNameserver(addr string) bool {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("controld.com.", dns.TypeNS)
|
||||
client := new(dns.Client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53"))
|
||||
if err != nil {
|
||||
ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr)
|
||||
if len(publicNss) == 0 {
|
||||
publicNss = []string{controldPublicDnsWithPort}
|
||||
}
|
||||
return err == nil
|
||||
|
||||
return slices.Concat(lanNss, publicNss)
|
||||
}
|
||||
|
||||
// Resolver is the interface that wraps the basic DNS operations.
|
||||
@@ -175,19 +193,23 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
case ResolverTypeDOQ:
|
||||
return &doqResolver{uc: uc}, nil
|
||||
case ResolverTypeOS:
|
||||
if or == nil {
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
return or, nil
|
||||
case ResolverTypeLegacy:
|
||||
return &legacyResolver{uc: uc}, nil
|
||||
case ResolverTypePrivate:
|
||||
return NewPrivateResolver(), nil
|
||||
case ResolverTypeLocal:
|
||||
return localResolver, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
|
||||
}
|
||||
|
||||
type osResolver struct {
|
||||
currentLanServer atomic.Pointer[netip.Addr]
|
||||
lastLanServer atomic.Pointer[netip.Addr]
|
||||
publicServer atomic.Pointer[[]string]
|
||||
lanServers atomic.Pointer[[]string]
|
||||
publicServers atomic.Pointer[[]string]
|
||||
}
|
||||
|
||||
type osResolverResult struct {
|
||||
@@ -197,26 +219,75 @@ type osResolverResult struct {
|
||||
lan bool
|
||||
}
|
||||
|
||||
type publicResponse struct {
|
||||
answer *dns.Msg
|
||||
server string
|
||||
}
|
||||
|
||||
// SetDefaultLocalIPv4 updates the stored local IPv4.
|
||||
func SetDefaultLocalIPv4(ip net.IP) {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip)
|
||||
defaultLocalIPv4.Store(ip)
|
||||
}
|
||||
|
||||
// SetDefaultLocalIPv6 updates the stored local IPv6.
|
||||
func SetDefaultLocalIPv6(ip net.IP) {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip)
|
||||
defaultLocalIPv6.Store(ip)
|
||||
}
|
||||
|
||||
// GetDefaultLocalIPv4 returns the stored local IPv4 or nil if none.
|
||||
func GetDefaultLocalIPv4() net.IP {
|
||||
if v := defaultLocalIPv4.Load(); v != nil {
|
||||
return v.(net.IP)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefaultLocalIPv6 returns the stored local IPv6 or nil if none.
|
||||
func GetDefaultLocalIPv6() net.IP {
|
||||
if v := defaultLocalIPv6.Load(); v != nil {
|
||||
return v.(net.IP)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// customDNSExchange wraps the DNS exchange to use our debug dialer.
|
||||
// It uses dns.ExchangeWithConn so that our custom dialer is used directly.
|
||||
func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, time.Duration, error) {
|
||||
baseDialer := &net.Dialer{
|
||||
Timeout: 3 * time.Second,
|
||||
Resolver: &net.Resolver{PreferGo: true},
|
||||
}
|
||||
if desiredLocalIP != nil {
|
||||
baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0}
|
||||
}
|
||||
dnsClient := &dns.Client{Net: "udp"}
|
||||
dnsClient.Dialer = baseDialer
|
||||
return dnsClient.ExchangeContext(ctx, msg, server)
|
||||
}
|
||||
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
publicServers := *o.publicServer.Load()
|
||||
nss := make([]string, 0, 2)
|
||||
if p := o.currentLanServer.Load(); p != nil {
|
||||
nss = append(nss, net.JoinHostPort(p.String(), "53"))
|
||||
}
|
||||
if p := o.lastLanServer.Load(); p != nil {
|
||||
nss = append(nss, net.JoinHostPort(p.String(), "53"))
|
||||
publicServers := *o.publicServers.Load()
|
||||
var nss []string
|
||||
if p := o.lanServers.Load(); p != nil {
|
||||
nss = append(nss, (*p)...)
|
||||
}
|
||||
numServers := len(nss) + len(publicServers)
|
||||
// If this is a LAN query, skip public DNS.
|
||||
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
|
||||
if ok && lan {
|
||||
numServers -= len(publicServers)
|
||||
}
|
||||
if numServers == 0 {
|
||||
return nil, errors.New("no nameservers available")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
dnsClient := &dns.Client{Net: "udp"}
|
||||
ch := make(chan *osResolverResult, numServers)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(numServers)
|
||||
@@ -229,57 +300,86 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
for _, server := range servers {
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||
var answer *dns.Msg
|
||||
var err error
|
||||
var localOSResolverIP net.IP
|
||||
if runtime.GOOS == "darwin" {
|
||||
host, _, err := net.SplitHostPort(server)
|
||||
if err == nil {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.To4() == nil {
|
||||
// IPv6 nameserver; use default IPv6 address (if set)
|
||||
localOSResolverIP = GetDefaultLocalIPv6()
|
||||
} else {
|
||||
localOSResolverIP = GetDefaultLocalIPv4()
|
||||
}
|
||||
}
|
||||
}
|
||||
answer, _, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP)
|
||||
ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan}
|
||||
}(server)
|
||||
}
|
||||
}
|
||||
do(nss, true)
|
||||
do(publicServers, false)
|
||||
if !lan {
|
||||
do(publicServers, false)
|
||||
}
|
||||
|
||||
logAnswer := func(server string) {
|
||||
if before, _, found := strings.Cut(server, ":"); found {
|
||||
server = before
|
||||
host, _, err := net.SplitHostPort(server)
|
||||
if err != nil {
|
||||
// If splitting fails, fallback to the original server string
|
||||
host = server
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", server)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host)
|
||||
}
|
||||
var (
|
||||
nonSuccessAnswer *dns.Msg
|
||||
nonSuccessServer string
|
||||
controldSuccessAnswer *dns.Msg
|
||||
publicServerAnswer *dns.Msg
|
||||
publicServer string
|
||||
publicResponses []publicResponse
|
||||
)
|
||||
errs := make([]error, 0, numServers)
|
||||
for res := range ch {
|
||||
switch {
|
||||
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
|
||||
switch {
|
||||
case res.server == controldPublicDnsWithPort:
|
||||
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
|
||||
case !res.lan && publicServerAnswer == nil:
|
||||
publicServerAnswer = res.answer // use public DNS answer after LAN server..
|
||||
publicServer = res.server
|
||||
default:
|
||||
case res.lan:
|
||||
// Always prefer LAN responses immediately
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server)
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
case res.server == controldPublicDnsWithPort:
|
||||
controldSuccessAnswer = res.answer
|
||||
case !res.lan:
|
||||
publicResponses = append(publicResponses, publicResponse{
|
||||
answer: res.answer,
|
||||
server: res.server,
|
||||
})
|
||||
}
|
||||
case res.answer != nil:
|
||||
nonSuccessAnswer = res.answer
|
||||
nonSuccessServer = res.server
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d",
|
||||
res.server, res.answer.Rcode)
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
if publicServerAnswer != nil {
|
||||
logAnswer(publicServer)
|
||||
return publicServerAnswer, nil
|
||||
|
||||
if len(publicResponses) > 0 {
|
||||
resp := publicResponses[0]
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server)
|
||||
logAnswer(resp.server)
|
||||
return resp.answer, nil
|
||||
}
|
||||
if controldSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort)
|
||||
logAnswer(controldPublicDnsWithPort)
|
||||
return controldSuccessAnswer, nil
|
||||
}
|
||||
if nonSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s", nonSuccessServer)
|
||||
logAnswer(nonSuccessServer)
|
||||
return nonSuccessAnswer, nil
|
||||
}
|
||||
@@ -328,7 +428,11 @@ func LookupIP(domain string) []string {
|
||||
}
|
||||
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
nss := defaultNameservers()
|
||||
if or == nil {
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
nss := *or.lanServers.Load()
|
||||
nss = append(nss, *or.publicServers.Load()...)
|
||||
if withBootstrapDNS {
|
||||
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
|
||||
}
|
||||
@@ -467,17 +571,19 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
// The caller must ensure each server in list is formed "ip:53".
|
||||
func newResolverWithNameserver(nameservers []string) *osResolver {
|
||||
r := &osResolver{}
|
||||
nss := slices.Sorted(slices.Values(nameservers))
|
||||
for i, ns := range nss {
|
||||
var publicNss []string
|
||||
var lanNss []string
|
||||
for _, ns := range slices.Sorted(slices.Values(nameservers)) {
|
||||
ip, _, _ := net.SplitHostPort(ns)
|
||||
addr, _ := netip.ParseAddr(ip)
|
||||
if isLanAddr(addr) {
|
||||
r.currentLanServer.Store(&addr)
|
||||
nss = slices.Delete(nss, i, i+1)
|
||||
break
|
||||
lanNss = append(lanNss, ns)
|
||||
} else {
|
||||
publicNss = append(publicNss, ns)
|
||||
}
|
||||
}
|
||||
r.publicServer.Store(&nss)
|
||||
r.lanServers.Store(&lanNss)
|
||||
r.publicServers.Store(&publicNss)
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
113
resolver_test.go
113
resolver_test.go
@@ -3,13 +3,10 @@ package ctrld
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@@ -20,7 +17,7 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
go func() {
|
||||
defer cancel()
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServer.Store(&[]string{"127.0.0.127:5353"})
|
||||
resolver.publicServers.Store(&[]string{"127.0.0.127:5353"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
@@ -34,26 +31,51 @@ func Test_osResolver_Resolve(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_ResolveLanHostname(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
reqId := "req-id"
|
||||
ctx = context.WithValue(ctx, ReqIdCtxKey{}, reqId)
|
||||
ctx = LanQueryCtx(ctx)
|
||||
|
||||
go func(ctx context.Context) {
|
||||
defer cancel()
|
||||
id, ok := ctx.Value(ReqIdCtxKey{}).(string)
|
||||
if !ok || id != reqId {
|
||||
t.Error("missing request id")
|
||||
return
|
||||
}
|
||||
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
|
||||
if !ok || !lan {
|
||||
t.Error("not a LAN query")
|
||||
return
|
||||
}
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServers.Store(&[]string{"76.76.2.0:53"})
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
_, err := resolver.Resolve(ctx, m)
|
||||
if err == nil {
|
||||
t.Error("os resolver succeeded unexpectedly")
|
||||
return
|
||||
}
|
||||
}(ctx)
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Error("os resolver hangs")
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
ns := make([]string, 0, 2)
|
||||
servers := make([]*dns.Server, 0, 2)
|
||||
successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, dns.RcodeSuccess)
|
||||
w.WriteMsg(m)
|
||||
})
|
||||
nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc {
|
||||
return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, rcode)
|
||||
w.WriteMsg(m)
|
||||
})
|
||||
}
|
||||
|
||||
handlers := []dns.Handler{
|
||||
nonSuccessHandlerWithRcode(dns.RcodeRefused),
|
||||
nonSuccessHandlerWithRcode(dns.RcodeNameError),
|
||||
successHandler,
|
||||
successHandler(),
|
||||
}
|
||||
for i := range handlers {
|
||||
pc, err := net.ListenPacket("udp", ":0")
|
||||
@@ -74,7 +96,7 @@ func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
resolver := &osResolver{}
|
||||
resolver.publicServer.Store(&ns)
|
||||
resolver.publicServers.Store(&ns)
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
answer, err := resolver.Resolve(context.Background(), msg)
|
||||
@@ -93,7 +115,7 @@ func Test_osResolver_InitializationRace(t *testing.T) {
|
||||
for range n {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
InitializeOsResolver()
|
||||
InitializeOsResolver(false)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
@@ -153,41 +175,18 @@ func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.H
|
||||
return server, addr, nil
|
||||
}
|
||||
|
||||
func Test_initializeOsResolver(t *testing.T) {
|
||||
lanServer1 := "192.168.1.1"
|
||||
lanServer2 := "10.0.10.69"
|
||||
wanServer := "1.1.1.1"
|
||||
publicServers := []string{net.JoinHostPort(wanServer, "53")}
|
||||
|
||||
// First initialization.
|
||||
initializeOsResolver([]string{lanServer1, wanServer})
|
||||
p := or.currentLanServer.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.Equal(t, lanServer1, p.String())
|
||||
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
|
||||
|
||||
// No new LAN server, current LAN server -> last LAN server.
|
||||
initializeOsResolver([]string{lanServer1, wanServer})
|
||||
p = or.currentLanServer.Load()
|
||||
assert.Nil(t, p)
|
||||
p = or.lastLanServer.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.Equal(t, lanServer1, p.String())
|
||||
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
|
||||
|
||||
// New LAN server detected.
|
||||
initializeOsResolver([]string{lanServer2, lanServer1, wanServer})
|
||||
p = or.currentLanServer.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.Equal(t, lanServer2, p.String())
|
||||
p = or.lastLanServer.Load()
|
||||
assert.NotNil(t, p)
|
||||
assert.Equal(t, lanServer1, p.String())
|
||||
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
|
||||
|
||||
// No LAN server available.
|
||||
initializeOsResolver([]string{wanServer})
|
||||
assert.Nil(t, or.currentLanServer.Load())
|
||||
assert.Nil(t, or.lastLanServer.Load())
|
||||
assert.True(t, slices.Equal(*or.publicServer.Load(), publicServers))
|
||||
func successHandler() dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, dns.RcodeSuccess)
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
|
||||
func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(msg, rcode)
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user