mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-16 10:22:45 +00:00
Compare commits
59 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6586fd360 | ||
|
|
33a6db2599 | ||
|
|
70b0c4f7b9 | ||
|
|
5af3ec4f7b | ||
|
|
79476add12 | ||
|
|
1634a06330 | ||
|
|
a007394f60 | ||
|
|
62a0ba8731 | ||
|
|
e8d3ed1acd | ||
|
|
8b98faa441 | ||
|
|
30320ec9c7 | ||
|
|
5f4a399850 | ||
|
|
82e0d4b0c4 | ||
|
|
95a9df826d | ||
|
|
3b71d26cf3 | ||
|
|
c233ad9b1b | ||
|
|
12d6484b1c | ||
|
|
bc7b1cc6d8 | ||
|
|
ec684348ed | ||
|
|
18a19a3aa2 | ||
|
|
905f2d08c5 | ||
|
|
04947b4d87 | ||
|
|
72bf80533e | ||
|
|
9ddedf926e | ||
|
|
139dd62ff3 | ||
|
|
50ef00526e | ||
|
|
80cf79b9cb | ||
|
|
e6ad39b070 | ||
|
|
56f9c72569 | ||
|
|
dc48c908b8 | ||
|
|
9b0f0e792a | ||
|
|
b3eebb19b6 | ||
|
|
c24589a5be | ||
|
|
1e1c5a4dc8 | ||
|
|
339023421a | ||
|
|
a00d2a431a | ||
|
|
5aca118dbb | ||
|
|
411f7434f4 | ||
|
|
34801382f5 | ||
|
|
b9f2259ae4 | ||
|
|
19020a96bf | ||
|
|
96085147ff | ||
|
|
f3dd344026 | ||
|
|
486096416f | ||
|
|
5710f2e984 | ||
|
|
09936f1f07 | ||
|
|
0d6ca57536 | ||
|
|
3ddcb84db8 | ||
|
|
1012bf063f | ||
|
|
b8155e6182 | ||
|
|
9a34df61bb | ||
|
|
fbb879edf9 | ||
|
|
ac97c88876 | ||
|
|
a1fda2c0de | ||
|
|
f499770d45 | ||
|
|
4769da4ef4 | ||
|
|
c2556a8e39 | ||
|
|
29bf329f6a | ||
|
|
1dee4305bc |
@@ -4,6 +4,8 @@
|
|||||||
[](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
|
[](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
|
||||||
[](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
|
[](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
A highly configurable DNS forwarding proxy with support for:
|
A highly configurable DNS forwarding proxy with support for:
|
||||||
- Multiple listeners for incoming queries
|
- Multiple listeners for incoming queries
|
||||||
- Multiple upstreams with fallbacks
|
- Multiple upstreams with fallbacks
|
||||||
@@ -103,9 +105,11 @@ Available Commands:
|
|||||||
start Quick start service and configure DNS on interface
|
start Quick start service and configure DNS on interface
|
||||||
stop Quick stop service and remove DNS from interface
|
stop Quick stop service and remove DNS from interface
|
||||||
restart Restart the ctrld service
|
restart Restart the ctrld service
|
||||||
|
reload Reload the ctrld service
|
||||||
status Show status of the ctrld service
|
status Show status of the ctrld service
|
||||||
uninstall Stop and uninstall the ctrld service
|
uninstall Stop and uninstall the ctrld service
|
||||||
clients Manage clients
|
clients Manage clients
|
||||||
|
upgrade Upgrading ctrld to latest version
|
||||||
|
|
||||||
Flags:
|
Flags:
|
||||||
-h, --help help for ctrld
|
-h, --help help for ctrld
|
||||||
|
|||||||
679
cmd/cli/cli.go
679
cmd/cli/cli.go
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@ func Test_writeConfigFile(t *testing.T) {
|
|||||||
_, err := os.Stat(configPath)
|
_, err := os.Stat(configPath)
|
||||||
assert.True(t, os.IsNotExist(err))
|
assert.True(t, os.IsNotExist(err))
|
||||||
|
|
||||||
assert.NoError(t, writeConfigFile())
|
assert.NoError(t, writeConfigFile(&cfg))
|
||||||
|
|
||||||
_, err = os.Stat(configPath)
|
_, err = os.Stat(configPath)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/kardianos/service"
|
||||||
|
|
||||||
dto "github.com/prometheus/client_model/go"
|
dto "github.com/prometheus/client_model/go"
|
||||||
|
|
||||||
"github.com/Control-D-Inc/ctrld"
|
"github.com/Control-D-Inc/ctrld"
|
||||||
@@ -22,6 +24,7 @@ const (
|
|||||||
reloadPath = "/reload"
|
reloadPath = "/reload"
|
||||||
deactivationPath = "/deactivation"
|
deactivationPath = "/deactivation"
|
||||||
cdPath = "/cd"
|
cdPath = "/cd"
|
||||||
|
ifacePath = "/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
type controlServer struct {
|
type controlServer struct {
|
||||||
@@ -70,7 +73,7 @@ func (p *prog) registerControlServerHandler() {
|
|||||||
sort.Slice(clients, func(i, j int) bool {
|
sort.Slice(clients, func(i, j int) bool {
|
||||||
return clients[i].IP.Less(clients[j].IP)
|
return clients[i].IP.Less(clients[j].IP)
|
||||||
})
|
})
|
||||||
if p.cfg.Service.MetricsQueryStats {
|
if p.metricsQueryStats.Load() {
|
||||||
for _, client := range clients {
|
for _, client := range clients {
|
||||||
client.IncludeQueryCount = true
|
client.IncludeQueryCount = true
|
||||||
dm := &dto.Metric{}
|
dm := &dto.Metric{}
|
||||||
@@ -175,10 +178,22 @@ func (p *prog) registerControlServerHandler() {
|
|||||||
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||||
if cdUID != "" {
|
if cdUID != "" {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(cdUID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
}))
|
}))
|
||||||
|
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||||
|
// p.setDNS is only called when running as a service
|
||||||
|
if !service.Interactive() {
|
||||||
|
<-p.csSetDnsDone
|
||||||
|
if p.csSetDnsOk {
|
||||||
|
w.Write([]byte(iface))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func jsonResponse(next http.Handler) http.Handler {
|
func jsonResponse(next http.Handler) http.Handler {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
|
|
||||||
"github.com/Control-D-Inc/ctrld"
|
"github.com/Control-D-Inc/ctrld"
|
||||||
|
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||||
)
|
)
|
||||||
@@ -32,6 +33,9 @@ const (
|
|||||||
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
||||||
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
||||||
EDNS0_OPTION_MAC = 0xFDE9
|
EDNS0_OPTION_MAC = 0xFDE9
|
||||||
|
|
||||||
|
// selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation.
|
||||||
|
selfUninstallMaxQueries = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||||
@@ -89,6 +93,7 @@ func (p *prog) serveDNS(listenerNum string) error {
|
|||||||
_ = w.WriteMsg(answer)
|
_ = w.WriteMsg(answer)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
listenerConfig := p.cfg.Listener[listenerNum]
|
||||||
reqId := requestID()
|
reqId := requestID()
|
||||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||||
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
||||||
@@ -143,6 +148,7 @@ func (p *prog) serveDNS(listenerNum string) error {
|
|||||||
failoverRcodes: failoverRcode,
|
failoverRcodes: failoverRcode,
|
||||||
ufr: ur,
|
ufr: ur,
|
||||||
})
|
})
|
||||||
|
go p.doSelfUninstall(pr.answer)
|
||||||
answer = pr.answer
|
answer = pr.answer
|
||||||
rtt := time.Since(t)
|
rtt := time.Since(t)
|
||||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||||
@@ -210,12 +216,9 @@ func (p *prog) serveDNS(listenerNum string) error {
|
|||||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||||
s, errCh := runDNSServer(addr, proto, handler)
|
s, errCh := runDNSServer(addr, proto, handler)
|
||||||
defer s.Shutdown()
|
defer s.Shutdown()
|
||||||
select {
|
|
||||||
case err := <-errCh:
|
p.started <- struct{}{}
|
||||||
return err
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
p.started <- struct{}{}
|
|
||||||
}
|
|
||||||
select {
|
select {
|
||||||
case <-p.stopCh:
|
case <-p.stopCh:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -752,20 +755,19 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha
|
|||||||
Handler: handler,
|
Handler: handler,
|
||||||
}
|
}
|
||||||
|
|
||||||
waitLock := sync.Mutex{}
|
startedCh := make(chan struct{})
|
||||||
waitLock.Lock()
|
s.NotifyStartedFunc = func() { sync.OnceFunc(func() { close(startedCh) })() }
|
||||||
s.NotifyStartedFunc = waitLock.Unlock
|
|
||||||
|
|
||||||
errCh := make(chan error)
|
errCh := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(errCh)
|
defer close(errCh)
|
||||||
if err := s.ListenAndServe(); err != nil {
|
if err := s.ListenAndServe(); err != nil {
|
||||||
waitLock.Unlock()
|
s.NotifyStartedFunc()
|
||||||
mainLog.Load().Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
mainLog.Load().Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
waitLock.Lock()
|
<-startedCh
|
||||||
return s, errCh
|
return s, errCh
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -840,6 +842,51 @@ func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// doSelfUninstall performs self-uninstall if these condition met:
|
||||||
|
//
|
||||||
|
// - There is only 1 ControlD upstream in-use.
|
||||||
|
// - Number of refused queries seen so far equals to selfUninstallMaxQueries.
|
||||||
|
// - The cdUID is deleted.
|
||||||
|
func (p *prog) doSelfUninstall(answer *dns.Msg) {
|
||||||
|
if !p.canSelfUninstall.Load() || answer == nil || answer.Rcode != dns.RcodeRefused {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.selfUninstallMu.Lock()
|
||||||
|
defer p.selfUninstallMu.Unlock()
|
||||||
|
if p.checkingSelfUninstall {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := mainLog.Load().With().Str("mode", "self-uninstall").Logger()
|
||||||
|
if p.refusedQueryCount > selfUninstallMaxQueries {
|
||||||
|
p.checkingSelfUninstall = true
|
||||||
|
_, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||||
|
logger.Debug().Msg("maximum number of refused queries reached, checking device status")
|
||||||
|
selfUninstallCheck(err, p, logger)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||||
|
}
|
||||||
|
// Cool-of period to prevent abusing the API.
|
||||||
|
go p.selfUninstallCoolOfPeriod()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.refusedQueryCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// selfUninstallCoolOfPeriod waits for 30 minutes before
|
||||||
|
// calling API again for checking ControlD device status.
|
||||||
|
func (p *prog) selfUninstallCoolOfPeriod() {
|
||||||
|
t := time.NewTimer(time.Minute * 30)
|
||||||
|
defer t.Stop()
|
||||||
|
<-t.C
|
||||||
|
p.selfUninstallMu.Lock()
|
||||||
|
p.checkingSelfUninstall = false
|
||||||
|
p.refusedQueryCount = 0
|
||||||
|
p.selfUninstallMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// queryFromSelf reports whether the input IP is from device running ctrld.
|
// queryFromSelf reports whether the input IP is from device running ctrld.
|
||||||
func queryFromSelf(ip string) bool {
|
func queryFromSelf(ip string) bool {
|
||||||
netIP := netip.MustParseAddr(ip)
|
netIP := netip.MustParseAddr(ip)
|
||||||
|
|||||||
@@ -105,6 +105,10 @@ func (p *prog) checkDnsLoop() {
|
|||||||
for uid := range p.loop {
|
for uid := range p.loop {
|
||||||
msg := loopTestMsg(uid)
|
msg := loopTestMsg(uid)
|
||||||
uc := upstream[uid]
|
uc := upstream[uid]
|
||||||
|
// Skipping upstream which is being marked as down.
|
||||||
|
if uc == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
resolver, err := ctrld.NewResolver(uc)
|
resolver, err := ctrld.NewResolver(uc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||||
|
|||||||
@@ -35,6 +35,9 @@ var (
|
|||||||
nextdns string
|
nextdns string
|
||||||
cdUpstreamProto string
|
cdUpstreamProto string
|
||||||
deactivationPin int64
|
deactivationPin int64
|
||||||
|
skipSelfChecks bool
|
||||||
|
cleanup bool
|
||||||
|
startOnly bool
|
||||||
|
|
||||||
mainLog atomic.Pointer[zerolog.Logger]
|
mainLog atomic.Pointer[zerolog.Logger]
|
||||||
consoleWriter zerolog.ConsoleWriter
|
consoleWriter zerolog.ConsoleWriter
|
||||||
@@ -62,8 +65,11 @@ func Main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func normalizeLogFilePath(logFilePath string) string {
|
func normalizeLogFilePath(logFilePath string) string {
|
||||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
// In cleanup mode, we always want the full log file path.
|
||||||
return logFilePath
|
if !cleanup {
|
||||||
|
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||||
|
return logFilePath
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if homedir != "" {
|
if homedir != "" {
|
||||||
return filepath.Join(homedir, logFilePath)
|
return filepath.Join(homedir, logFilePath)
|
||||||
@@ -120,14 +126,14 @@ func initLoggingWithBackup(doBackup bool) {
|
|||||||
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
|
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
|
||||||
if doBackup {
|
if doBackup {
|
||||||
// Backup old log file with .1 suffix.
|
// Backup old log file with .1 suffix.
|
||||||
if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) {
|
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
|
||||||
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
|
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
|
||||||
} else {
|
} else {
|
||||||
// Backup was created, set flags for truncating old log file.
|
// Backup was created, set flags for truncating old log file.
|
||||||
flags = os.O_CREATE | os.O_RDWR
|
flags = os.O_CREATE | os.O_RDWR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logFile, err := os.OpenFile(logFilePath, flags, os.FileMode(0o600))
|
logFile, err := openLogFile(logFilePath, flags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
|
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
|||||||
|
|
||||||
reg := prometheus.NewRegistry()
|
reg := prometheus.NewRegistry()
|
||||||
// Register queries count stats if enabled.
|
// Register queries count stats if enabled.
|
||||||
if cfg.Service.MetricsQueryStats {
|
if p.metricsQueryStats.Load() {
|
||||||
reg.MustRegister(statsQueriesCount)
|
reg.MustRegister(statsQueriesCount)
|
||||||
reg.MustRegister(statsClientQueriesCount)
|
reg.MustRegister(statsClientQueriesCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,20 +43,32 @@ func networkServiceName(ifaceName string, r io.Reader) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// validInterface reports whether the *net.Interface is a valid one, which includes:
|
// validInterface reports whether the *net.Interface is a valid one.
|
||||||
//
|
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||||
// - en0: physical wireless
|
_, ok := validIfacesMap[iface.Name]
|
||||||
// - en1: Thunderbolt 1
|
return ok
|
||||||
// - en2: Thunderbolt 2
|
}
|
||||||
// - en3: Thunderbolt 3
|
|
||||||
// - en4: Thunderbolt 4
|
func validInterfacesMap() map[string]struct{} {
|
||||||
//
|
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
|
||||||
// For full list, see: https://unix.stackexchange.com/questions/603506/what-are-these-ifconfig-interfaces-on-macos
|
if err != nil {
|
||||||
func validInterface(iface *net.Interface) bool {
|
return nil
|
||||||
switch iface.Name {
|
}
|
||||||
case "en0", "en1", "en2", "en3", "en4":
|
return parseListAllHardwarePorts(bytes.NewReader(b))
|
||||||
return true
|
}
|
||||||
default:
|
|
||||||
return false
|
// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports"
|
||||||
}
|
// and returns map presents all hardware ports.
|
||||||
|
func parseListAllHardwarePorts(r io.Reader) map[string]struct{} {
|
||||||
|
m := make(map[string]struct{})
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
after, ok := strings.CutPrefix(line, "Device: ")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m[after] = struct{}{}
|
||||||
|
}
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"maps"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -57,3 +58,47 @@ func Test_networkServiceName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const listallhardwareportsOutput = `
|
||||||
|
Hardware Port: Ethernet Adapter (en6)
|
||||||
|
Device: en6
|
||||||
|
Ethernet Address: 3a:3e:fc:1e:ab:41
|
||||||
|
|
||||||
|
Hardware Port: Ethernet Adapter (en7)
|
||||||
|
Device: en7
|
||||||
|
Ethernet Address: 3a:3e:fc:1e:ab:42
|
||||||
|
|
||||||
|
Hardware Port: Thunderbolt Bridge
|
||||||
|
Device: bridge0
|
||||||
|
Ethernet Address: 36:21:bb:3a:7a:40
|
||||||
|
|
||||||
|
Hardware Port: Wi-Fi
|
||||||
|
Device: en0
|
||||||
|
Ethernet Address: a0:78:17:68:56:3f
|
||||||
|
|
||||||
|
Hardware Port: Thunderbolt 1
|
||||||
|
Device: en1
|
||||||
|
Ethernet Address: 36:21:bb:3a:7a:40
|
||||||
|
|
||||||
|
Hardware Port: Thunderbolt 2
|
||||||
|
Device: en2
|
||||||
|
Ethernet Address: 36:21:bb:3a:7a:44
|
||||||
|
|
||||||
|
VLAN Configurations
|
||||||
|
===================
|
||||||
|
`
|
||||||
|
|
||||||
|
func Test_parseListAllHardwarePorts(t *testing.T) {
|
||||||
|
expected := map[string]struct{}{
|
||||||
|
"en0": {},
|
||||||
|
"en1": {},
|
||||||
|
"en2": {},
|
||||||
|
"en6": {},
|
||||||
|
"en7": {},
|
||||||
|
"bridge0": {},
|
||||||
|
}
|
||||||
|
m := parseListAllHardwarePorts(strings.NewReader(listallhardwareportsOutput))
|
||||||
|
if !maps.Equal(m, expected) {
|
||||||
|
t.Errorf("unexpected output, want: %v, got: %v", expected, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,4 +6,6 @@ import "net"
|
|||||||
|
|
||||||
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
||||||
|
|
||||||
func validInterface(iface *net.Interface) bool { return true }
|
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||||
|
|
||||||
|
func validInterfacesMap() map[string]struct{} { return nil }
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ func patchNetIfaceName(iface *net.Interface) error {
|
|||||||
|
|
||||||
// validInterface reports whether the *net.Interface is a valid one.
|
// validInterface reports whether the *net.Interface is a valid one.
|
||||||
// On Windows, only physical interfaces are considered valid.
|
// On Windows, only physical interfaces are considered valid.
|
||||||
func validInterface(iface *net.Interface) bool {
|
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||||
if iface == nil {
|
if iface == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -19,3 +19,5 @@ func validInterface(iface *net.Interface) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validInterfacesMap() map[string]struct{} { return nil }
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import (
|
|||||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||||
|
|
||||||
// allocate loopback ip
|
// allocate loopback ip
|
||||||
// sudo ip a add 127.0.0.2/24 dev lo
|
// sudo ip a add 127.0.0.2/24 dev lo
|
||||||
func allocateIP(ip string) error {
|
func allocateIP(ip string) error {
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
forwardersFilename = ".forwarders.txt"
|
|
||||||
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||||
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||||
)
|
)
|
||||||
@@ -40,18 +39,13 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
|||||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||||
// Configuring the Dns server to forward queries to ctrld instead.
|
// Configuring the Dns server to forward queries to ctrld instead.
|
||||||
if windowsHasLocalDnsServerRunning() {
|
if windowsHasLocalDnsServerRunning() {
|
||||||
file := absHomeDir(forwardersFilename)
|
file := absHomeDir(windowsForwardersFilename)
|
||||||
if data, _ := os.ReadFile(file); len(data) > 0 {
|
oldForwardersContent, _ := os.ReadFile(file)
|
||||||
if err := removeDnsServerForwarders(strings.Split(string(data), ",")); err != nil {
|
|
||||||
mainLog.Load().Error().Err(err).Msg("could not remove current forwarders settings")
|
|
||||||
} else {
|
|
||||||
mainLog.Load().Debug().Msg("removed current forwarders settings.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil {
|
if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||||
}
|
}
|
||||||
if err := addDnsServerForwarders(nameservers); err != nil {
|
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
||||||
|
if err := addDnsServerForwarders(nameservers, oldForwarders); err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -77,7 +71,7 @@ func resetDNS(iface *net.Interface) error {
|
|||||||
resetDNSOnce.Do(func() {
|
resetDNSOnce.Do(func() {
|
||||||
// See corresponding comment in setDNS.
|
// See corresponding comment in setDNS.
|
||||||
if windowsHasLocalDnsServerRunning() {
|
if windowsHasLocalDnsServerRunning() {
|
||||||
file := absHomeDir(forwardersFilename)
|
file := absHomeDir(windowsForwardersFilename)
|
||||||
content, err := os.ReadFile(file)
|
content, err := os.ReadFile(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
||||||
@@ -213,14 +207,32 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
|||||||
return ns, nil
|
return ns, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addDnsServerForwarders adds given nameservers to DNS server forwarders list.
|
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
||||||
func addDnsServerForwarders(nameservers []string) error {
|
// and also removing old forwarders if provided.
|
||||||
for _, ns := range nameservers {
|
func addDnsServerForwarders(nameservers, old []string) error {
|
||||||
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", ns)
|
newForwardersMap := make(map[string]struct{})
|
||||||
if out, err := powershell(cmd); err != nil {
|
newForwarders := make([]string, len(nameservers))
|
||||||
return fmt.Errorf("%w: %s", err, string(out))
|
for i := range nameservers {
|
||||||
|
newForwardersMap[nameservers[i]] = struct{}{}
|
||||||
|
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
|
||||||
|
}
|
||||||
|
oldForwarders := old[:0]
|
||||||
|
for _, fwd := range old {
|
||||||
|
if _, ok := newForwardersMap[fwd]; !ok {
|
||||||
|
oldForwarders = append(oldForwarders, fwd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// NOTE: It is important to add new forwarder before removing old one.
|
||||||
|
// Testing on Windows Server 2022 shows that removing forwarder1
|
||||||
|
// then adding forwarder2 sometimes ends up adding both of them
|
||||||
|
// to the forwarders list.
|
||||||
|
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
|
||||||
|
if len(oldForwarders) > 0 {
|
||||||
|
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
|
||||||
|
}
|
||||||
|
if out, err := powershell(cmd); err != nil {
|
||||||
|
return fmt.Errorf("%w: %s", err, string(out))
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
301
cmd/cli/prog.go
301
cmd/cli/prog.go
@@ -12,19 +12,24 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"tailscale.com/net/interfaces"
|
"tailscale.com/net/interfaces"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
|
|
||||||
"github.com/Control-D-Inc/ctrld"
|
"github.com/Control-D-Inc/ctrld"
|
||||||
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
||||||
|
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||||
)
|
)
|
||||||
@@ -38,6 +43,7 @@ const (
|
|||||||
upstreamPrefix = "upstream."
|
upstreamPrefix = "upstream."
|
||||||
upstreamOS = upstreamPrefix + "os"
|
upstreamOS = upstreamPrefix + "os"
|
||||||
upstreamPrivate = upstreamPrefix + "private"
|
upstreamPrivate = upstreamPrefix + "private"
|
||||||
|
dnsWatchdogDefaultInterval = 20 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// ControlSocketName returns name for control unix socket.
|
// ControlSocketName returns name for control unix socket.
|
||||||
@@ -62,13 +68,19 @@ var svcConfig = &service.Config{
|
|||||||
var useSystemdResolved = false
|
var useSystemdResolved = false
|
||||||
|
|
||||||
type prog struct {
|
type prog struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
waitCh chan struct{}
|
waitCh chan struct{}
|
||||||
stopCh chan struct{}
|
stopCh chan struct{}
|
||||||
reloadCh chan struct{} // For Windows.
|
reloadCh chan struct{} // For Windows.
|
||||||
reloadDoneCh chan struct{}
|
reloadDoneCh chan struct{}
|
||||||
logConn net.Conn
|
apiReloadCh chan *ctrld.Config
|
||||||
cs *controlServer
|
logConn net.Conn
|
||||||
|
cs *controlServer
|
||||||
|
csSetDnsDone chan struct{}
|
||||||
|
csSetDnsOk bool
|
||||||
|
dnsWatchDogOnce sync.Once
|
||||||
|
dnsWg sync.WaitGroup
|
||||||
|
dnsWatcherStopCh chan struct{}
|
||||||
|
|
||||||
cfg *ctrld.Config
|
cfg *ctrld.Config
|
||||||
localUpstreams []string
|
localUpstreams []string
|
||||||
@@ -82,6 +94,12 @@ type prog struct {
|
|||||||
router router.Router
|
router router.Router
|
||||||
ptrLoopGuard *loopGuard
|
ptrLoopGuard *loopGuard
|
||||||
lanLoopGuard *loopGuard
|
lanLoopGuard *loopGuard
|
||||||
|
metricsQueryStats atomic.Bool
|
||||||
|
|
||||||
|
selfUninstallMu sync.Mutex
|
||||||
|
refusedQueryCount int
|
||||||
|
canSelfUninstall atomic.Bool
|
||||||
|
checkingSelfUninstall bool
|
||||||
|
|
||||||
loopMu sync.Mutex
|
loopMu sync.Mutex
|
||||||
loop map[string]bool
|
loop map[string]bool
|
||||||
@@ -115,11 +133,15 @@ func (p *prog) runWait() {
|
|||||||
p.run(reload, reloadCh)
|
p.run(reload, reloadCh)
|
||||||
reload = true
|
reload = true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var newCfg *ctrld.Config
|
||||||
select {
|
select {
|
||||||
case sig := <-reloadSigCh:
|
case sig := <-reloadSigCh:
|
||||||
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
||||||
case <-p.reloadCh:
|
case <-p.reloadCh:
|
||||||
logger.Notice().Msg("reloading...")
|
logger.Notice().Msg("reloading...")
|
||||||
|
case apiCfg := <-p.apiReloadCh:
|
||||||
|
newCfg = apiCfg
|
||||||
case <-p.stopCh:
|
case <-p.stopCh:
|
||||||
close(reloadCh)
|
close(reloadCh)
|
||||||
return
|
return
|
||||||
@@ -129,28 +151,31 @@ func (p *prog) runWait() {
|
|||||||
close(reloadCh)
|
close(reloadCh)
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
newCfg := &ctrld.Config{}
|
|
||||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
if newCfg == nil {
|
||||||
ctrld.InitConfig(v, "ctrld")
|
newCfg = &ctrld.Config{}
|
||||||
if configPath != "" {
|
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||||
v.SetConfigFile(configPath)
|
ctrld.InitConfig(v, "ctrld")
|
||||||
}
|
if configPath != "" {
|
||||||
if err := v.ReadInConfig(); err != nil {
|
v.SetConfigFile(configPath)
|
||||||
logger.Err(err).Msg("could not read new config")
|
}
|
||||||
waitOldRunDone()
|
if err := v.ReadInConfig(); err != nil {
|
||||||
continue
|
logger.Err(err).Msg("could not read new config")
|
||||||
}
|
|
||||||
if err := v.Unmarshal(&newCfg); err != nil {
|
|
||||||
logger.Err(err).Msg("could not unmarshal new config")
|
|
||||||
waitOldRunDone()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if cdUID != "" {
|
|
||||||
if err := processCDFlags(newCfg); err != nil {
|
|
||||||
logger.Err(err).Msg("could not fetch ControlD config")
|
|
||||||
waitOldRunDone()
|
waitOldRunDone()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if err := v.Unmarshal(&newCfg); err != nil {
|
||||||
|
logger.Err(err).Msg("could not unmarshal new config")
|
||||||
|
waitOldRunDone()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cdUID != "" {
|
||||||
|
if err := processCDFlags(newCfg); err != nil {
|
||||||
|
logger.Err(err).Msg("could not fetch ControlD config")
|
||||||
|
waitOldRunDone()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
waitOldRunDone()
|
waitOldRunDone()
|
||||||
@@ -176,6 +201,10 @@ func (p *prog) runWait() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := writeConfigFile(newCfg); err != nil {
|
||||||
|
logger.Err(err).Msg("could not write new config")
|
||||||
|
}
|
||||||
|
|
||||||
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
||||||
// upstream config because its initialization function have not been called yet.
|
// upstream config because its initialization function have not been called yet.
|
||||||
mainLog.Load().Debug().Msg("setup upstream with new config")
|
mainLog.Load().Debug().Msg("setup upstream with new config")
|
||||||
@@ -186,6 +215,7 @@ func (p *prog) runWait() {
|
|||||||
p.mu.Unlock()
|
p.mu.Unlock()
|
||||||
|
|
||||||
logger.Notice().Msg("reloading config successfully")
|
logger.Notice().Msg("reloading config successfully")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case p.reloadDoneCh <- struct{}{}:
|
case p.reloadDoneCh <- struct{}{}:
|
||||||
default:
|
default:
|
||||||
@@ -194,9 +224,6 @@ func (p *prog) runWait() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *prog) preRun() {
|
func (p *prog) preRun() {
|
||||||
if !service.Interactive() {
|
|
||||||
p.setDNS()
|
|
||||||
}
|
|
||||||
if runtime.GOOS == "darwin" {
|
if runtime.GOOS == "darwin" {
|
||||||
p.onStopped = append(p.onStopped, func() {
|
p.onStopped = append(p.onStopped, func() {
|
||||||
if !service.Interactive() {
|
if !service.Interactive() {
|
||||||
@@ -206,12 +233,76 @@ func (p *prog) preRun() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *prog) postRun() {
|
||||||
|
if !service.Interactive() {
|
||||||
|
p.resetDNS()
|
||||||
|
ns := ctrld.InitializeOsResolver()
|
||||||
|
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
|
||||||
|
p.setDNS()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apiConfigReload calls API to check for latest config update then reload ctrld if necessary.
|
||||||
|
func (p *prog) apiConfigReload() {
|
||||||
|
if cdUID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
secs := 3600
|
||||||
|
if p.cfg.Service.RefetchTime != nil && *p.cfg.Service.RefetchTime > 0 {
|
||||||
|
secs = *p.cfg.Service.RefetchTime
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(time.Duration(secs) * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
logger := mainLog.Load().With().Str("mode", "api-reload").Logger()
|
||||||
|
logger.Debug().Msg("starting custom config reload timer")
|
||||||
|
lastUpdated := time.Now().Unix()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||||
|
selfUninstallCheck(err, p, logger)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resolverConfig.Ctrld.CustomConfig == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resolverConfig.Ctrld.CustomLastUpdate > lastUpdated {
|
||||||
|
lastUpdated = time.Now().Unix()
|
||||||
|
cfg := &ctrld.Config{}
|
||||||
|
if err := validateCdRemoteConfig(resolverConfig, cfg); err != nil {
|
||||||
|
logger.Warn().Err(err).Msg("skipping invalid custom config")
|
||||||
|
if _, err := controld.UpdateCustomLastFailed(cdUID, rootCmd.Version, cdDev, true); err != nil {
|
||||||
|
logger.Error().Err(err).Msg("could not mark custom last update failed")
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
setListenerDefaultValue(cfg)
|
||||||
|
logger.Debug().Msg("custom config changes detected, reloading...")
|
||||||
|
p.apiReloadCh <- cfg
|
||||||
|
} else {
|
||||||
|
logger.Debug().Msg("custom config does not change")
|
||||||
|
}
|
||||||
|
case <-p.stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
||||||
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
||||||
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
||||||
|
isControlDUpstream := false
|
||||||
for n := range cfg.Upstream {
|
for n := range cfg.Upstream {
|
||||||
uc := cfg.Upstream[n]
|
uc := cfg.Upstream[n]
|
||||||
uc.Init()
|
uc.Init()
|
||||||
|
isControlDUpstream = isControlDUpstream || uc.IsControlD()
|
||||||
if uc.BootstrapIP == "" {
|
if uc.BootstrapIP == "" {
|
||||||
uc.SetupBootstrapIP()
|
uc.SetupBootstrapIP()
|
||||||
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||||
@@ -228,6 +319,10 @@ func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
|||||||
ptrNameservers = append(ptrNameservers, uc.Endpoint)
|
ptrNameservers = append(ptrNameservers, uc.Endpoint)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Self-uninstallation is ok If there is only 1 ControlD upstream, and no remote config.
|
||||||
|
if len(cfg.Upstream) == 1 && isControlDUpstream {
|
||||||
|
p.canSelfUninstall.Store(true)
|
||||||
|
}
|
||||||
p.localUpstreams = localUpstreams
|
p.localUpstreams = localUpstreams
|
||||||
p.ptrNameservers = ptrNameservers
|
p.ptrNameservers = ptrNameservers
|
||||||
}
|
}
|
||||||
@@ -249,12 +344,21 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
|||||||
numListeners := len(p.cfg.Listener)
|
numListeners := len(p.cfg.Listener)
|
||||||
if !reload {
|
if !reload {
|
||||||
p.started = make(chan struct{}, numListeners)
|
p.started = make(chan struct{}, numListeners)
|
||||||
|
if p.cs != nil {
|
||||||
|
p.csSetDnsDone = make(chan struct{}, 1)
|
||||||
|
p.registerControlServerHandler()
|
||||||
|
if err := p.cs.start(); err != nil {
|
||||||
|
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
||||||
|
}
|
||||||
|
mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
p.onStartedDone = make(chan struct{})
|
p.onStartedDone = make(chan struct{})
|
||||||
p.loop = make(map[string]bool)
|
p.loop = make(map[string]bool)
|
||||||
p.lanLoopGuard = newLoopGuard()
|
p.lanLoopGuard = newLoopGuard()
|
||||||
p.ptrLoopGuard = newLoopGuard()
|
p.ptrLoopGuard = newLoopGuard()
|
||||||
p.cacheFlushDomainsMap = nil
|
p.cacheFlushDomainsMap = nil
|
||||||
|
p.metricsQueryStats.Store(p.cfg.Service.MetricsQueryStats)
|
||||||
if p.cfg.Service.CacheEnable {
|
if p.cfg.Service.CacheEnable {
|
||||||
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -381,12 +485,8 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
|||||||
if p.logConn != nil {
|
if p.logConn != nil {
|
||||||
_ = p.logConn.Close()
|
_ = p.logConn.Close()
|
||||||
}
|
}
|
||||||
if p.cs != nil {
|
go p.apiConfigReload()
|
||||||
p.registerControlServerHandler()
|
p.postRun()
|
||||||
if err := p.cs.start(); err != nil {
|
|
||||||
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
@@ -430,17 +530,25 @@ func (p *prog) deAllocateIP() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *prog) setDNS() {
|
func (p *prog) setDNS() {
|
||||||
|
setDnsOK := false
|
||||||
|
defer func() {
|
||||||
|
p.csSetDnsOk = setDnsOK
|
||||||
|
p.csSetDnsDone <- struct{}{}
|
||||||
|
close(p.csSetDnsDone)
|
||||||
|
}()
|
||||||
|
|
||||||
if cfg.Listener == nil {
|
if cfg.Listener == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if iface == "" {
|
if iface == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
runningIface := iface
|
||||||
// allIfaces tracks whether we should set DNS for all physical interfaces.
|
// allIfaces tracks whether we should set DNS for all physical interfaces.
|
||||||
allIfaces := false
|
allIfaces := false
|
||||||
if iface == "auto" {
|
if runningIface == "auto" {
|
||||||
iface = defaultIfaceName()
|
runningIface = defaultIfaceName()
|
||||||
// If iface is "auto", it means user does not specify "--iface" flag.
|
// If runningIface is "auto", it means user does not specify "--iface" flag.
|
||||||
// In this case, ctrld has to set DNS for all physical interfaces, so
|
// In this case, ctrld has to set DNS for all physical interfaces, so
|
||||||
// thing will still work when user switch from one to the other.
|
// thing will still work when user switch from one to the other.
|
||||||
allIfaces = requiredMultiNICsConfig()
|
allIfaces = requiredMultiNICsConfig()
|
||||||
@@ -449,8 +557,8 @@ func (p *prog) setDNS() {
|
|||||||
if lc == nil {
|
if lc == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger := mainLog.Load().With().Str("iface", iface).Logger()
|
logger := mainLog.Load().With().Str("iface", runningIface).Logger()
|
||||||
netIface, err := netInterface(iface)
|
netIface, err := netInterface(runningIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error().Err(err).Msg("could not get interface")
|
logger.Error().Err(err).Msg("could not get interface")
|
||||||
return
|
return
|
||||||
@@ -484,33 +592,108 @@ func (p *prog) setDNS() {
|
|||||||
logger.Error().Err(err).Msgf("could not set DNS for interface")
|
logger.Error().Err(err).Msgf("could not set DNS for interface")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
setDnsOK = true
|
||||||
logger.Debug().Msg("setting DNS successfully")
|
logger.Debug().Msg("setting DNS successfully")
|
||||||
if shouldWatchResolvconf() {
|
if shouldWatchResolvconf() {
|
||||||
servers := make([]netip.Addr, len(nameservers))
|
servers := make([]netip.Addr, len(nameservers))
|
||||||
for i := range nameservers {
|
for i := range nameservers {
|
||||||
servers[i] = netip.MustParseAddr(nameservers[i])
|
servers[i] = netip.MustParseAddr(nameservers[i])
|
||||||
}
|
}
|
||||||
go watchResolvConf(netIface, servers, setResolvConf)
|
p.dnsWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer p.dnsWg.Done()
|
||||||
|
p.watchResolvConf(netIface, servers, setResolvConf)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
if allIfaces {
|
if allIfaces {
|
||||||
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
|
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
|
||||||
return setDnsIgnoreUnusableInterface(i, nameservers)
|
return setDnsIgnoreUnusableInterface(i, nameservers)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
if p.dnsWatchdogEnabled() {
|
||||||
|
p.dnsWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer p.dnsWg.Done()
|
||||||
|
p.dnsWatchdog(netIface, nameservers, allIfaces)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsWatchdogEnabled reports whether DNS watchdog is enabled.
|
||||||
|
func (p *prog) dnsWatchdogEnabled() bool {
|
||||||
|
if ptr := p.cfg.Service.DnsWatchdogEnabled; ptr != nil {
|
||||||
|
return *ptr
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsWatchdogDuration returns the time duration between each DNS watchdog loop.
|
||||||
|
func (p *prog) dnsWatchdogDuration() time.Duration {
|
||||||
|
if ptr := p.cfg.Service.DnsWatchdogInvterval; ptr != nil {
|
||||||
|
if (*ptr).Seconds() > 0 {
|
||||||
|
return *ptr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dnsWatchdogDefaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsWatchdog watches for DNS changes on Darwin and Windows then re-applying ctrld's settings.
|
||||||
|
// This is only works when deactivation pin set.
|
||||||
|
func (p *prog) dnsWatchdog(iface *net.Interface, nameservers []string, allIfaces bool) {
|
||||||
|
if !requiredMultiNICsConfig() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.dnsWatchDogOnce.Do(func() {
|
||||||
|
mainLog.Load().Debug().Msg("start DNS settings watchdog")
|
||||||
|
ns := nameservers
|
||||||
|
slices.Sort(ns)
|
||||||
|
ticker := time.NewTicker(p.dnsWatchdogDuration())
|
||||||
|
logger := mainLog.Load().With().Str("iface", iface.Name).Logger()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-p.dnsWatcherStopCh:
|
||||||
|
return
|
||||||
|
case <-p.stopCh:
|
||||||
|
mainLog.Load().Debug().Msg("stop dns watchdog")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if dnsChanged(iface, ns) {
|
||||||
|
logger.Debug().Msg("DNS settings were changed, re-applying settings")
|
||||||
|
if err := setDNS(iface, ns); err != nil {
|
||||||
|
mainLog.Load().Error().Err(err).Str("iface", iface.Name).Msgf("could not re-apply DNS settings")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if allIfaces {
|
||||||
|
withEachPhysicalInterfaces(iface.Name, "re-applying DNS", func(i *net.Interface) error {
|
||||||
|
if dnsChanged(i, ns) {
|
||||||
|
if err := setDnsIgnoreUnusableInterface(i, nameservers); err != nil {
|
||||||
|
mainLog.Load().Error().Err(err).Str("iface", i.Name).Msgf("could not re-apply DNS settings")
|
||||||
|
} else {
|
||||||
|
mainLog.Load().Debug().Msgf("re-applying DNS for interface %q successfully", i.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *prog) resetDNS() {
|
func (p *prog) resetDNS() {
|
||||||
if iface == "" {
|
if iface == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
runningIface := iface
|
||||||
allIfaces := false
|
allIfaces := false
|
||||||
if iface == "auto" {
|
if runningIface == "auto" {
|
||||||
iface = defaultIfaceName()
|
runningIface = defaultIfaceName()
|
||||||
// See corresponding comments in (*prog).setDNS function.
|
// See corresponding comments in (*prog).setDNS function.
|
||||||
allIfaces = requiredMultiNICsConfig()
|
allIfaces = requiredMultiNICsConfig()
|
||||||
}
|
}
|
||||||
logger := mainLog.Load().With().Str("iface", iface).Logger()
|
logger := mainLog.Load().With().Str("iface", runningIface).Logger()
|
||||||
netIface, err := netInterface(iface)
|
netIface, err := netInterface(runningIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error().Err(err).Msg("could not get interface")
|
logger.Error().Err(err).Msg("could not get interface")
|
||||||
return
|
return
|
||||||
@@ -706,13 +889,14 @@ func canBeLocalUpstream(addr string) bool {
|
|||||||
// the interface that matches excludeIfaceName. The context is used to clarify the
|
// the interface that matches excludeIfaceName. The context is used to clarify the
|
||||||
// log message when error happens.
|
// log message when error happens.
|
||||||
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
|
func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.Interface) error) {
|
||||||
|
validIfacesMap := validInterfacesMap()
|
||||||
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||||
// Skip loopback/virtual interface.
|
// Skip loopback/virtual interface.
|
||||||
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
|
if i.IsLoopback() || len(i.HardwareAddr) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Skip invalid interface.
|
// Skip invalid interface.
|
||||||
if !validInterface(i.Interface) {
|
if !validInterface(i.Interface, validIfacesMap) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
netIface := i.Interface
|
netIface := i.Interface
|
||||||
@@ -726,7 +910,9 @@ func withEachPhysicalInterfaces(excludeIfaceName, context string, f func(i *net.
|
|||||||
}
|
}
|
||||||
// TODO: investigate whether we should report this error?
|
// TODO: investigate whether we should report this error?
|
||||||
if err := f(netIface); err == nil {
|
if err := f(netIface); err == nil {
|
||||||
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
|
if context != "" {
|
||||||
|
mainLog.Load().Debug().Msgf("%s for interface %q successfully", context, i.Name)
|
||||||
|
}
|
||||||
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
|
} else if !errors.Is(err, errSaveCurrentStaticDNSNotSupported) {
|
||||||
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
|
mainLog.Load().Err(err).Msgf("%s for interface %q failed", context, i.Name)
|
||||||
}
|
}
|
||||||
@@ -785,3 +971,24 @@ func savedStaticNameservers(iface *net.Interface) []string {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dnsChanged reports whether DNS settings for given interface was changed.
|
||||||
|
// The caller must sort the nameservers before calling this function.
|
||||||
|
func dnsChanged(iface *net.Interface, nameservers []string) bool {
|
||||||
|
curNameservers, _ := currentStaticDNS(iface)
|
||||||
|
slices.Sort(curNameservers)
|
||||||
|
return !slices.Equal(curNameservers, nameservers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// selfUninstallCheck checks if the error dues to controld.InvalidConfigCode, perform self-uninstall then.
|
||||||
|
func selfUninstallCheck(uninstallErr error, p *prog, logger zerolog.Logger) {
|
||||||
|
var uer *controld.UtilityErrorResponse
|
||||||
|
if errors.As(uninstallErr, &uer) && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||||
|
// Ensure all DNS watchers goroutine are terminated, so it won't mess up with self-uninstall.
|
||||||
|
close(p.dnsWatcherStopCh)
|
||||||
|
p.dnsWg.Wait()
|
||||||
|
|
||||||
|
// Perform self-uninstall now.
|
||||||
|
selfUninstall(p, logger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
|
|
||||||
@@ -24,12 +29,34 @@ func setDependencies(svc *service.Config) {
|
|||||||
"After=network-online.target",
|
"After=network-online.target",
|
||||||
"Wants=NetworkManager-wait-online.service",
|
"Wants=NetworkManager-wait-online.service",
|
||||||
"After=NetworkManager-wait-online.service",
|
"After=NetworkManager-wait-online.service",
|
||||||
"Wants=systemd-networkd-wait-online.service",
|
|
||||||
"Wants=nss-lookup.target",
|
"Wants=nss-lookup.target",
|
||||||
"After=nss-lookup.target",
|
"After=nss-lookup.target",
|
||||||
}
|
}
|
||||||
|
if out, _ := exec.Command("networkctl", "--no-pager").CombinedOutput(); len(out) > 0 {
|
||||||
|
if wantsSystemDNetworkdWaitOnline(bytes.NewReader(out)) {
|
||||||
|
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||||
svc.WorkingDirectory = dir
|
svc.WorkingDirectory = dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wantsSystemDNetworkdWaitOnline reports whether "systemd-networkd-wait-online" service
|
||||||
|
// is required to be added to ctrld dependencies services.
|
||||||
|
// The input reader r is the output of "networkctl --no-pager" command.
|
||||||
|
func wantsSystemDNetworkdWaitOnline(r io.Reader) bool {
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
// Skip header
|
||||||
|
scanner.Scan()
|
||||||
|
configured := false
|
||||||
|
for scanner.Scan() {
|
||||||
|
fields := strings.Fields(scanner.Text())
|
||||||
|
if len(fields) > 0 && fields[len(fields)-1] == "configured" {
|
||||||
|
configured = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return configured
|
||||||
|
}
|
||||||
|
|||||||
48
cmd/cli/prog_linux_test.go
Normal file
48
cmd/cli/prog_linux_test.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
networkctlUnmanagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||||
|
1 lo loopback carrier unmanaged
|
||||||
|
2 wlp0s20f3 wlan routable unmanaged
|
||||||
|
3 tailscale0 none routable unmanaged
|
||||||
|
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||||
|
5 docker0 bridge no-carrier unmanaged
|
||||||
|
|
||||||
|
5 links listed.
|
||||||
|
`
|
||||||
|
networkctlManagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||||
|
1 lo loopback carrier unmanaged
|
||||||
|
2 wlp0s20f3 wlan routable configured
|
||||||
|
3 tailscale0 none routable unmanaged
|
||||||
|
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||||
|
5 docker0 bridge no-carrier unmanaged
|
||||||
|
|
||||||
|
5 links listed.
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_wantsSystemDNetworkdWaitOnline(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
r io.Reader
|
||||||
|
required bool
|
||||||
|
}{
|
||||||
|
{"unmanaged", strings.NewReader(networkctlUnmanagedOutput), false},
|
||||||
|
{"managed", strings.NewReader(networkctlManagedOutput), true},
|
||||||
|
{"empty", strings.NewReader(""), false},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if required := wantsSystemDNetworkdWaitOnline(tc.r); required != tc.required {
|
||||||
|
t.Errorf("wants %v got %v", tc.required, required)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
57
cmd/cli/prog_test.go
Normal file
57
cmd/cli/prog_test.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Control-D-Inc/ctrld"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
|
||||||
|
p := &prog{cfg: &ctrld.Config{}}
|
||||||
|
|
||||||
|
// Default value is true.
|
||||||
|
assert.True(t, p.dnsWatchdogEnabled())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
enabled bool
|
||||||
|
}{
|
||||||
|
{"enabled", true},
|
||||||
|
{"disabled", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
p.cfg.Service.DnsWatchdogEnabled = &tc.enabled
|
||||||
|
assert.Equal(t, tc.enabled, p.dnsWatchdogEnabled())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_prog_dnsWatchdogInterval(t *testing.T) {
|
||||||
|
p := &prog{cfg: &ctrld.Config{}}
|
||||||
|
|
||||||
|
// Default value is 20s.
|
||||||
|
assert.Equal(t, dnsWatchdogDefaultInterval, p.dnsWatchdogDuration())
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
duration time.Duration
|
||||||
|
expected time.Duration
|
||||||
|
}{
|
||||||
|
{"valid", time.Minute, time.Minute},
|
||||||
|
{"zero", 0, dnsWatchdogDefaultInterval},
|
||||||
|
{"nagative", time.Duration(-1 * time.Minute), dnsWatchdogDefaultInterval},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
p.cfg.Service.DnsWatchdogInvterval = &tc.duration
|
||||||
|
assert.Equal(t, tc.expected, p.dnsWatchdogDuration())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -51,7 +51,7 @@ var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
|||||||
|
|
||||||
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
|
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
|
||||||
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
|
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
|
||||||
if p.cfg.Service.MetricsQueryStats {
|
if p.metricsQueryStats.Load() {
|
||||||
c.WithLabelValues(lvs...).Inc()
|
c.WithLabelValues(lvs...).Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,15 +8,15 @@ import (
|
|||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
resolvConfPath = "/etc/resolv.conf"
|
|
||||||
resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
|
||||||
)
|
|
||||||
|
|
||||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||||
// and reverting to the original config set by ctrld.
|
// and reverting to the original config set by ctrld.
|
||||||
func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||||
mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file")
|
resolvConfPath := "/etc/resolv.conf"
|
||||||
|
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
|
||||||
|
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
|
||||||
|
resolvConfPath = rp
|
||||||
|
}
|
||||||
|
mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath)
|
||||||
watcher, err := fsnotify.NewWatcher()
|
watcher, err := fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||||
@@ -28,12 +28,17 @@ func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface
|
|||||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||||
watchDir := filepath.Dir(resolvConfPath)
|
watchDir := filepath.Dir(resolvConfPath)
|
||||||
if err := watcher.Add(watchDir); err != nil {
|
if err := watcher.Add(watchDir); err != nil {
|
||||||
mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list")
|
mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-p.dnsWatcherStopCh:
|
||||||
|
return
|
||||||
|
case <-p.stopCh:
|
||||||
|
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||||
|
return
|
||||||
case event, ok := <-watcher.Events:
|
case event, ok := <-watcher.Events:
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,15 +3,44 @@ package cli
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const resolvConfPath = "/etc/resolv.conf"
|
||||||
|
|
||||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||||
servers := make([]string, len(ns))
|
servers := make([]string, len(ns))
|
||||||
for i := range ns {
|
for i := range ns {
|
||||||
servers[i] = ns[i].String()
|
servers[i] = ns[i].String()
|
||||||
}
|
}
|
||||||
return setDNS(iface, servers)
|
if err := setDNS(iface, servers); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
slices.Sort(servers)
|
||||||
|
curNs := currentDNS(iface)
|
||||||
|
slices.Sort(curNs)
|
||||||
|
if !slices.Equal(curNs, servers) {
|
||||||
|
c, err := resolvconffile.ParseFile(resolvConfPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.Nameservers = ns
|
||||||
|
f, err := os.Create(resolvConfPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := c.Write(f); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return f.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||||
|
|||||||
7
cmd/cli/self_delete_others.go
Normal file
7
cmd/cli/self_delete_others.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package cli
|
||||||
|
|
||||||
|
var supportedSelfDelete = true
|
||||||
|
|
||||||
|
func selfDeleteExe() error { return nil }
|
||||||
134
cmd/cli/self_delete_windows.go
Normal file
134
cmd/cli/self_delete_windows.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
// Copied from https://github.com/secur30nly/go-self-delete
|
||||||
|
// with modification to suitable for ctrld usage.
|
||||||
|
|
||||||
|
/*
|
||||||
|
License: MIT Licence
|
||||||
|
|
||||||
|
References:
|
||||||
|
- https://github.com/LloydLabs/delete-self-poc
|
||||||
|
- https://twitter.com/jonasLyk/status/1350401461985955840
|
||||||
|
*/
|
||||||
|
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var supportedSelfDelete = false
|
||||||
|
|
||||||
|
type FILE_RENAME_INFO struct {
|
||||||
|
Union struct {
|
||||||
|
ReplaceIfExists bool
|
||||||
|
Flags uint32
|
||||||
|
}
|
||||||
|
RootDirectory windows.Handle
|
||||||
|
FileNameLength uint32
|
||||||
|
FileName [1]uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type FILE_DISPOSITION_INFO struct {
|
||||||
|
DeleteFile bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
|
||||||
|
handle, err := windows.CreateFile(
|
||||||
|
pwPath,
|
||||||
|
windows.DELETE,
|
||||||
|
0,
|
||||||
|
nil,
|
||||||
|
windows.OPEN_EXISTING,
|
||||||
|
windows.FILE_ATTRIBUTE_NORMAL,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return handle, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dsRenameHandle(hHandle windows.Handle) error {
|
||||||
|
var fRename FILE_RENAME_INFO
|
||||||
|
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
lpwStream := &DS_STREAM_RENAME[0]
|
||||||
|
fRename.FileNameLength = uint32(unsafe.Sizeof(lpwStream))
|
||||||
|
|
||||||
|
windows.NewLazyDLL("kernel32.dll").NewProc("RtlCopyMemory").Call(
|
||||||
|
uintptr(unsafe.Pointer(&fRename.FileName[0])),
|
||||||
|
uintptr(unsafe.Pointer(lpwStream)),
|
||||||
|
unsafe.Sizeof(lpwStream),
|
||||||
|
)
|
||||||
|
|
||||||
|
err = windows.SetFileInformationByHandle(
|
||||||
|
hHandle,
|
||||||
|
windows.FileRenameInfo,
|
||||||
|
(*byte)(unsafe.Pointer(&fRename)),
|
||||||
|
uint32(unsafe.Sizeof(fRename)+unsafe.Sizeof(lpwStream)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dsDepositeHandle(hHandle windows.Handle) error {
|
||||||
|
var fDelete FILE_DISPOSITION_INFO
|
||||||
|
fDelete.DeleteFile = true
|
||||||
|
|
||||||
|
err := windows.SetFileInformationByHandle(
|
||||||
|
hHandle,
|
||||||
|
windows.FileDispositionInfo,
|
||||||
|
(*byte)(unsafe.Pointer(&fDelete)),
|
||||||
|
uint32(unsafe.Sizeof(fDelete)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func selfDeleteExe() error {
|
||||||
|
var wcPath [windows.MAX_PATH + 1]uint16
|
||||||
|
var hCurrent windows.Handle
|
||||||
|
|
||||||
|
_, err := windows.GetModuleFileName(0, &wcPath[0], windows.MAX_PATH)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||||
|
if err != nil || hCurrent == windows.InvalidHandle {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := dsRenameHandle(hCurrent); err != nil {
|
||||||
|
_ = windows.CloseHandle(hCurrent)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = windows.CloseHandle(hCurrent)
|
||||||
|
|
||||||
|
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||||
|
if err != nil || hCurrent == windows.InvalidHandle {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := dsDepositeHandle(hCurrent); err != nil {
|
||||||
|
_ = windows.CloseHandle(hCurrent)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return windows.CloseHandle(hCurrent)
|
||||||
|
}
|
||||||
16
cmd/cli/self_kill_others.go
Normal file
16
cmd/cli/self_kill_others.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build !unix
|
||||||
|
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||||
|
if uninstallInvalidCdUID(p, logger, false) {
|
||||||
|
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
45
cmd/cli/self_kill_unix.go
Normal file
45
cmd/cli/self_kill_unix.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
selfUninstallLinux(p, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
bin, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal().Err(err).Msg("could not determine executable")
|
||||||
|
}
|
||||||
|
args := []string{"uninstall"}
|
||||||
|
if !deactivationPinNotSet() {
|
||||||
|
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin))
|
||||||
|
}
|
||||||
|
cmd := exec.Command(bin, args...)
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
logger.Fatal().Err(err).Msg("could not start self uninstall command")
|
||||||
|
}
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||||
|
_ = cmd.Wait()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func selfUninstallLinux(p *prog, logger zerolog.Logger) {
|
||||||
|
if uninstallInvalidCdUID(p, logger, true) {
|
||||||
|
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,13 +21,16 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
|
|||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
|
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
|
||||||
return &procd{&sysV{s}}, nil
|
return &procd{sysV: &sysV{s}, svcConfig: c}, nil
|
||||||
case router.IsGLiNet():
|
case router.IsGLiNet():
|
||||||
return &sysV{s}, nil
|
return &sysV{s}, nil
|
||||||
case s.Platform() == "unix-systemv":
|
case s.Platform() == "unix-systemv":
|
||||||
return &sysV{s}, nil
|
return &sysV{s}, nil
|
||||||
case s.Platform() == "linux-systemd":
|
case s.Platform() == "linux-systemd":
|
||||||
return &systemd{s}, nil
|
return &systemd{s}, nil
|
||||||
|
case s.Platform() == "darwin-launchd":
|
||||||
|
return newLaunchd(s), nil
|
||||||
|
|
||||||
}
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
@@ -89,25 +92,31 @@ func (s *sysV) Status() (service.Status, error) {
|
|||||||
// like old GL.iNET Opal router.
|
// like old GL.iNET Opal router.
|
||||||
type procd struct {
|
type procd struct {
|
||||||
*sysV
|
*sysV
|
||||||
|
svcConfig *service.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *procd) Status() (service.Status, error) {
|
func (s *procd) Status() (service.Status, error) {
|
||||||
if !s.installed() {
|
if !s.installed() {
|
||||||
return service.StatusUnknown, service.ErrNotInstalled
|
return service.StatusUnknown, service.ErrNotInstalled
|
||||||
}
|
}
|
||||||
exe, err := os.Executable()
|
bin := s.svcConfig.Executable
|
||||||
if err != nil {
|
if bin == "" {
|
||||||
return service.StatusUnknown, nil
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return service.StatusUnknown, nil
|
||||||
|
}
|
||||||
|
bin = exe
|
||||||
}
|
}
|
||||||
|
|
||||||
// Looking for something like "/sbin/ctrld run ".
|
// Looking for something like "/sbin/ctrld run ".
|
||||||
shellCmd := fmt.Sprintf("ps | grep -q %q", exe+" [r]un ")
|
shellCmd := fmt.Sprintf("ps | grep -q %q", bin+" [r]un ")
|
||||||
if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil {
|
if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil {
|
||||||
return service.StatusStopped, nil
|
return service.StatusStopped, nil
|
||||||
}
|
}
|
||||||
return service.StatusRunning, nil
|
return service.StatusRunning, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// procd wraps a service.Service, and provide status command to
|
// systemd wraps a service.Service, and provide status command to
|
||||||
// report the status correctly.
|
// report the status correctly.
|
||||||
type systemd struct {
|
type systemd struct {
|
||||||
service.Service
|
service.Service
|
||||||
@@ -121,6 +130,29 @@ func (s *systemd) Status() (service.Status, error) {
|
|||||||
return s.Service.Status()
|
return s.Service.Status()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newLaunchd(s service.Service) *launchd {
|
||||||
|
return &launchd{
|
||||||
|
Service: s,
|
||||||
|
statusErrMsg: "Permission denied",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// launchd wraps a service.Service, and provide status command to
|
||||||
|
// report the status correctly when not running as root on Darwin.
|
||||||
|
//
|
||||||
|
// TODO: remove this wrapper once https://github.com/kardianos/service/issues/400 fixed.
|
||||||
|
type launchd struct {
|
||||||
|
service.Service
|
||||||
|
statusErrMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *launchd) Status() (service.Status, error) {
|
||||||
|
if os.Geteuid() != 0 {
|
||||||
|
return service.StatusUnknown, errors.New(l.statusErrMsg)
|
||||||
|
}
|
||||||
|
return l.Service.Status()
|
||||||
|
}
|
||||||
|
|
||||||
type task struct {
|
type task struct {
|
||||||
f func() error
|
f func() error
|
||||||
abortOnError bool
|
abortOnError bool
|
||||||
|
|||||||
@@ -9,3 +9,7 @@ import (
|
|||||||
func hasElevatedPrivilege() (bool, error) {
|
func hasElevatedPrivilege() (bool, error) {
|
||||||
return os.Geteuid() == 0, nil
|
return os.Geteuid() == 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func openLogFile(path string, flags int) (*os.File, error) {
|
||||||
|
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package cli
|
package cli
|
||||||
|
|
||||||
import "golang.org/x/sys/windows"
|
import (
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
func hasElevatedPrivilege() (bool, error) {
|
func hasElevatedPrivilege() (bool, error) {
|
||||||
var sid *windows.SID
|
var sid *windows.SID
|
||||||
@@ -22,3 +27,55 @@ func hasElevatedPrivilege() (bool, error) {
|
|||||||
token := windows.Token(0)
|
token := windows.Token(0)
|
||||||
return token.IsMember(sid)
|
return token.IsMember(sid)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func openLogFile(path string, mode int) (*os.File, error) {
|
||||||
|
if len(path) == 0 {
|
||||||
|
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||||
|
}
|
||||||
|
|
||||||
|
pathP, err := syscall.UTF16PtrFromString(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var access uint32
|
||||||
|
switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
|
||||||
|
case os.O_RDONLY:
|
||||||
|
access = windows.GENERIC_READ
|
||||||
|
case os.O_WRONLY:
|
||||||
|
access = windows.GENERIC_WRITE
|
||||||
|
case os.O_RDWR:
|
||||||
|
access = windows.GENERIC_READ | windows.GENERIC_WRITE
|
||||||
|
}
|
||||||
|
if mode&os.O_CREATE != 0 {
|
||||||
|
access |= windows.GENERIC_WRITE
|
||||||
|
}
|
||||||
|
if mode&os.O_APPEND != 0 {
|
||||||
|
access &^= windows.GENERIC_WRITE
|
||||||
|
access |= windows.FILE_APPEND_DATA
|
||||||
|
}
|
||||||
|
|
||||||
|
shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
|
||||||
|
|
||||||
|
var sa *syscall.SecurityAttributes
|
||||||
|
|
||||||
|
var createMode uint32
|
||||||
|
switch {
|
||||||
|
case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL):
|
||||||
|
createMode = windows.CREATE_NEW
|
||||||
|
case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC):
|
||||||
|
createMode = windows.CREATE_ALWAYS
|
||||||
|
case mode&os.O_CREATE == os.O_CREATE:
|
||||||
|
createMode = windows.OPEN_ALWAYS
|
||||||
|
case mode&os.O_TRUNC == os.O_TRUNC:
|
||||||
|
createMode = windows.TRUNCATE_EXISTING
|
||||||
|
default:
|
||||||
|
createMode = windows.OPEN_EXISTING
|
||||||
|
}
|
||||||
|
|
||||||
|
handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &os.PathError{Path: path, Op: "open", Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.NewFile(uintptr(handle), path), nil
|
||||||
|
}
|
||||||
|
|||||||
62
config.go
62
config.go
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
"tailscale.com/logtail/backoff"
|
"tailscale.com/logtail/backoff"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
@@ -188,27 +189,30 @@ func (c *Config) FirstUpstream() *UpstreamConfig {
|
|||||||
|
|
||||||
// ServiceConfig specifies the general ctrld config.
|
// ServiceConfig specifies the general ctrld config.
|
||||||
type ServiceConfig struct {
|
type ServiceConfig struct {
|
||||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||||
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
|
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
|
||||||
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
||||||
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
|
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
|
||||||
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
|
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
|
||||||
Daemon bool `mapstructure:"-" toml:"-"`
|
DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"`
|
||||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"`
|
||||||
|
RefetchTime *int `mapstructure:"refetch_time" toml:"refetch_time,omitempty"`
|
||||||
|
Daemon bool `mapstructure:"-" toml:"-"`
|
||||||
|
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
|
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
|
||||||
@@ -316,7 +320,7 @@ func (uc *UpstreamConfig) Init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if uc.IPStack == "" {
|
if uc.IPStack == "" {
|
||||||
if uc.isControlD() {
|
if uc.IsControlD() {
|
||||||
uc.IPStack = IpStackSplit
|
uc.IPStack = IpStackSplit
|
||||||
} else {
|
} else {
|
||||||
uc.IPStack = IpStackBoth
|
uc.IPStack = IpStackBoth
|
||||||
@@ -354,7 +358,7 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
|
|||||||
}
|
}
|
||||||
switch uc.Type {
|
switch uc.Type {
|
||||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||||
if uc.isControlD() || uc.isNextDNS() {
|
if uc.IsControlD() || uc.isNextDNS() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -401,7 +405,7 @@ func (uc *UpstreamConfig) UID() string {
|
|||||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
||||||
isControlD := uc.isControlD()
|
isControlD := uc.IsControlD()
|
||||||
for {
|
for {
|
||||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
||||||
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
||||||
@@ -486,6 +490,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
|||||||
ClientSessionCache: tls.NewLRUClientSessionCache(0),
|
ClientSessionCache: tls.NewLRUClientSessionCache(0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prevent bad tcp connection hanging the requests for too long.
|
||||||
|
// See: https://github.com/golang/go/issues/36026
|
||||||
|
if t2, err := http2.ConfigureTransports(transport); err == nil {
|
||||||
|
t2.ReadIdleTimeout = 10 * time.Second
|
||||||
|
t2.PingTimeout = 5 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
dialerTimeoutMs := 2000
|
dialerTimeoutMs := 2000
|
||||||
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
|
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
|
||||||
dialerTimeoutMs = uc.Timeout
|
dialerTimeoutMs = uc.Timeout
|
||||||
@@ -572,7 +583,8 @@ func (uc *UpstreamConfig) ping() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uc *UpstreamConfig) isControlD() bool {
|
// IsControlD reports whether this is a ControlD upstream.
|
||||||
|
func (uc *UpstreamConfig) IsControlD() bool {
|
||||||
domain := uc.Domain
|
domain := uc.Domain
|
||||||
if domain == "" {
|
if domain == "" {
|
||||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
@@ -22,6 +23,8 @@ func TestLoadConfig(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, "info", cfg.Service.LogLevel)
|
assert.Equal(t, "info", cfg.Service.LogLevel)
|
||||||
assert.Equal(t, "/path/to/log.log", cfg.Service.LogPath)
|
assert.Equal(t, "/path/to/log.log", cfg.Service.LogPath)
|
||||||
|
assert.Equal(t, false, *cfg.Service.DnsWatchdogEnabled)
|
||||||
|
assert.Equal(t, time.Duration(20*time.Second), *cfg.Service.DnsWatchdogInvterval)
|
||||||
|
|
||||||
assert.Len(t, cfg.Network, 2)
|
assert.Len(t, cfg.Network, 2)
|
||||||
assert.Contains(t, cfg.Network, "0")
|
assert.Contains(t, cfg.Network, "0")
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
# - Non-cgo ctrld binary.
|
# - Non-cgo ctrld binary.
|
||||||
#
|
#
|
||||||
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
||||||
FROM golang:1.20-bullseye as base
|
FROM golang:bullseye as base
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
|||||||
@@ -252,6 +252,35 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus
|
|||||||
- Required: no
|
- Required: no
|
||||||
- Default: ""
|
- Default: ""
|
||||||
|
|
||||||
|
### dns_watchdog_enabled
|
||||||
|
Checking DNS changes to network interfaces and reverting to ctrld's own settings.
|
||||||
|
|
||||||
|
The DNS watchdog process only runs on Windows and MacOS.
|
||||||
|
|
||||||
|
- Type: boolean
|
||||||
|
- Required: no
|
||||||
|
- Default: true
|
||||||
|
|
||||||
|
### dns_watchdog_interval
|
||||||
|
Time duration between each DNS watchdog iteration.
|
||||||
|
|
||||||
|
A duration string is a possibly signed sequence of decimal numbers, each with optional fraction and a unit suffix,
|
||||||
|
such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||||
|
|
||||||
|
If the time duration is non-positive, default value will be used.
|
||||||
|
|
||||||
|
- Type: time duration string
|
||||||
|
- Required: no
|
||||||
|
- Default: 20s
|
||||||
|
|
||||||
|
### refetch_time
|
||||||
|
Time in seconds between each iteration that reloads custom config if changed.
|
||||||
|
|
||||||
|
The value must be a positive number, any invalid value will be ignored and default value will be used.
|
||||||
|
- Type: number
|
||||||
|
- Required: no
|
||||||
|
- Default: 3600
|
||||||
|
|
||||||
## Upstream
|
## Upstream
|
||||||
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
||||||
|
|
||||||
@@ -336,7 +365,7 @@ The protocol that `ctrld` will use to send DNS requests to upstream.
|
|||||||
|
|
||||||
- Type: string
|
- Type: string
|
||||||
- Required: yes
|
- Required: yes
|
||||||
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os`
|
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`
|
||||||
|
|
||||||
### ip_stack
|
### ip_stack
|
||||||
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.
|
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.
|
||||||
|
|||||||
BIN
docs/ctrldsplash.png
Normal file
BIN
docs/ctrldsplash.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 458 KiB |
2
doh.go
2
doh.go
@@ -147,7 +147,7 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
|
|||||||
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
|
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
|
||||||
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
|
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
|
||||||
switch {
|
switch {
|
||||||
case uc.isControlD():
|
case uc.IsControlD():
|
||||||
dohHeader = newControlDHeaders(ci)
|
dohHeader = newControlDHeaders(ci)
|
||||||
case uc.isNextDNS():
|
case uc.isNextDNS():
|
||||||
dohHeader = newNextDNSHeaders(ci)
|
dohHeader = newNextDNSHeaders(ci)
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
//go:build qf
|
|
||||||
|
|
||||||
package ctrld
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type doqResolver struct {
|
|
||||||
uc *UpstreamConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
|
||||||
return nil, errors.New("DoQ is not supported")
|
|
||||||
}
|
|
||||||
2
dot.go
2
dot.go
@@ -18,7 +18,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
|||||||
// dns.controld.dev first. By using a dialer with custom resolver,
|
// dns.controld.dev first. By using a dialer with custom resolver,
|
||||||
// we ensure that we can always resolve the bootstrap domain
|
// we ensure that we can always resolve the bootstrap domain
|
||||||
// regardless of the machine DNS status.
|
// regardless of the machine DNS status.
|
||||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
|
||||||
dnsTyp := uint16(0)
|
dnsTyp := uint16(0)
|
||||||
if msg != nil && len(msg.Question) > 0 {
|
if msg != nil && len(msg.Question) > 0 {
|
||||||
dnsTyp = msg.Question[0].Qtype
|
dnsTyp = msg.Question[0].Qtype
|
||||||
|
|||||||
@@ -122,8 +122,8 @@ func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan
|
|||||||
bo := backoff.NewBackoff("mdns probe", func(format string, args ...any) {}, time.Second*30)
|
bo := backoff.NewBackoff("mdns probe", func(format string, args ...any) {}, time.Second*30)
|
||||||
for {
|
for {
|
||||||
err := m.probe(conns, remoteAddr)
|
err := m.probe(conns, remoteAddr)
|
||||||
if isErrNetUnreachableOrInvalid(err) {
|
if shouldStopProbing(err) {
|
||||||
ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: network unreachable or invalid", remoteAddr)
|
ctrld.ProxyLogger.Load().Warn().Msgf("stop probing %q: %v", remoteAddr, err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -165,7 +165,7 @@ func (m *mdns) readLoop(conn *net.UDPConn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ip, name string
|
var ip, name string
|
||||||
rrs := make([]dns.RR, 0, len(msg.Answer)+len(msg.Extra))
|
var rrs []dns.RR
|
||||||
rrs = append(rrs, msg.Answer...)
|
rrs = append(rrs, msg.Answer...)
|
||||||
rrs = append(rrs, msg.Extra...)
|
rrs = append(rrs, msg.Extra...)
|
||||||
for _, rr := range rrs {
|
for _, rr := range rrs {
|
||||||
@@ -273,10 +273,14 @@ func multicastInterfaces() ([]net.Interface, error) {
|
|||||||
return interfaces, nil
|
return interfaces, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isErrNetUnreachableOrInvalid(err error) bool {
|
// shouldStopProbing reports whether ctrld should stop probing mdns.
|
||||||
|
func shouldStopProbing(err error) bool {
|
||||||
var se *os.SyscallError
|
var se *os.SyscallError
|
||||||
if errors.As(err, &se) {
|
if errors.As(err, &se) {
|
||||||
return se.Err == syscall.ENETUNREACH || se.Err == syscall.EINVAL
|
switch se.Err {
|
||||||
|
case syscall.ENETUNREACH, syscall.EINVAL, syscall.EPERM:
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,14 +26,15 @@ const (
|
|||||||
apiDomainDev = "api.controld.dev"
|
apiDomainDev = "api.controld.dev"
|
||||||
resolverDataURLCom = "https://api.controld.com/utility"
|
resolverDataURLCom = "https://api.controld.com/utility"
|
||||||
resolverDataURLDev = "https://api.controld.dev/utility"
|
resolverDataURLDev = "https://api.controld.dev/utility"
|
||||||
InvalidConfigCode = 40401
|
InvalidConfigCode = 40402
|
||||||
)
|
)
|
||||||
|
|
||||||
// ResolverConfig represents Control D resolver data.
|
// ResolverConfig represents Control D resolver data.
|
||||||
type ResolverConfig struct {
|
type ResolverConfig struct {
|
||||||
DOH string `json:"doh"`
|
DOH string `json:"doh"`
|
||||||
Ctrld struct {
|
Ctrld struct {
|
||||||
CustomConfig string `json:"custom_config"`
|
CustomConfig string `json:"custom_config"`
|
||||||
|
CustomLastUpdate int64 `json:"custom_last_update"`
|
||||||
} `json:"ctrld"`
|
} `json:"ctrld"`
|
||||||
Exclude []string `json:"exclude"`
|
Exclude []string `json:"exclude"`
|
||||||
UID string `json:"uid"`
|
UID string `json:"uid"`
|
||||||
@@ -76,17 +77,28 @@ func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, e
|
|||||||
req.ClientID = clientID
|
req.ClientID = clientID
|
||||||
}
|
}
|
||||||
body, _ := json.Marshal(req)
|
body, _ := json.Marshal(req)
|
||||||
return postUtilityAPI(version, cdDev, bytes.NewReader(body))
|
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchResolverUID fetch resolver uid from provision token.
|
// FetchResolverUID fetch resolver uid from provision token.
|
||||||
func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) {
|
func FetchResolverUID(pt, version string, cdDev bool) (*ResolverConfig, error) {
|
||||||
hostname, _ := os.Hostname()
|
hostname, _ := os.Hostname()
|
||||||
body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname})
|
body, _ := json.Marshal(utilityOrgRequest{ProvToken: pt, Hostname: hostname})
|
||||||
return postUtilityAPI(version, cdDev, bytes.NewReader(body))
|
return postUtilityAPI(version, cdDev, false, bytes.NewReader(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig, error) {
|
// UpdateCustomLastFailed calls API to mark custom config is bad.
|
||||||
|
func UpdateCustomLastFailed(rawUID, version string, cdDev, lastUpdatedFailed bool) (*ResolverConfig, error) {
|
||||||
|
uid, clientID := ParseRawUID(rawUID)
|
||||||
|
req := utilityRequest{UID: uid}
|
||||||
|
if clientID != "" {
|
||||||
|
req.ClientID = clientID
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(req)
|
||||||
|
return postUtilityAPI(version, cdDev, true, bytes.NewReader(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reader) (*ResolverConfig, error) {
|
||||||
apiUrl := resolverDataURLCom
|
apiUrl := resolverDataURLCom
|
||||||
if cdDev {
|
if cdDev {
|
||||||
apiUrl = resolverDataURLDev
|
apiUrl = resolverDataURLDev
|
||||||
@@ -98,6 +110,9 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig
|
|||||||
q := req.URL.Query()
|
q := req.URL.Query()
|
||||||
q.Set("platform", "ctrld")
|
q.Set("platform", "ctrld")
|
||||||
q.Set("version", version)
|
q.Set("version", version)
|
||||||
|
if lastUpdatedFailed {
|
||||||
|
q.Set("custom_last_failed", "1")
|
||||||
|
}
|
||||||
req.URL.RawQuery = q.Encode()
|
req.URL.RawQuery = q.Encode()
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ start_service() {
|
|||||||
procd_set_param stdout 1 # forward stdout of the command to logd
|
procd_set_param stdout 1 # forward stdout of the command to logd
|
||||||
procd_set_param stderr 1 # same for stderr
|
procd_set_param stderr 1 # same for stderr
|
||||||
procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop
|
procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop
|
||||||
|
procd_set_param term_timeout 10
|
||||||
procd_close_instance
|
procd_close_instance
|
||||||
echo "${name} has been started"
|
echo "${name} has been started"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
@@ -98,6 +99,11 @@ func IsOldOpenwrt() bool {
|
|||||||
return cmd == ""
|
return cmd == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WaitProcessExited reports whether the "ctrld stop" command have to wait until ctrld process exited.
|
||||||
|
func WaitProcessExited() bool {
|
||||||
|
return Name() == openwrt.Name
|
||||||
|
}
|
||||||
|
|
||||||
var routerPlatform atomic.Pointer[router]
|
var routerPlatform atomic.Pointer[router]
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
@@ -159,6 +165,16 @@ func HomeDir() (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return filepath.Dir(exe), nil
|
return filepath.Dir(exe), nil
|
||||||
|
case edgeos.Name:
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Using binary directory as home dir if it is located in /config.
|
||||||
|
// Otherwise, fallback to old behavior for compatibility.
|
||||||
|
if strings.HasPrefix(exe, "/config/") {
|
||||||
|
return filepath.Dir(exe), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,11 +49,15 @@ func (s *merlinSvc) Platform() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *merlinSvc) configPath() string {
|
func (s *merlinSvc) configPath() string {
|
||||||
path, err := os.Executable()
|
bin := s.Config.Executable
|
||||||
if err != nil {
|
if bin == "" {
|
||||||
return ""
|
path, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
bin = path
|
||||||
}
|
}
|
||||||
return path + ".startup"
|
return bin + ".startup"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *merlinSvc) template() *template.Template {
|
func (s *merlinSvc) template() *template.Template {
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
package ctrld
|
package ctrld
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func dnsFns() []dnsFn {
|
func dnsFns() []dnsFn {
|
||||||
@@ -20,40 +17,23 @@ func dnsFromAdapter() []string {
|
|||||||
}
|
}
|
||||||
ns := make([]string, 0, len(aas)*2)
|
ns := make([]string, 0, len(aas)*2)
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
do := func(addr windows.SocketAddress) {
|
addressMap := make(map[string]struct{})
|
||||||
sa, err := addr.Sockaddr.Sockaddr()
|
for _, aa := range aas {
|
||||||
if err != nil {
|
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
|
||||||
return
|
addressMap[a.Address.IP().String()] = struct{}{}
|
||||||
}
|
}
|
||||||
var ip net.IP
|
|
||||||
switch sa := sa.(type) {
|
|
||||||
case *syscall.SockaddrInet4:
|
|
||||||
ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
|
|
||||||
case *syscall.SockaddrInet6:
|
|
||||||
ip = make(net.IP, net.IPv6len)
|
|
||||||
copy(ip, sa.Addr[:])
|
|
||||||
if ip[0] == 0xfe && ip[1] == 0xc0 {
|
|
||||||
// Ignore these fec0/10 ones. Windows seems to
|
|
||||||
// populate them as defaults on its misc rando
|
|
||||||
// interfaces.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
|
|
||||||
}
|
|
||||||
if ip.IsLoopback() || seen[ip.String()] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
seen[ip.String()] = true
|
|
||||||
ns = append(ns, ip.String())
|
|
||||||
}
|
}
|
||||||
for _, aa := range aas {
|
for _, aa := range aas {
|
||||||
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
||||||
do(dns.Address)
|
ip := dns.Address.IP()
|
||||||
}
|
if ip == nil || ip.IsLoopback() || seen[ip.String()] {
|
||||||
for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next {
|
continue
|
||||||
do(gw.Address)
|
}
|
||||||
|
if _, ok := addressMap[ip.String()]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[ip.String()] = true
|
||||||
|
ns = append(ns, ip.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ns
|
return ns
|
||||||
|
|||||||
95
resolver.go
95
resolver.go
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -30,18 +31,49 @@ const (
|
|||||||
ResolverTypePrivate = "private"
|
ResolverTypePrivate = "private"
|
||||||
)
|
)
|
||||||
|
|
||||||
const bootstrapDNS = "76.76.2.22"
|
const (
|
||||||
|
controldBootstrapDns = "76.76.2.22"
|
||||||
|
controldPublicDns = "76.76.2.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||||
|
|
||||||
// or is the Resolver used for ResolverTypeOS.
|
// or is the Resolver used for ResolverTypeOS.
|
||||||
var or = &osResolver{nameservers: defaultNameservers()}
|
var or = &osResolver{nameservers: defaultNameservers()}
|
||||||
|
|
||||||
// defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver.
|
// defaultNameservers returns OS nameservers plus ControlD public DNS.
|
||||||
func defaultNameservers() []string {
|
func defaultNameservers() []string {
|
||||||
ns := nameservers()
|
ns := nameservers()
|
||||||
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
|
|
||||||
return ns
|
return ns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitializeOsResolver initializes OS resolver using the current system DNS settings.
|
||||||
|
// It returns the nameservers that is going to be used by the OS resolver.
|
||||||
|
//
|
||||||
|
// It's the caller's responsibility to ensure the system DNS is in a clean state before
|
||||||
|
// calling this function.
|
||||||
|
func InitializeOsResolver() []string {
|
||||||
|
or.nameservers = or.nameservers[:0]
|
||||||
|
for _, ns := range defaultNameservers() {
|
||||||
|
if testNameserver(ns) {
|
||||||
|
or.nameservers = append(or.nameservers, ns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
or.nameservers = append(or.nameservers, controldPublicDnsWithPort)
|
||||||
|
return or.nameservers
|
||||||
|
}
|
||||||
|
|
||||||
|
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
|
||||||
|
func testNameserver(addr string) bool {
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.SetQuestion(".", dns.TypeNS)
|
||||||
|
client := new(dns.Client)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_, _, err := client.ExchangeContext(ctx, msg, addr)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
// Resolver is the interface that wraps the basic DNS operations.
|
// Resolver is the interface that wraps the basic DNS operations.
|
||||||
//
|
//
|
||||||
// Resolve resolves the DNS query, return the result and the corresponding error.
|
// Resolve resolves the DNS query, return the result and the corresponding error.
|
||||||
@@ -76,8 +108,9 @@ type osResolver struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type osResolverResult struct {
|
type osResolverResult struct {
|
||||||
answer *dns.Msg
|
answer *dns.Msg
|
||||||
err error
|
err error
|
||||||
|
isControlDPublicDNS bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||||
@@ -103,19 +136,34 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
|||||||
go func(server string) {
|
go func(server string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||||
ch <- &osResolverResult{answer: answer, err: err}
|
ch <- &osResolverResult{answer: answer, err: err, isControlDPublicDNS: server == controldPublicDnsWithPort}
|
||||||
}(server)
|
}(server)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
nonSuccessAnswer *dns.Msg
|
||||||
|
controldSuccessAnswer *dns.Msg
|
||||||
|
)
|
||||||
errs := make([]error, 0, numServers)
|
errs := make([]error, 0, numServers)
|
||||||
for res := range ch {
|
for res := range ch {
|
||||||
if res.err == nil {
|
switch {
|
||||||
cancel()
|
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
|
||||||
return res.answer, res.err
|
if res.isControlDPublicDNS {
|
||||||
|
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
|
||||||
|
} else {
|
||||||
|
cancel()
|
||||||
|
return res.answer, nil
|
||||||
|
}
|
||||||
|
case res.answer != nil:
|
||||||
|
nonSuccessAnswer = res.answer
|
||||||
}
|
}
|
||||||
errs = append(errs, res.err)
|
errs = append(errs, res.err)
|
||||||
}
|
}
|
||||||
|
for _, answer := range []*dns.Msg{controldSuccessAnswer, nonSuccessAnswer} {
|
||||||
|
if answer != nil {
|
||||||
|
return answer, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, errors.Join(errs...)
|
return nil, errors.Join(errs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +173,7 @@ type legacyResolver struct {
|
|||||||
|
|
||||||
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||||
// See comment in (*dotResolver).resolve method.
|
// See comment in (*dotResolver).resolve method.
|
||||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
|
||||||
dnsTyp := uint16(0)
|
dnsTyp := uint16(0)
|
||||||
if msg != nil && len(msg.Question) > 0 {
|
if msg != nil && len(msg.Question) > 0 {
|
||||||
dnsTyp = msg.Question[0].Qtype
|
dnsTyp = msg.Question[0].Qtype
|
||||||
@@ -163,7 +211,7 @@ func LookupIP(domain string) []string {
|
|||||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||||
resolver := &osResolver{nameservers: nameservers()}
|
resolver := &osResolver{nameservers: nameservers()}
|
||||||
if withBootstrapDNS {
|
if withBootstrapDNS {
|
||||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
resolver.nameservers = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, resolver.nameservers...)
|
||||||
}
|
}
|
||||||
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
|
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
|
||||||
timeoutMs := 2000
|
timeoutMs := 2000
|
||||||
@@ -239,7 +287,7 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
|
|||||||
// - Input servers.
|
// - Input servers.
|
||||||
func NewBootstrapResolver(servers ...string) Resolver {
|
func NewBootstrapResolver(servers ...string) Resolver {
|
||||||
resolver := &osResolver{nameservers: nameservers()}
|
resolver := &osResolver{nameservers: nameservers()}
|
||||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
resolver.nameservers = append([]string{controldPublicDnsWithPort}, resolver.nameservers...)
|
||||||
for _, ns := range servers {
|
for _, ns := range servers {
|
||||||
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
|
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
|
||||||
}
|
}
|
||||||
@@ -266,11 +314,11 @@ func NewPrivateResolver() Resolver {
|
|||||||
// - Direct listener that has ctrld as an upstream (e.g: dnsmasq).
|
// - Direct listener that has ctrld as an upstream (e.g: dnsmasq).
|
||||||
//
|
//
|
||||||
// causing the query always succeed.
|
// causing the query always succeed.
|
||||||
if sliceContains(resolveConfNss, host) {
|
if slices.Contains(resolveConfNss, host) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Ignoring local RFC 1918 addresses.
|
// Ignoring local RFC 1918 addresses.
|
||||||
if sliceContains(localRfc1918Addrs, host) {
|
if slices.Contains(localRfc1918Addrs, host) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip := net.ParseIP(host)
|
ip := net.ParseIP(host)
|
||||||
@@ -322,20 +370,3 @@ func newDialer(dnsAddress string) *net.Dialer {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(cuonglm): use slices.Contains once upgrading to go1.21
|
|
||||||
// sliceContains reports whether v is present in s.
|
|
||||||
func sliceContains[S ~[]E, E comparable](s S, v E) bool {
|
|
||||||
return sliceIndex(s, v) >= 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// sliceIndex returns the index of the first occurrence of v in s,
|
|
||||||
// or -1 if not present.
|
|
||||||
func sliceIndex[S ~[]E, E comparable](s S, v E) int {
|
|
||||||
for i := range s {
|
|
||||||
if v == s[i] {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package ctrld
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -28,6 +30,57 @@ func Test_osResolver_Resolve(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) {
|
||||||
|
ns := make([]string, 0, 2)
|
||||||
|
servers := make([]*dns.Server, 0, 2)
|
||||||
|
successHandler := dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetRcode(msg, dns.RcodeSuccess)
|
||||||
|
w.WriteMsg(m)
|
||||||
|
})
|
||||||
|
nonSuccessHandlerWithRcode := func(rcode int) dns.HandlerFunc {
|
||||||
|
return dns.HandlerFunc(func(w dns.ResponseWriter, msg *dns.Msg) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetRcode(msg, rcode)
|
||||||
|
w.WriteMsg(m)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
handlers := []dns.Handler{
|
||||||
|
nonSuccessHandlerWithRcode(dns.RcodeRefused),
|
||||||
|
nonSuccessHandlerWithRcode(dns.RcodeNameError),
|
||||||
|
successHandler,
|
||||||
|
}
|
||||||
|
for i := range handlers {
|
||||||
|
pc, err := net.ListenPacket("udp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s, addr, err := runLocalPacketConnTestServer(t, pc, handlers[i])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
ns = append(ns, addr)
|
||||||
|
servers = append(servers, s)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
for _, server := range servers {
|
||||||
|
server.Shutdown()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
resolver := &osResolver{nameservers: ns}
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.SetQuestion(".", dns.TypeNS)
|
||||||
|
answer, err := resolver.Resolve(context.Background(), msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if answer.Rcode != dns.RcodeSuccess {
|
||||||
|
t.Errorf("unexpected return code: %s", dns.RcodeToString[answer.Rcode])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -51,3 +104,33 @@ func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
server := &dns.Server{
|
||||||
|
PacketConn: pc,
|
||||||
|
ReadTimeout: time.Hour,
|
||||||
|
WriteTimeout: time.Hour,
|
||||||
|
Handler: handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
waitLock := sync.Mutex{}
|
||||||
|
waitLock.Lock()
|
||||||
|
server.NotifyStartedFunc = waitLock.Unlock
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(server)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, closer := pc.LocalAddr().String(), pc
|
||||||
|
go func() {
|
||||||
|
if err := server.ActivateAndServe(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
closer.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitLock.Lock()
|
||||||
|
return server, addr, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ var sampleConfigContent = `
|
|||||||
[service]
|
[service]
|
||||||
log_level = "info"
|
log_level = "info"
|
||||||
log_path = "/path/to/log.log"
|
log_path = "/path/to/log.log"
|
||||||
|
dns_watchdog_enabled = false
|
||||||
|
dns_watchdog_interval = "20s"
|
||||||
|
|
||||||
[network.0]
|
[network.0]
|
||||||
name = "Home Wifi"
|
name = "Home Wifi"
|
||||||
|
|||||||
Reference in New Issue
Block a user