diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index b48813d..173d1f7 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -103,6 +103,10 @@ func initCLI() { if err := v.Unmarshal(&cfg); err != nil { log.Fatalf("failed to unmarshal config: %v", err) } + // Wait for network up. + if !netUp() { + log.Fatal("network is not up yet") + } processCDFlags() if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil { log.Fatalf("invalid config: %v", err) diff --git a/cmd/ctrld/net.go b/cmd/ctrld/net.go index f98feb9..96450c0 100644 --- a/cmd/ctrld/net.go +++ b/cmd/ctrld/net.go @@ -1,20 +1,41 @@ package main import ( + "context" + "fmt" "net" "sync" + "time" + + "tailscale.com/logtail/backoff" ) -const controldIPv6Test = "ipv6.controld.io" +const ( + controldIPv6Test = "ipv6.controld.io" + controldIPv4Test = "ipv4.controld.io" +) var ( stackOnce sync.Once ipv6Enabled bool canListenIPv6Local bool + hasNetworkUp bool ) func probeStack() { - if _, err := net.Dial("tcp6", controldIPv6Test); err == nil { + logf := func(format string, args ...any) { + fmt.Printf(format, args...) + } + b := backoff.NewBackoff("probeStack", logf, time.Minute) + for { + if _, err := net.Dial("tcp", net.JoinHostPort(controldIPv4Test, "80")); err == nil { + hasNetworkUp = true + break + } else { + b.BackOff(context.Background(), err) + } + } + if _, err := net.Dial("tcp6", net.JoinHostPort(controldIPv6Test, "80")); err == nil { ipv6Enabled = true } if ln, err := net.Listen("tcp6", "[::1]:53"); err == nil { @@ -23,6 +44,11 @@ func probeStack() { } } +func netUp() bool { + stackOnce.Do(probeStack) + return hasNetworkUp +} + func supportsIPv6() bool { stackOnce.Do(probeStack) return ipv6Enabled diff --git a/internal/controld/config.go b/internal/controld/config.go index 3ca1c5b..51a1353 100644 --- a/internal/controld/config.go +++ b/internal/controld/config.go @@ -2,8 +2,10 @@ package controld import ( "bytes" + "context" "encoding/json" "fmt" + "net" "net/http" "time" ) @@ -13,6 +15,20 @@ const ( InvalidConfigCode = 40401 ) +const bootstrapDNS = "76.76.2.0:53" + +var dialer = &net.Dialer{ + Resolver: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{ + Timeout: 10 * time.Second, + } + return d.DialContext(ctx, "udp", bootstrapDNS) + }, + }, +} + // ResolverConfig represents Control D resolver data. type ResolverConfig struct { DOH string `json:"doh"` @@ -52,7 +68,14 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) { q.Set("platform", "ctrld") req.URL.RawQuery = q.Encode() req.Header.Add("Content-Type", "application/json") - client := http.Client{Timeout: 5 * time.Second} + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, addr) + } + client := http.Client{ + Timeout: 10 * time.Second, + Transport: transport, + } resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("client.Do: %w", err)