package ctrld import ( "context" "crypto/tls" "errors" "net" "net/http" "runtime" "sync" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" ) func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper { if uc.Type != ResolverTypeDOH3 { return nil } rt := &http3.Transport{} rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool} logger := LoggerFromCtx(ctx) rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { _, port, _ := net.SplitHostPort(addr) // if we have a bootstrap ip set, use it to avoid DNS lookup if uc.BootstrapIP != "" { addr = net.JoinHostPort(uc.BootstrapIP, port) Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr) udpConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, err } remoteAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg) } dialAddrs := make([]string, len(addrs)) for i := range addrs { dialAddrs[i] = net.JoinHostPort(addrs[i], port) } pd := &quicParallelDialer{} conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg) if err != nil { return nil, err } Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr()) return conn, err } runtime.SetFinalizer(rt, func(rt *http3.Transport) { rt.CloseIdleConnections() }) return rt } func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper { uc.ensureSetupTransport(ctx) return transportByIpStack(uc.IPStack, dnsType, uc.http3RoundTripper, uc.http3RoundTripper4, uc.http3RoundTripper6) } func (uc *UpstreamConfig) doqTransport(ctx context.Context, dnsType uint16) *doqConnPool { uc.ensureSetupTransport(ctx) return transportByIpStack(uc.IPStack, dnsType, uc.doqConnPool, uc.doqConnPool4, uc.doqConnPool6) } func (uc *UpstreamConfig) dotTransport(ctx context.Context, dnsType uint16) *dotConnPool { uc.ensureSetupTransport(ctx) return transportByIpStack(uc.IPStack, dnsType, uc.dotClientPool, uc.dotClientPool4, uc.dotClientPool6) } // Putting the code for quic parallel dialer here: // // - quic dialer is different with net.Dialer // - simplification for quic free version type parallelDialerResult struct { conn *quic.Conn err error } type quicParallelDialer struct{} // Dial performs parallel dialing to the given address list. func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) { if len(addrs) == 0 { return nil, errors.New("empty addresses") } ctx, cancel := context.WithCancel(ctx) defer cancel() done := make(chan struct{}) defer close(done) ch := make(chan *parallelDialerResult, len(addrs)) var wg sync.WaitGroup wg.Add(len(addrs)) go func() { wg.Wait() close(ch) }() for _, addr := range addrs { go func(addr string) { defer wg.Done() remoteAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { ch <- ¶llelDialerResult{conn: nil, err: err} return } udpConn, err := net.ListenUDP("udp", nil) if err != nil { ch <- ¶llelDialerResult{conn: nil, err: err} return } conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg) select { case ch <- ¶llelDialerResult{conn: conn, err: err}: case <-done: if conn != nil { conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "") } if udpConn != nil { udpConn.Close() } } }(addr) } errs := make([]error, 0, len(addrs)) for res := range ch { if res.err == nil { cancel() return res.conn, res.err } errs = append(errs, res.err) } return nil, errors.Join(errs...) } func (uc *UpstreamConfig) newDOQConnPool(ctx context.Context, addrs []string) *doqConnPool { if uc.Type != ResolverTypeDOQ { return nil } return newDOQConnPool(ctx, uc, addrs) } func (uc *UpstreamConfig) newDOTClientPool(ctx context.Context, addrs []string) *dotConnPool { if uc.Type != ResolverTypeDOT { return nil } return newDOTClientPool(ctx, uc, addrs) }