diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index e67e02c..76af611 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -48,7 +48,6 @@ func (p *prog) run() { for n := range p.cfg.Upstream { uc := p.cfg.Upstream[n] uc.Init() - if uc.BootstrapIP == "" { // resolve it manually and set the bootstrap ip c := new(dns.Client) @@ -71,6 +70,7 @@ func (p *prog) run() { } } } + uc.SetupTransport() } for listenerNum := range p.cfg.Listener { diff --git a/config.go b/config.go index f660006..faf8720 100644 --- a/config.go +++ b/config.go @@ -1,12 +1,17 @@ package ctrld import ( + "context" + "fmt" "net" + "net/http" "net/url" "strings" + "time" "github.com/Control-D-Inc/ctrld/internal/dnsrcode" "github.com/go-playground/validator/v10" + "github.com/miekg/dns" "github.com/spf13/viper" ) @@ -75,12 +80,13 @@ type NetworkConfig struct { // UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to. type UpstreamConfig struct { - Name string `mapstructure:"name" toml:"name"` - Type string `mapstructure:"type" toml:"type" validate:"oneof=doh doh3 dot doq os legacy"` - Endpoint string `mapstructure:"endpoint" toml:"endpoint" validate:"required_unless=Type os"` - BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip"` - Domain string `mapstructure:"-" toml:"-"` - Timeout int `mapstructure:"timeout" toml:"timeout" validate:"gte=0"` + Name string `mapstructure:"name" toml:"name"` + Type string `mapstructure:"type" toml:"type" validate:"oneof=doh doh3 dot doq os legacy"` + Endpoint string `mapstructure:"endpoint" toml:"endpoint" validate:"required_unless=Type os"` + BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip"` + Domain string `mapstructure:"-" toml:"-"` + Timeout int `mapstructure:"timeout" toml:"timeout" validate:"gte=0"` + transport *http.Transport `mapstructure:"-" toml:"-"` } // ListenerConfig specifies the networks configuration that ctrld will run on. @@ -125,6 +131,41 @@ func (uc *UpstreamConfig) Init() { } } +// SetupTransport initializes the network transport used to connect to upstream server. +// For now, only DoH upstream is supported. +func (uc *UpstreamConfig) SetupTransport() { + if uc.Type != resolverTypeDOH { + return + } + uc.transport = http.DefaultTransport.(*http.Transport).Clone() + uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 10 * time.Second, + } + Log(ctx, ProxyLog.Debug(), "debug dial context %s - %s - %s", addr, network, bootstrapDNS) + // if we have a bootstrap ip set, use it to avoid DNS lookup + if uc.BootstrapIP != "" && addr == fmt.Sprintf("%s:443", uc.Domain) { + addr = fmt.Sprintf("%s:443", uc.BootstrapIP) + Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr) + } + return dialer.DialContext(ctx, network, addr) + } + + // Warming up the transport by querying a test packet. + dnsResolver, err := NewResolver(uc) + if err != nil { + ProxyLog.Error().Err(err).Msgf("failed to create resolver for upstream: %s", uc.Name) + return + } + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + msg.MsgHdr.RecursionDesired = true + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _ = dnsResolver.Resolve(ctx, msg) +} + // Init initialized necessary values for an ListenerConfig. func (lc *ListenerConfig) Init() { if lc.Policy != nil { diff --git a/doh.go b/doh.go index f3e3810..78adbfa 100644 --- a/doh.go +++ b/doh.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "time" "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/http3" @@ -16,20 +15,11 @@ import ( ) func newDohResolver(uc *UpstreamConfig) *dohResolver { - http.DefaultTransport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer := &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 10 * time.Second, - } - Log(ctx, ProxyLog.Debug(), "debug dial context %s - %s - %s", addr, network, bootstrapDNS) - // if we have a bootstrap ip set, use it to avoid DNS lookup - if uc.BootstrapIP != "" && addr == fmt.Sprintf("%s:443", uc.Domain) { - addr = fmt.Sprintf("%s:443", uc.BootstrapIP) - Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr) - } - return dialer.DialContext(ctx, network, addr) + r := &dohResolver{ + endpoint: uc.Endpoint, + isDoH3: uc.Type == resolverTypeDOH3, + transport: uc.transport, } - r := &dohResolver{endpoint: uc.Endpoint, isDoH3: uc.Type == resolverTypeDOH3} if r.isDoH3 { r.doh3DialFunc = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { host := addr @@ -57,6 +47,7 @@ type dohResolver struct { endpoint string isDoH3 bool doh3DialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + transport *http.Transport } func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { @@ -73,7 +64,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro req.Header.Set("Content-Type", "application/dns-message") req.Header.Set("Accept", "application/dns-message") - c := http.Client{} + c := http.Client{Transport: r.transport} if r.isDoH3 { c.Transport = &http3.RoundTripper{} c.Transport.(*http3.RoundTripper).Dial = r.doh3DialFunc