diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 22497a7..1ba562e 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "encoding/base64" "errors" "fmt" @@ -14,12 +15,15 @@ import ( "runtime" "strconv" "strings" + "time" "github.com/go-playground/validator/v10" "github.com/kardianos/service" + "github.com/miekg/dns" "github.com/pelletier/go-toml/v2" "github.com/spf13/cobra" "github.com/spf13/viper" + "tailscale.com/logtail/backoff" "tailscale.com/net/interfaces" "github.com/Control-D-Inc/ctrld" @@ -27,6 +31,8 @@ import ( ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) +const selfCheckFQDN = "verify.controld.com" + var ( v = viper.NewWithOptions(viper.KeyDelimiter("::")) defaultConfigWritten = false @@ -234,8 +240,24 @@ func initCLI() { {s.Start, true}, } if doTasks(tasks) { + status, err := s.Status() + if err != nil { + mainLog.Warn().Err(err).Msg("could not get service status") + return + } + + status = selfCheckStatus(status) + switch status { + case service.StatusRunning: + mainLog.Info().Msg("Service started") + default: + mainLog.Error().Msg("Service did not start, please check system/service log for details error") + if runtime.GOOS == "linux" { + prog.resetDNS() + } + os.Exit(1) + } prog.setDNS() - mainLog.Info().Msg("Service started") } }, } @@ -549,7 +571,7 @@ func processCDFlags() { iface = "auto" } logger := mainLog.With().Str("mode", "cd").Logger() - logger.Info().Msg("fetching Controld-D configuration") + logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) resolverConfig, err := controld.FetchResolverConfig(cdUID) if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { s, err := service.New(&prog{}, svcConfig) @@ -681,3 +703,27 @@ func defaultIfaceName() string { } return dri } + +func selfCheckStatus(status service.Status) service.Status { + c := new(dns.Client) + lc := cfg.Listener["0"] + bo := backoff.NewBackoff("self-check", logf, 10*time.Second) + bo.LogLongerThan = 500 * time.Millisecond + ctx := context.Background() + err := errors.New("query failed") + maxAttempts := 10 + mainLog.Debug().Msg("Performing self-check") + for i := 0; i < maxAttempts; i++ { + m := new(dns.Msg) + m.SetQuestion(selfCheckFQDN+".", dns.TypeA) + m.RecursionDesired = true + r, _, _ := c.ExchangeContext(ctx, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))) + if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 { + mainLog.Debug().Msgf("self-check against %q succeeded", selfCheckFQDN) + return status + } + bo.BackOff(ctx, err) + } + mainLog.Debug().Msgf("self-check against %q failed", selfCheckFQDN) + return service.StatusUnknown +} diff --git a/cmd/ctrld/os_linux.go b/cmd/ctrld/os_linux.go index 970de78..a29c7f4 100644 --- a/cmd/ctrld/os_linux.go +++ b/cmd/ctrld/os_linux.go @@ -23,10 +23,6 @@ import ( "github.com/Control-D-Inc/ctrld/internal/resolvconffile" ) -var logf = func(format string, args ...any) { - mainLog.Debug().Msgf(format, args...) -} - // allocate loopback ip // sudo ip a add 127.0.0.2/24 dev lo func allocateIP(ip string) error { diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 2046be4..e3d6239 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -16,6 +16,10 @@ import ( ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ) +var logf = func(format string, args ...any) { + mainLog.Debug().Msgf(format, args...) +} + var errWindowsAddrInUse = syscall.Errno(0x2740) var svcConfig = &service.Config{