diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 266c880..0bf85f2 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "io" "net" "net/netip" "runtime" @@ -14,11 +15,11 @@ import ( "sync" "time" - "tailscale.com/net/netmon" - "github.com/miekg/dns" "golang.org/x/sync/errgroup" + "tailscale.com/net/captivedetection" "tailscale.com/net/netaddr" + "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" "github.com/Control-D-Inc/ctrld" @@ -494,12 +495,21 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse { answer, err := resolve1(n, upstreamConfig, msg) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") - if errNetworkError(err) { + isNetworkErr := errNetworkError(err) + if isNetworkErr { p.um.increaseFailureCount(upstreams[n]) if p.um.isDown(upstreams[n]) { go p.um.checkUpstream(upstreams[n], upstreamConfig) } } + if cdUID != "" && (isNetworkErr || err == io.EOF) { + p.captivePortalMu.Lock() + if !p.captivePortalCheckWasRun { + p.captivePortalCheckWasRun = true + go p.performCaptivePortalDetection() + } + p.captivePortalMu.Unlock() + } // For timeout error (i.e: context deadline exceed), force re-bootstrapping. var e net.Error if errors.As(err, &e) && e.Timeout() { @@ -585,6 +595,9 @@ func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstre } func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { + if p.captivePortalDetected.Load() { + return nil // always use OS resolver if behind captive portal. + } upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix) @@ -888,6 +901,31 @@ func (p *prog) selfUninstallCoolOfPeriod() { p.selfUninstallMu.Unlock() } +// performCaptivePortalDetection check if ctrld is running behind a captive portal. +func (p *prog) performCaptivePortalDetection() { + mainLog.Load().Warn().Msg("Performing captive portal detection") + d := captivedetection.NewDetector(logf) + found := true + var resetDnsOnce sync.Once + for found { + time.Sleep(2 * time.Second) + found = d.Detect(context.Background(), netmon.NewStatic(), nil, 0) + if found { + resetDnsOnce.Do(func() { + mainLog.Load().Warn().Msg("found captive portal, leaking query to OS resolver") + p.resetDNS() + }) + } + p.captivePortalDetected.Store(found) + } + + p.captivePortalMu.Lock() + p.captivePortalCheckWasRun = false + p.captivePortalMu.Unlock() + p.setDNS() + mainLog.Load().Warn().Msg("captive portal login finished, stop leaking query") +} + // queryFromSelf reports whether the input IP is from device running ctrld. func queryFromSelf(ip string) bool { netIP := netip.MustParseAddr(ip) diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index d1ff0b8..711e966 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -105,6 +105,10 @@ type prog struct { loopMu sync.Mutex loop map[string]bool + captivePortalMu sync.Mutex + captivePortalCheckWasRun bool + captivePortalDetected atomic.Bool + started chan struct{} onStartedDone chan struct{} onStarted []func() @@ -240,6 +244,8 @@ func (p *prog) postRun() { ns := ctrld.InitializeOsResolver() mainLog.Load().Debug().Msgf("initialized OS resolver with nameservers: %v", ns) p.setDNS() + p.csSetDnsDone <- struct{}{} + close(p.csSetDnsDone) } } @@ -534,8 +540,6 @@ func (p *prog) setDNS() { setDnsOK := false defer func() { p.csSetDnsOk = setDnsOK - p.csSetDnsDone <- struct{}{} - close(p.csSetDnsDone) }() if cfg.Listener == nil {