diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 5ec372c..2f17f82 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -91,7 +91,7 @@ func (p *prog) run() { mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap IP for upstream.%s", n) } uc.SetCertPool(rootCertPool) - uc.SetupTransport() + go uc.Ping() } go p.watchLinkState() diff --git a/config.go b/config.go index 6fa54e4..ac6904b 100644 --- a/config.go +++ b/config.go @@ -5,13 +5,16 @@ import ( "crypto/tls" "crypto/x509" "errors" + "io" "math/rand" "net" "net/http" "net/url" "os" + "runtime" "strings" "sync" + "sync/atomic" "time" "github.com/go-playground/validator/v10" @@ -154,11 +157,12 @@ type UpstreamConfig struct { SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"` g singleflight.Group - mu sync.Mutex + rebootstrap atomic.Bool bootstrapIPs []string bootstrapIPs4 []string bootstrapIPs6 []string transport *http.Transport + transportOnce sync.Once transport4 *http.Transport transport6 *http.Transport http3RoundTripper http.RoundTripper @@ -306,20 +310,11 @@ func (uc *UpstreamConfig) ReBootstrap() { } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { ProxyLog.Debug().Msg("re-bootstrapping upstream ip") - uc.setupTransportWithoutPingUpstream() + uc.rebootstrap.Store(true) return true, nil }) } -func (uc *UpstreamConfig) setupTransportWithoutPingUpstream() { - switch uc.Type { - case ResolverTypeDOH: - uc.setupDOHTransportWithoutPingUpstream() - case ResolverTypeDOH3: - uc.setupDOH3TransportWithoutPingUpstream() - } -} - // SetupTransport initializes the network transport used to connect to upstream server. // For now, only DoH upstream is supported. func (uc *UpstreamConfig) SetupTransport() { @@ -332,14 +327,31 @@ func (uc *UpstreamConfig) SetupTransport() { } func (uc *UpstreamConfig) setupDOHTransport() { - uc.setupDOHTransportWithoutPingUpstream() - go uc.pingUpstream() + switch uc.IPStack { + case IpStackBoth, "": + uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + case IpStackV4: + uc.transport = uc.newDOHTransport(uc.bootstrapIPs4) + case IpStackV6: + uc.transport = uc.newDOHTransport(uc.bootstrapIPs6) + case IpStackSplit: + uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4) + if hasIPv6() { + uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6) + } else { + uc.transport6 = uc.transport4 + } + uc.transport = uc.newDOHTransport(uc.bootstrapIPs) + } } func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() - transport.IdleConnTimeout = 5 * time.Second - transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} + transport.MaxIdleConnsPerHost = 100 + transport.TLSClientConfig = &tls.Config{ + RootCAs: uc.certPool, + ClientSessionCache: tls.NewLRUClientSessionCache(0), + } dialerTimeoutMs := 2000 if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs { @@ -368,44 +380,39 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport { Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", conn.RemoteAddr()) return conn, nil } + runtime.SetFinalizer(transport, func(transport *http.Transport) { + transport.CloseIdleConnections() + }) return transport } -func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() { - uc.mu.Lock() - defer uc.mu.Unlock() - switch uc.IPStack { - case IpStackBoth, "": - uc.transport = uc.newDOHTransport(uc.bootstrapIPs) - case IpStackV4: - uc.transport = uc.newDOHTransport(uc.bootstrapIPs4) - case IpStackV6: - uc.transport = uc.newDOHTransport(uc.bootstrapIPs6) - case IpStackSplit: - uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4) - if hasIPv6() { - uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6) - } else { - uc.transport6 = uc.transport4 - } - - uc.transport = uc.newDOHTransport(uc.bootstrapIPs) - } -} - -func (uc *UpstreamConfig) pingUpstream() { - // Warming up the transport by querying a test packet. - dnsResolver, err := NewResolver(uc) - if err != nil { - ProxyLog.Error().Err(err).Msgf("failed to create resolver for upstream: %s", uc.Name) +// Ping warms up the connection to DoH/DoH3 upstream. +func (uc *UpstreamConfig) Ping() { + switch uc.Type { + case ResolverTypeDOH, ResolverTypeDOH3: + default: return } - msg := new(dns.Msg) - msg.SetQuestion(".", dns.TypeNS) - msg.MsgHdr.RecursionDesired = true - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _, _ = dnsResolver.Resolve(ctx, msg) + + ping := func(t http.RoundTripper) { + if t == nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + req, _ := http.NewRequestWithContext(ctx, "HEAD", uc.Endpoint, nil) + resp, _ := t.RoundTrip(req) + if resp == nil { + return + } + defer resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + } + + for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} { + ping(uc.dohTransport(typ)) + ping(uc.doh3Transport(typ)) + } } func (uc *UpstreamConfig) isControlD() bool { @@ -424,8 +431,12 @@ func (uc *UpstreamConfig) isControlD() bool { } func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper { - uc.mu.Lock() - defer uc.mu.Unlock() + uc.transportOnce.Do(func() { + uc.SetupTransport() + }) + if uc.rebootstrap.CompareAndSwap(true, false) { + uc.SetupTransport() + } switch uc.IPStack { case IpStackBoth, IpStackV4, IpStackV6: return uc.transport diff --git a/config_quic.go b/config_quic.go index 085476e..32d338e 100644 --- a/config_quic.go +++ b/config_quic.go @@ -19,8 +19,24 @@ import ( ) func (uc *UpstreamConfig) setupDOH3Transport() { - uc.setupDOH3TransportWithoutPingUpstream() - go uc.pingUpstream() + switch uc.IPStack { + case IpStackBoth, "": + uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + case IpStackV4: + uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4) + case IpStackV6: + uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) + case IpStackSplit: + uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if ctrldnet.IPv6Available(ctx) { + uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) + } else { + uc.http3RoundTripper6 = uc.http3RoundTripper4 + } + uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) + } } func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { @@ -58,32 +74,13 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper { return rt } -func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() { - uc.mu.Lock() - defer uc.mu.Unlock() - switch uc.IPStack { - case IpStackBoth, "": - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) - case IpStackV4: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4) - case IpStackV6: - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6) - case IpStackSplit: - uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if ctrldnet.IPv6Available(ctx) { - uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6) - } else { - uc.http3RoundTripper6 = uc.http3RoundTripper4 - } - uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs) - } -} - func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { - uc.mu.Lock() - defer uc.mu.Unlock() + uc.transportOnce.Do(func() { + uc.SetupTransport() + }) + if uc.rebootstrap.CompareAndSwap(true, false) { + uc.SetupTransport() + } switch uc.IPStack { case IpStackBoth, IpStackV4, IpStackV6: return uc.http3RoundTripper diff --git a/config_quic_free.go b/config_quic_free.go index a4b1bdd..a674a1b 100644 --- a/config_quic_free.go +++ b/config_quic_free.go @@ -6,5 +6,4 @@ import "net/http" func (uc *UpstreamConfig) setupDOH3Transport() {} -func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {} func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { return nil }