diff --git a/config.go b/config.go index 1d2fd20..95d05ec 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,7 @@ package ctrld import ( "context" + "crypto/tls" "net" "net/http" "net/url" @@ -10,6 +11,8 @@ import ( "github.com/Control-D-Inc/ctrld/internal/dnsrcode" "github.com/go-playground/validator/v10" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" "github.com/miekg/dns" "github.com/spf13/viper" ) @@ -79,13 +82,14 @@ type NetworkConfig struct { // UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to. type UpstreamConfig struct { - Name string `mapstructure:"name" toml:"name"` - Type string `mapstructure:"type" toml:"type" validate:"oneof=doh doh3 dot doq os legacy"` - Endpoint string `mapstructure:"endpoint" toml:"endpoint" validate:"required_unless=Type os"` - BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip"` - Domain string `mapstructure:"-" toml:"-"` - Timeout int `mapstructure:"timeout" toml:"timeout" validate:"gte=0"` - transport *http.Transport `mapstructure:"-" toml:"-"` + Name string `mapstructure:"name" toml:"name"` + Type string `mapstructure:"type" toml:"type" validate:"oneof=doh doh3 dot doq os legacy"` + Endpoint string `mapstructure:"endpoint" toml:"endpoint" validate:"required_unless=Type os"` + BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip"` + Domain string `mapstructure:"-" toml:"-"` + Timeout int `mapstructure:"timeout" toml:"timeout" validate:"gte=0"` + transport *http.Transport `mapstructure:"-" toml:"-"` + http3RoundTripper *http3.RoundTripper `mapstructure:"-" toml:"-"` } // ListenerConfig specifies the networks configuration that ctrld will run on. @@ -133,9 +137,15 @@ func (uc *UpstreamConfig) Init() { // SetupTransport initializes the network transport used to connect to upstream server. // For now, only DoH upstream is supported. func (uc *UpstreamConfig) SetupTransport() { - if uc.Type != resolverTypeDOH { - return + switch uc.Type { + case resolverTypeDOH: + uc.setupDOHTransport() + case resolverTypeDOH3: + uc.setupDOH3Transport() } +} + +func (uc *UpstreamConfig) setupDOHTransport() { uc.transport = http.DefaultTransport.(*http.Transport).Clone() uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { dialer := &net.Dialer{ @@ -153,6 +163,40 @@ func (uc *UpstreamConfig) SetupTransport() { return dialer.DialContext(ctx, network, addr) } + uc.pingUpstream() +} + +func (uc *UpstreamConfig) setupDOH3Transport() { + uc.http3RoundTripper = &http3.RoundTripper{} + uc.http3RoundTripper.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + host := addr + ProxyLog.Debug().Msgf("debug dial context D0H3 %s - %s", addr, bootstrapDNS) + // if we have a bootstrap ip set, use it to avoid DNS lookup + if uc.BootstrapIP != "" { + if _, port, _ := net.SplitHostPort(addr); port != "" { + addr = net.JoinHostPort(uc.BootstrapIP, port) + } + ProxyLog.Debug().Msgf("sending doh3 request to: %s", addr) + } + remoteAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + localAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0} + if strings.Index(uc.BootstrapIP, ":") > -1 { + localAddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0} + } + udpConn, err := net.ListenUDP("udp", localAddr) + if err != nil { + return nil, err + } + return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg) + } + + uc.pingUpstream() +} + +func (uc *UpstreamConfig) pingUpstream() { // Warming up the transport by querying a test packet. dnsResolver, err := NewResolver(uc) if err != nil { diff --git a/doh.go b/doh.go index 78adbfa..511e28c 100644 --- a/doh.go +++ b/doh.go @@ -2,52 +2,30 @@ package ctrld import ( "context" - "crypto/tls" "encoding/base64" "fmt" "io" - "net" "net/http" - "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/http3" "github.com/miekg/dns" ) func newDohResolver(uc *UpstreamConfig) *dohResolver { r := &dohResolver{ - endpoint: uc.Endpoint, - isDoH3: uc.Type == resolverTypeDOH3, - transport: uc.transport, - } - if r.isDoH3 { - r.doh3DialFunc = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - host := addr - Log(ctx, ProxyLog.Debug(), "debug dial context D0H3 %s - %s", addr, bootstrapDNS) - // if we have a bootstrap ip set, use it to avoid DNS lookup - if uc.BootstrapIP != "" && addr == fmt.Sprintf("%s:443", uc.Domain) { - addr = fmt.Sprintf("%s:443", uc.BootstrapIP) - Log(ctx, ProxyLog.Debug(), "sending doh3 request to: %s", addr) - } - remoteAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg) - } + endpoint: uc.Endpoint, + isDoH3: uc.Type == resolverTypeDOH3, + transport: uc.transport, + http3RoundTripper: uc.http3RoundTripper, } return r } type dohResolver struct { - endpoint string - isDoH3 bool - doh3DialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) - transport *http.Transport + endpoint string + isDoH3 bool + transport *http.Transport + http3RoundTripper *http3.RoundTripper } func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { @@ -66,9 +44,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro c := http.Client{Transport: r.transport} if r.isDoH3 { - c.Transport = &http3.RoundTripper{} - c.Transport.(*http3.RoundTripper).Dial = r.doh3DialFunc - defer c.Transport.(*http3.RoundTripper).Close() + c.Transport = r.http3RoundTripper } resp, err := c.Do(req) if err != nil {