diff --git a/cmd/cli/commands.go b/cmd/cli/commands.go index 70d4467..8396c19 100644 --- a/cmd/cli/commands.go +++ b/cmd/cli/commands.go @@ -561,6 +561,17 @@ func initStopCmd() *cobra.Command { } initLogging() + + status, err := s.Status() + if errors.Is(err, service.ErrNotInstalled) { + mainLog.Load().Warn().Msg("service not installed") + return + } + if status == service.StatusStopped { + mainLog.Load().Warn().Msg("service is already stopped") + return + } + if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) { os.Exit(deactivationPinInvalidExitCode) } diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 8390680..fd49764 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -578,7 +578,9 @@ func (p *prog) metricsEnabled() bool { func (p *prog) Stop(s service.Service) error { p.stopDnsWatchers() mainLog.Load().Debug().Msg("dns watchers stopped") - mainLog.Load().Info().Msg("Service stopped") + defer func() { + mainLog.Load().Info().Msg("Service stopped") + }() close(p.stopCh) if err := p.deAllocateIP(); err != nil { mainLog.Load().Error().Err(err).Msg("de-allocate ip failed") diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index 6e3bd82..c4df5a5 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -68,9 +68,9 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error { // Then proceed with existing actions, e.g. setting failure actions actions := []mgr.RecoveryAction{ - {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds - {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds - {Type: mgr.ServiceRestart, Delay: time.Second * 2}, // 2 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds + {Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds } // Set the recovery actions (3 restarts, reset period = 120). diff --git a/resolver.go b/resolver.go index 49b81af..c26560b 100644 --- a/resolver.go +++ b/resolver.go @@ -8,7 +8,6 @@ import ( "net" "net/netip" "slices" - "strings" "sync" "sync/atomic" "time" @@ -212,6 +211,11 @@ type osResolverResult struct { lan bool } +type publicResponse struct { + answer *dns.Msg + server string +} + // Resolve resolves DNS queries using pre-configured nameservers. // Query is sent to all nameservers concurrently, and the first // success response will be returned. @@ -257,33 +261,37 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } logAnswer := func(server string) { - if before, _, found := strings.Cut(server, ":"); found { - server = before + host, _, err := net.SplitHostPort(server) + if err != nil { + // If splitting fails, fallback to the original server string + host = server } - Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", server) + Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host) } var ( nonSuccessAnswer *dns.Msg nonSuccessServer string controldSuccessAnswer *dns.Msg - publicServerAnswer *dns.Msg - publicServer string + publicResponses []publicResponse ) errs := make([]error, 0, numServers) for res := range ch { switch { case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess: switch { - case res.server == controldPublicDnsWithPort: - controldSuccessAnswer = res.answer - case !res.lan && publicServerAnswer == nil: - publicServerAnswer = res.answer - publicServer = res.server - default: - Log(ctx, ProxyLogger.Load().Debug(), "got LAN answer from: %s", res.server) + case res.lan: + // Always prefer LAN responses immediately + Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server) cancel() logAnswer(res.server) return res.answer, nil + case res.server == controldPublicDnsWithPort: + controldSuccessAnswer = res.answer + case !res.lan: + publicResponses = append(publicResponses, publicResponse{ + answer: res.answer, + server: res.server, + }) } case res.answer != nil: nonSuccessAnswer = res.answer @@ -293,10 +301,12 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error } errs = append(errs, res.err) } - if publicServerAnswer != nil { - Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", publicServer) - logAnswer(publicServer) - return publicServerAnswer, nil + + if len(publicResponses) > 0 { + resp := publicResponses[0] + Log(ctx, ProxyLogger.Load().Debug(), "got public answer from: %s", resp.server) + logAnswer(resp.server) + return resp.answer, nil } if controldSuccessAnswer != nil { Log(ctx, ProxyLogger.Load().Debug(), "got ControlD answer from: %s", controldPublicDnsWithPort)