Merge pull request #155 from Control-D-Inc/release-branch-v1.3.7

Release branch v1.3.7
This commit is contained in:
Cuong Manh Le
2024-05-31 15:04:47 +07:00
committed by GitHub
18 changed files with 433 additions and 261 deletions

View File

@@ -4,6 +4,8 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/Control-D-Inc/ctrld.svg)](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
[![Go Report Card](https://goreportcard.com/badge/github.com/Control-D-Inc/ctrld)](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
![ctrld spash image](/docs/ctrldsplash.png)
A highly configurable DNS forwarding proxy with support for:
- Multiple listeners for incoming queries
- Multiple upstreams with fallbacks

View File

@@ -18,15 +18,14 @@ import (
"path/filepath"
"reflect"
"runtime"
"runtime/debug"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/Masterminds/semver"
"github.com/cuonglm/osinfo"
"github.com/fsnotify/fsnotify"
"github.com/go-playground/validator/v10"
"github.com/kardianos/service"
"github.com/miekg/dns"
@@ -86,7 +85,7 @@ var rootCmd = &cobra.Command{
Use: "ctrld",
Short: strings.TrimLeft(rootShortDesc, "\n"),
Version: curVersion(),
PreRun: func(cmd *cobra.Command, args []string) {
PersistentPreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
}
@@ -127,9 +126,6 @@ func initCLI() {
Use: "run",
Short: "Run the DNS proxy server",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: func(cmd *cobra.Command, args []string) {
RunCobraCommand(cmd)
},
@@ -158,7 +154,6 @@ func initCLI() {
startCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "start",
@@ -187,11 +182,11 @@ func initCLI() {
return
}
status, err := s.Status()
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
status, _ := s.Status()
isCtrldRunning := status == service.StatusRunning
// If pin code was set, do not allow running start command.
if status == service.StatusRunning {
if isCtrldRunning {
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
os.Exit(deactivationPinInvalidExitCode)
}
@@ -255,13 +250,14 @@ func initCLI() {
// A buffer channel to gather log output from runCmd and report
// to user in case self-check process failed.
runCmdLogCh := make(chan string, 256)
if dir, err := userHomeDir(); err == nil {
setWorkingDirectory(sc, dir)
ud, err := userHomeDir()
sockDir := ud
if err == nil {
setWorkingDirectory(sc, ud)
if configPath == "" && writeDefaultConfig {
defaultConfigFile = filepath.Join(dir, defaultConfigFile)
defaultConfigFile = filepath.Join(ud, defaultConfigFile)
}
sc.Arguments = append(sc.Arguments, "--homedir="+dir)
sockDir := dir
sc.Arguments = append(sc.Arguments, "--homedir="+ud)
if d, err := socketDir(); err == nil {
sockDir = d
}
@@ -312,18 +308,11 @@ func initCLI() {
}
tasks := []task{
resetDnsTask(p, s),
{s.Stop, false},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true},
{func() error { return ensureUninstall(s) }, false},
{func() error {
// If ctrld is installed, we should not save current DNS settings, because:
//
// - The DNS settings was being set by ctrld already.
// - We could not determine the state of DNS settings before installing ctrld.
if isCtrldInstalled {
return nil
}
// Save current DNS so we can restore later.
withEachPhysicalInterfaces("", "save DNS settings", func(i *net.Interface) error {
return saveCurrentStaticDNS(i)
@@ -343,7 +332,7 @@ func initCLI() {
return
}
ok, status, err := selfCheckStatus(s)
ok, status, err := selfCheckStatus(s, ud, sockDir)
switch {
case ok && status == service.StatusRunning:
mainLog.Load().Notice().Msg("Service started")
@@ -381,7 +370,15 @@ func initCLI() {
uninstall(p, s)
os.Exit(1)
}
p.setDNS()
if cc := newSocketControlClient(s, sockDir); cc != nil {
if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK {
if iface == "auto" {
iface = defaultIfaceName()
}
logger := mainLog.Load().With().Str("iface", iface).Logger()
logger.Debug().Msg("setting DNS successfully")
}
}
}
},
}
@@ -401,12 +398,10 @@ func initCLI() {
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`)
routerCmd := &cobra.Command{
Use: "setup",
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: func(cmd *cobra.Command, _ []string) {
exe, err := os.Executable()
if err != nil {
@@ -434,7 +429,6 @@ func initCLI() {
stopCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "stop",
@@ -456,6 +450,23 @@ func initCLI() {
if doTasks([]task{{s.Stop, true}}) {
p.router.Cleanup()
p.resetDNS()
if router.WaitProcessExited() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
for {
select {
case <-ctx.Done():
mainLog.Load().Error().Msg("timeout while waiting for service to stop")
return
default:
}
time.Sleep(time.Second)
if status, _ := s.Status(); status == service.StatusStopped {
break
}
}
}
mainLog.Load().Notice().Msg("Service stopped")
}
},
@@ -466,14 +477,16 @@ func initCLI() {
restartCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "restart",
Short: "Restart the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := newService(&prog{}, svcConfig)
readConfig(false)
v.Unmarshal(&cfg)
p := &prog{router: router.New(&cfg, runInCdMode())}
s, err := newService(p, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
@@ -484,6 +497,7 @@ func initCLI() {
}
initLogging()
iface = runningIface(s)
tasks := []task{
{s.Stop, false},
{s.Start, true},
@@ -494,10 +508,12 @@ func initCLI() {
mainLog.Load().Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
return
}
if cc := newSocketControlClient(s, dir); cc == nil {
cc := newSocketControlClient(s, dir)
if cc == nil {
mainLog.Load().Notice().Msg("Service was not restarted")
os.Exit(1)
}
_, _ = cc.post(ifacePath, nil)
mainLog.Load().Notice().Msg("Service restarted")
}
},
@@ -505,7 +521,6 @@ func initCLI() {
reloadCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "reload",
@@ -551,9 +566,6 @@ func initCLI() {
Use: "status",
Short: "Show status of the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: func(cmd *cobra.Command, args []string) {
s, err := newService(&prog{}, svcConfig)
if err != nil {
@@ -581,14 +593,12 @@ func initCLI() {
if runtime.GOOS == "darwin" {
// On darwin, running status command without privileges may return wrong information.
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
}
}
uninstallCmd := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "uninstall",
@@ -623,9 +633,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
Use: "list",
Short: "List network interfaces of the host",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: func(cmd *cobra.Command, args []string) {
err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
fmt.Printf("Index : %d\n", i.Index)
@@ -686,7 +693,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
rootCmd.AddCommand(serviceCmd)
startCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "start",
@@ -704,7 +710,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
rootCmd.AddCommand(startCmdAlias)
stopCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "stop",
@@ -723,7 +728,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
restartCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "restart",
@@ -736,7 +740,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
reloadCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "reload",
@@ -751,16 +754,12 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
Use: "status",
Short: "Show status of the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: statusCmd.Run,
Run: statusCmd.Run,
}
rootCmd.AddCommand(statusCmdAlias)
uninstallCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "uninstall",
@@ -785,7 +784,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
Short: "List clients that ctrld discovered",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Run: func(cmd *cobra.Command, args []string) {
@@ -873,25 +871,31 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
ValidArgs: []string{upgradeChannelDev, upgradeChannelProd},
Args: cobra.MaximumNArgs(1),
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Run: func(cmd *cobra.Command, args []string) {
s, err := newService(&prog{}, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
}
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("service not installed")
return
}
bin, err := os.Executable()
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path")
}
sc := &service.Config{}
*sc = *svcConfig
sc.Executable = bin
readConfig(false)
v.Unmarshal(&cfg)
p := &prog{router: router.New(&cfg, runInCdMode())}
s, err := newService(p, sc)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
}
svcInstalled := true
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
svcInstalled = false
}
oldBin := bin + "_previous"
urlString := upgradeChannel[upgradeChannelDefault]
baseUrl := upgradeChannel[upgradeChannelDefault]
if len(args) > 0 {
channel := args[0]
switch channel {
@@ -899,12 +903,9 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
default:
mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev)
}
urlString = upgradeChannel[channel]
}
dlUrl := fmt.Sprintf("%s/%s-%s/ctrld", urlString, runtime.GOOS, runtime.GOARCH)
if runtime.GOOS == "windows" {
dlUrl += ".exe"
baseUrl = upgradeChannel[channel]
}
dlUrl := upgradeUrl(baseUrl)
mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl)
resp, err := http.Get(dlUrl)
if err != nil {
@@ -923,18 +924,26 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
}
doRestart := func() bool {
if !svcInstalled {
return true
}
tasks := []task{
{s.Stop, false},
{s.Start, false},
}
if doTasks(tasks) {
if dir, err := socketDir(); err == nil {
return newSocketControlClient(s, dir) != nil
if cc := newSocketControlClient(s, dir); cc != nil {
_, _ = cc.post(ifacePath, nil)
return true
}
}
}
return false
}
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
if svcInstalled {
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
}
if doRestart() {
_ = os.Remove(oldBin)
_ = os.Chmod(bin, 0755)
@@ -1608,7 +1617,7 @@ func defaultIfaceName() string {
// - External testing, ensuring query could be sent from ctrld -> upstream.
//
// Self-check is considered success only if both tests are ok.
func selfCheckStatus(s service.Service) (bool, service.Status, error) {
func selfCheckStatus(s service.Service, homedir, sockDir string) (bool, service.Status, error) {
status, err := s.Status()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not get service status")
@@ -1618,117 +1627,80 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
if status != service.StatusRunning {
return false, status, nil
}
dir, err := socketDir()
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to check ctrld listener status: could not get home directory")
return false, status, err
// Skip self checks if set.
if skipSelfChecks {
return true, status, nil
}
mainLog.Load().Debug().Msg("waiting for ctrld listener to be ready")
cc := newSocketControlClient(s, dir)
cc := newSocketControlClient(s, sockDir)
if cc == nil {
return false, status, errors.New("could not connect to control server")
}
resp, err := cc.post(startedPath, nil)
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to connect to control server")
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
ctrld.SetConfigNameWithPath(v, "ctrld", homedir)
if configPath != "" {
v.SetConfigFile(configPath)
}
if err := v.ReadInConfig(); err != nil {
mainLog.Load().Error().Err(err).Msgf("failed to re-read configuration file: %s", v.ConfigFileUsed())
return false, status, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
mainLog.Load().Error().Msg("ctrld listener is not ready")
return false, status, errors.New("ctrld listener is not ready")
cfg = ctrld.Config{}
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Error().Err(err).Msg("failed to update new config")
return false, status, err
}
// Not a ctrld upstream, return status as-is.
if cfg.FirstUpstream().VerifyDomain() == "" {
selfCheckExternalDomain := cfg.FirstUpstream().VerifyDomain()
if selfCheckExternalDomain == "" {
// Nothing to do, return the status as-is.
return true, status, nil
}
mainLog.Load().Debug().Msg("ctrld listener is ready")
mainLog.Load().Debug().Msg("performing self-check")
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
bo.LogLongerThan = 500 * time.Millisecond
ctx := context.Background()
maxAttempts := 20
c := new(dns.Client)
var (
lcChanged map[string]*ctrld.ListenerConfig
ucChanged map[string]*ctrld.UpstreamConfig
mu sync.Mutex
)
if err := v.ReadInConfig(); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to read new config")
}
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to update new config")
}
domain := cfg.FirstUpstream().VerifyDomain()
if domain == "" {
// Nothing to do, return the status as-is.
return true, status, nil
}
watcher, err := fsnotify.NewWatcher()
if err != nil {
mainLog.Load().Error().Err(err).Msg("could not watch config change")
lc := cfg.FirstListener()
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
if err := selfCheckResolveDomain(context.TODO(), addr, "internal", selfCheckInternalTestDomain); err != nil {
return false, status, err
}
defer watcher.Close()
if err := selfCheckResolveDomain(context.TODO(), addr, "external", selfCheckExternalDomain); err != nil {
return false, status, err
}
return true, status, nil
}
// selfCheckResolveDomain performs DNS test query against ctrld listener.
func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain string) error {
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
bo.LogLongerThan = 500 * time.Millisecond
maxAttempts := 20
c := new(dns.Client)
v.OnConfigChange(func(in fsnotify.Event) {
mu.Lock()
defer mu.Unlock()
if err := v.UnmarshalKey("listener", &lcChanged); err != nil {
mainLog.Load().Error().Msgf("failed to unmarshal listener config: %v", err)
return
}
if err := v.UnmarshalKey("upstream", &ucChanged); err != nil {
mainLog.Load().Error().Msgf("failed to unmarshal upstream config: %v", err)
return
}
})
v.WatchConfig()
var (
lastAnswer *dns.Msg
lastErr error
internalTested bool
lastAnswer *dns.Msg
lastErr error
)
for i := 0; i < maxAttempts; i++ {
mu.Lock()
if lcChanged != nil {
cfg.Listener = lcChanged
}
if ucChanged != nil {
cfg.Upstream = ucChanged
}
mu.Unlock()
lc := cfg.FirstListener()
domain = cfg.FirstUpstream().VerifyDomain()
if !internalTested {
domain = selfCheckInternalTestDomain
}
if domain == "" {
continue
}
for i := 0; i < maxAttempts; i++ {
if domain == "" {
return errors.New("empty test domain")
}
m := new(dns.Msg)
m.SetQuestion(domain+".", dns.TypeA)
m.RecursionDesired = true
r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)))
r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, addr)
if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 {
internalTested = domain == selfCheckInternalTestDomain
if internalTested {
mainLog.Load().Debug().Msgf("internal self-check against %q succeeded", domain)
continue // internal domain test ok, continue with external test.
} else {
mainLog.Load().Debug().Msgf("external self-check against %q succeeded", domain)
}
return true, status, nil
mainLog.Load().Debug().Msgf("%s self-check against %q succeeded", scope, domain)
return nil
}
// Return early if this is a connection refused.
if errConnectionRefused(exErr) {
return false, status, exErr
return exErr
}
lastAnswer = r
lastErr = exErr
@@ -1741,8 +1713,6 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint)
}
}
lc := cfg.FirstListener()
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
marker := strings.Repeat("=", 32)
mainLog.Load().Debug().Msg(marker)
mainLog.Load().Debug().Msgf("listener address : %s", addr)
@@ -1753,9 +1723,8 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
for _, s := range strings.Split(lastAnswer.String(), "\n") {
mainLog.Load().Debug().Msgf("%s", s)
}
return false, status, errSelfCheckNoAnswer
}
return false, status, lastErr
return errSelfCheckNoAnswer
}
func userHomeDir() (string, error) {
@@ -2293,6 +2262,10 @@ func removeProvTokenFromArgs(sc *service.Config) {
// newSocketControlClient returns new control client after control server was started.
func newSocketControlClient(s service.Service, dir string) *controlClient {
// Return early if service is not running.
if status, err := s.Status(); err != nil || status != service.StatusRunning {
return nil
}
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
bo.LogLongerThan = 10 * time.Second
ctx := context.Background()
@@ -2302,28 +2275,21 @@ func newSocketControlClient(s service.Service, dir string) *controlClient {
defer timeout.Stop()
// The socket control server may not start yet, so attempt to ping
// it until we got a response. For each iteration, check ctrld status
// to make sure ctrld is still running.
// it until we got a response.
for {
curStatus, err := s.Status()
if err != nil {
return nil
}
if curStatus != service.StatusRunning {
return nil
}
if _, err := cc.post("/", nil); err == nil {
_, err := cc.post(startedPath, nil)
if err == nil {
// Server was started, stop pinging.
break
}
// The socket control server is not ready yet, backoff for waiting it to be ready.
bo.BackOff(ctx, err)
select {
case <-timeout.C:
return nil
default:
}
continue
}
return cc
@@ -2496,11 +2462,6 @@ func powershell(cmd string) ([]byte, error) {
// windowsHasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func windowsHasLocalDnsServerRunning() bool {
if runtime.GOOS == "windows" {
out, _ := powershell("Get-WindowsFeature -Name DNS")
if !bytes.Contains(bytes.ToLower(out), []byte("installed")) {
return false
}
_, err := powershell("Get-Process -Name DNS")
return err == nil
}
@@ -2535,3 +2496,79 @@ func runInCdMode() bool {
}
return false
}
// goArm returns the GOARM value for the binary.
func goArm() string {
if runtime.GOARCH != "arm" {
return ""
}
if bi, ok := debug.ReadBuildInfo(); ok {
for _, setting := range bi.Settings {
if setting.Key == "GOARM" {
return setting.Value
}
}
}
// Use ARM v5 as a fallback, since it works on all others.
return "5"
}
// upgradeUrl returns the url for downloading new ctrld binary.
func upgradeUrl(baseUrl string) string {
dlPath := fmt.Sprintf("%s-%s/ctrld", runtime.GOOS, runtime.GOARCH)
if armVersion := goArm(); armVersion != "" {
dlPath = fmt.Sprintf("%s-%sv%s/ctrld", runtime.GOOS, runtime.GOARCH, armVersion)
}
dlUrl := fmt.Sprintf("%s/%s", baseUrl, dlPath)
if runtime.GOOS == "windows" {
dlUrl += ".exe"
}
return dlUrl
}
// runningIface returns the value of the iface variable used by ctrld process which is running.
func runningIface(s service.Service) string {
if sockDir, err := socketDir(); err == nil {
if cc := newSocketControlClient(s, sockDir); cc != nil {
resp, err := cc.post(ifacePath, nil)
if err != nil {
return ""
}
defer resp.Body.Close()
if buf, _ := io.ReadAll(resp.Body); len(buf) > 0 {
return string(buf)
}
}
}
return ""
}
// resetDnsNoLog performs resetting DNS with logging disable.
func resetDnsNoLog(p *prog) {
lvl := zerolog.GlobalLevel()
zerolog.SetGlobalLevel(zerolog.Disabled)
p.resetDNS()
zerolog.SetGlobalLevel(lvl)
}
// resetDnsTask returns a task which perform reset DNS operation.
func resetDnsTask(p *prog, s service.Service) task {
status, err := s.Status()
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
isCtrldRunning := status == service.StatusRunning
return task{func() error {
// Always reset DNS first, ensuring DNS setting is in a good state.
// resetDNS must use the "iface" value of current running ctrld
// process to reset what setDNS has done properly.
oldIface := iface
iface = "auto"
if isCtrldRunning {
iface = runningIface(s)
}
if isCtrldInstalled {
resetDnsNoLog(p)
}
iface = oldIface
return nil
}, false}
}

View File

@@ -10,6 +10,8 @@ import (
"sort"
"time"
"github.com/kardianos/service"
dto "github.com/prometheus/client_model/go"
"github.com/Control-D-Inc/ctrld"
@@ -22,6 +24,7 @@ const (
reloadPath = "/reload"
deactivationPath = "/deactivation"
cdPath = "/cd"
ifacePath = "/iface"
)
type controlServer struct {
@@ -179,6 +182,17 @@ func (p *prog) registerControlServerHandler() {
}
w.WriteHeader(http.StatusBadRequest)
}))
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
// p.setDNS is only called when running as a service
if !service.Interactive() {
<-p.csSetDnsDone
if p.csSetDnsOk {
w.Write([]byte(iface))
return
}
}
w.WriteHeader(http.StatusBadRequest)
}))
}
func jsonResponse(next http.Handler) http.Handler {

View File

@@ -210,12 +210,9 @@ func (p *prog) serveDNS(listenerNum string) error {
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
s, errCh := runDNSServer(addr, proto, handler)
defer s.Shutdown()
select {
case err := <-errCh:
return err
case <-time.After(5 * time.Second):
p.started <- struct{}{}
}
p.started <- struct{}{}
select {
case <-p.stopCh:
case <-ctx.Done():
@@ -752,20 +749,19 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha
Handler: handler,
}
waitLock := sync.Mutex{}
waitLock.Lock()
s.NotifyStartedFunc = waitLock.Unlock
startedCh := make(chan struct{})
s.NotifyStartedFunc = func() { sync.OnceFunc(func() { close(startedCh) })() }
errCh := make(chan error)
go func() {
defer close(errCh)
if err := s.ListenAndServe(); err != nil {
waitLock.Unlock()
s.NotifyStartedFunc()
mainLog.Load().Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
errCh <- err
}
}()
waitLock.Lock()
<-startedCh
return s, errCh
}

View File

@@ -105,6 +105,10 @@ func (p *prog) checkDnsLoop() {
for uid := range p.loop {
msg := loopTestMsg(uid)
uc := upstream[uid]
// Skipping upstream which is being marked as down.
if uc == nil {
continue
}
resolver, err := ctrld.NewResolver(uc)
if err != nil {
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)

View File

@@ -35,6 +35,7 @@ var (
nextdns string
cdUpstreamProto string
deactivationPin int64
skipSelfChecks bool
mainLog atomic.Pointer[zerolog.Logger]
consoleWriter zerolog.ConsoleWriter

View File

@@ -41,17 +41,12 @@ func setDNS(iface *net.Interface, nameservers []string) error {
// Configuring the Dns server to forward queries to ctrld instead.
if windowsHasLocalDnsServerRunning() {
file := absHomeDir(forwardersFilename)
if data, _ := os.ReadFile(file); len(data) > 0 {
if err := removeDnsServerForwarders(strings.Split(string(data), ",")); err != nil {
mainLog.Load().Error().Err(err).Msg("could not remove current forwarders settings")
} else {
mainLog.Load().Debug().Msg("removed current forwarders settings.")
}
}
oldForwardersContent, _ := os.ReadFile(file)
if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
}
if err := addDnsServerForwarders(nameservers); err != nil {
oldForwarders := strings.Split(string(oldForwardersContent), ",")
if err := addDnsServerForwarders(nameservers, oldForwarders); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
}
}
@@ -213,14 +208,32 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) {
return ns, nil
}
// addDnsServerForwarders adds given nameservers to DNS server forwarders list.
func addDnsServerForwarders(nameservers []string) error {
for _, ns := range nameservers {
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", ns)
if out, err := powershell(cmd); err != nil {
return fmt.Errorf("%w: %s", err, string(out))
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
// and also removing old forwarders if provided.
func addDnsServerForwarders(nameservers, old []string) error {
newForwardersMap := make(map[string]struct{})
newForwarders := make([]string, len(nameservers))
for i := range nameservers {
newForwardersMap[nameservers[i]] = struct{}{}
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
}
oldForwarders := old[:0]
for _, fwd := range old {
if _, ok := newForwardersMap[fwd]; !ok {
oldForwarders = append(oldForwarders, fwd)
}
}
// NOTE: It is important to add new forwarder before removing old one.
// Testing on Windows Server 2022 shows that removing forwarder1
// then adding forwarder2 sometimes ends up adding both of them
// to the forwarders list.
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
if len(oldForwarders) > 0 {
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
}
if out, err := powershell(cmd); err != nil {
return fmt.Errorf("%w: %s", err, string(out))
}
return nil
}

View File

@@ -69,6 +69,8 @@ type prog struct {
reloadDoneCh chan struct{}
logConn net.Conn
cs *controlServer
csSetDnsDone chan struct{}
csSetDnsOk bool
cfg *ctrld.Config
localUpstreams []string
@@ -194,9 +196,6 @@ func (p *prog) runWait() {
}
func (p *prog) preRun() {
if !service.Interactive() {
p.setDNS()
}
if runtime.GOOS == "darwin" {
p.onStopped = append(p.onStopped, func() {
if !service.Interactive() {
@@ -206,6 +205,15 @@ func (p *prog) preRun() {
}
}
func (p *prog) postRun() {
if !service.Interactive() {
p.resetDNS()
ns := ctrld.InitializeOsResolver()
mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns)
p.setDNS()
}
}
func (p *prog) setupUpstream(cfg *ctrld.Config) {
localUpstreams := make([]string, 0, len(cfg.Upstream))
ptrNameservers := make([]string, 0, len(cfg.Upstream))
@@ -249,6 +257,14 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
numListeners := len(p.cfg.Listener)
if !reload {
p.started = make(chan struct{}, numListeners)
if p.cs != nil {
p.csSetDnsDone = make(chan struct{}, 1)
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
}
mainLog.Load().Debug().Msgf("control server started: %s", p.cs.addr)
}
}
p.onStartedDone = make(chan struct{})
p.loop = make(map[string]bool)
@@ -381,12 +397,7 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
if p.logConn != nil {
_ = p.logConn.Close()
}
if p.cs != nil {
p.registerControlServerHandler()
if err := p.cs.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not start control server")
}
}
p.postRun()
}
wg.Wait()
}
@@ -430,17 +441,25 @@ func (p *prog) deAllocateIP() error {
}
func (p *prog) setDNS() {
setDnsOK := false
defer func() {
p.csSetDnsOk = setDnsOK
p.csSetDnsDone <- struct{}{}
close(p.csSetDnsDone)
}()
if cfg.Listener == nil {
return
}
if iface == "" {
return
}
runningIface := iface
// allIfaces tracks whether we should set DNS for all physical interfaces.
allIfaces := false
if iface == "auto" {
iface = defaultIfaceName()
// If iface is "auto", it means user does not specify "--iface" flag.
if runningIface == "auto" {
runningIface = defaultIfaceName()
// If runningIface is "auto", it means user does not specify "--iface" flag.
// In this case, ctrld has to set DNS for all physical interfaces, so
// thing will still work when user switch from one to the other.
allIfaces = requiredMultiNICsConfig()
@@ -449,8 +468,8 @@ func (p *prog) setDNS() {
if lc == nil {
return
}
logger := mainLog.Load().With().Str("iface", iface).Logger()
netIface, err := netInterface(iface)
logger := mainLog.Load().With().Str("iface", runningIface).Logger()
netIface, err := netInterface(runningIface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
return
@@ -484,6 +503,7 @@ func (p *prog) setDNS() {
logger.Error().Err(err).Msgf("could not set DNS for interface")
return
}
setDnsOK = true
logger.Debug().Msg("setting DNS successfully")
if shouldWatchResolvconf() {
servers := make([]netip.Addr, len(nameservers))
@@ -503,14 +523,15 @@ func (p *prog) resetDNS() {
if iface == "" {
return
}
runningIface := iface
allIfaces := false
if iface == "auto" {
iface = defaultIfaceName()
if runningIface == "auto" {
runningIface = defaultIfaceName()
// See corresponding comments in (*prog).setDNS function.
allIfaces = requiredMultiNICsConfig()
}
logger := mainLog.Load().With().Str("iface", iface).Logger()
netIface, err := netInterface(iface)
logger := mainLog.Load().With().Str("iface", runningIface).Logger()
netIface, err := netInterface(runningIface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
return

View File

@@ -1,7 +1,12 @@
package cli
import (
"bufio"
"bytes"
"io"
"os"
"os/exec"
"strings"
"github.com/kardianos/service"
@@ -24,12 +29,34 @@ func setDependencies(svc *service.Config) {
"After=network-online.target",
"Wants=NetworkManager-wait-online.service",
"After=NetworkManager-wait-online.service",
"Wants=systemd-networkd-wait-online.service",
"Wants=nss-lookup.target",
"After=nss-lookup.target",
}
if out, _ := exec.Command("networkctl", "--no-pager").CombinedOutput(); len(out) > 0 {
if wantsSystemDNetworkdWaitOnline(bytes.NewReader(out)) {
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
}
}
}
func setWorkingDirectory(svc *service.Config, dir string) {
svc.WorkingDirectory = dir
}
// wantsSystemDNetworkdWaitOnline reports whether "systemd-networkd-wait-online" service
// is required to be added to ctrld dependencies services.
// The input reader r is the output of "networkctl --no-pager" command.
func wantsSystemDNetworkdWaitOnline(r io.Reader) bool {
scanner := bufio.NewScanner(r)
// Skip header
scanner.Scan()
configured := false
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) > 0 && fields[len(fields)-1] == "configured" {
configured = true
break
}
}
return configured
}

View File

@@ -0,0 +1,48 @@
package cli
import (
"io"
"strings"
"testing"
)
const (
networkctlUnmanagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
1 lo loopback carrier unmanaged
2 wlp0s20f3 wlan routable unmanaged
3 tailscale0 none routable unmanaged
4 br-9ac33145e060 bridge no-carrier unmanaged
5 docker0 bridge no-carrier unmanaged
5 links listed.
`
networkctlManagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
1 lo loopback carrier unmanaged
2 wlp0s20f3 wlan routable configured
3 tailscale0 none routable unmanaged
4 br-9ac33145e060 bridge no-carrier unmanaged
5 docker0 bridge no-carrier unmanaged
5 links listed.
`
)
func Test_wantsSystemDNetworkdWaitOnline(t *testing.T) {
tests := []struct {
name string
r io.Reader
required bool
}{
{"unmanaged", strings.NewReader(networkctlUnmanagedOutput), false},
{"managed", strings.NewReader(networkctlManagedOutput), true},
{"empty", strings.NewReader(""), false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
if required := wantsSystemDNetworkdWaitOnline(tc.r); required != tc.required {
t.Errorf("wants %v got %v", tc.required, required)
}
})
}
}

View File

@@ -21,7 +21,7 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
}
switch {
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
return &procd{&sysV{s}}, nil
return &procd{sysV: &sysV{s}, svcConfig: c}, nil
case router.IsGLiNet():
return &sysV{s}, nil
case s.Platform() == "unix-systemv":
@@ -89,18 +89,24 @@ func (s *sysV) Status() (service.Status, error) {
// like old GL.iNET Opal router.
type procd struct {
*sysV
svcConfig *service.Config
}
func (s *procd) Status() (service.Status, error) {
if !s.installed() {
return service.StatusUnknown, service.ErrNotInstalled
}
exe, err := os.Executable()
if err != nil {
return service.StatusUnknown, nil
bin := s.svcConfig.Executable
if bin == "" {
exe, err := os.Executable()
if err != nil {
return service.StatusUnknown, nil
}
bin = exe
}
// Looking for something like "/sbin/ctrld run ".
shellCmd := fmt.Sprintf("ps | grep -q %q", exe+" [r]un ")
shellCmd := fmt.Sprintf("ps | grep -q %q", bin+" [r]un ")
if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil {
return service.StatusStopped, nil
}

View File

@@ -336,7 +336,7 @@ The protocol that `ctrld` will use to send DNS requests to upstream.
- Type: string
- Required: yes
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os`
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`
### ip_stack
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.

BIN
docs/ctrldsplash.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 458 KiB

View File

@@ -18,6 +18,7 @@ start_service() {
procd_set_param stdout 1 # forward stdout of the command to logd
procd_set_param stderr 1 # same for stderr
procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop
procd_set_param term_timeout 10
procd_close_instance
echo "${name} has been started"
}

View File

@@ -98,6 +98,11 @@ func IsOldOpenwrt() bool {
return cmd == ""
}
// WaitProcessExited reports whether the "ctrld stop" command have to wait until ctrld process exited.
func WaitProcessExited() bool {
return Name() == openwrt.Name
}
var routerPlatform atomic.Pointer[router]
type router struct {

View File

@@ -49,11 +49,15 @@ func (s *merlinSvc) Platform() string {
}
func (s *merlinSvc) configPath() string {
path, err := os.Executable()
if err != nil {
return ""
bin := s.Config.Executable
if bin == "" {
path, err := os.Executable()
if err != nil {
return ""
}
bin = path
}
return path + ".startup"
return bin + ".startup"
}
func (s *merlinSvc) template() *template.Template {

View File

@@ -1,12 +1,9 @@
package ctrld
import (
"net"
"syscall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"golang.org/x/sys/windows"
)
func dnsFns() []dnsFn {
@@ -20,40 +17,23 @@ func dnsFromAdapter() []string {
}
ns := make([]string, 0, len(aas)*2)
seen := make(map[string]bool)
do := func(addr windows.SocketAddress) {
sa, err := addr.Sockaddr.Sockaddr()
if err != nil {
return
addressMap := make(map[string]struct{})
for _, aa := range aas {
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
addressMap[a.Address.IP().String()] = struct{}{}
}
var ip net.IP
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
case *syscall.SockaddrInet6:
ip = make(net.IP, net.IPv6len)
copy(ip, sa.Addr[:])
if ip[0] == 0xfe && ip[1] == 0xc0 {
// Ignore these fec0/10 ones. Windows seems to
// populate them as defaults on its misc rando
// interfaces.
return
}
default:
return
}
if ip.IsLoopback() || seen[ip.String()] {
return
}
seen[ip.String()] = true
ns = append(ns, ip.String())
}
for _, aa := range aas {
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
do(dns.Address)
}
for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next {
do(gw.Address)
ip := dns.Address.IP()
if ip == nil || ip.IsLoopback() || seen[ip.String()] {
continue
}
if _, ok := addressMap[ip.String()]; ok {
continue
}
seen[ip.String()] = true
ns = append(ns, ip.String())
}
}
return ns

View File

@@ -35,13 +35,26 @@ const bootstrapDNS = "76.76.2.22"
// or is the Resolver used for ResolverTypeOS.
var or = &osResolver{nameservers: defaultNameservers()}
// defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver.
// defaultNameservers returns nameservers used by the OS.
// If no nameservers can be found, ctrld bootstrap nameserver will be used.
func defaultNameservers() []string {
ns := nameservers()
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
if len(ns) == 0 {
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
}
return ns
}
// InitializeOsResolver initializes OS resolver using the current system DNS settings.
// It returns the nameservers that is going to be used by the OS resolver.
//
// It's the caller's responsibility to ensure the system DNS is in a clean state before
// calling this function.
func InitializeOsResolver() []string {
or.nameservers = defaultNameservers()
return or.nameservers
}
// Resolver is the interface that wraps the basic DNS operations.
//
// Resolve resolves the DNS query, return the result and the corresponding error.