Files
ctrld/config_quic.go
2025-10-09 19:12:06 +07:00

162 lines
4.1 KiB
Go

package ctrld
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"runtime"
"sync"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
)
func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) {
switch uc.IPStack {
case IpStackBoth, "":
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
case IpStackV4:
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
case IpStackV6:
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
case IpStackSplit:
uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
if HasIPv6(ctx) {
uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
} else {
uc.http3RoundTripper6 = uc.http3RoundTripper4
}
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
rt := &http3.Transport{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
logger := LoggerFromCtx(ctx)
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
// if we have a bootstrap ip set, use it to avoid DNS lookup
if uc.BootstrapIP != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
}
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
pd := &quicParallelDialer{}
conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg)
if err != nil {
return nil, err
}
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
rt.CloseIdleConnections()
})
return rt
}
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
uc.transportOnce.Do(func() {
uc.SetupTransport(ctx)
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport(ctx)
}
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.http3RoundTripper
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.http3RoundTripper4
default:
return uc.http3RoundTripper6
}
}
return uc.http3RoundTripper
}
// Putting the code for quic parallel dialer here:
//
// - quic dialer is different with net.Dialer
// - simplification for quic free version
type parallelDialerResult struct {
conn *quic.Conn
err error
}
type quicParallelDialer struct{}
// Dial performs parallel dialing to the given address list.
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
if len(addrs) == 0 {
return nil, errors.New("empty addresses")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
done := make(chan struct{})
defer close(done)
ch := make(chan *parallelDialerResult, len(addrs))
var wg sync.WaitGroup
wg.Add(len(addrs))
go func() {
wg.Wait()
close(ch)
}()
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
select {
case ch <- &parallelDialerResult{conn: conn, err: err}:
case <-done:
if conn != nil {
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
}
if udpConn != nil {
udpConn.Close()
}
}
}(addr)
}
errs := make([]error, 0, len(addrs))
for res := range ch {
if res.err == nil {
cancel()
return res.conn, res.err
}
errs = append(errs, res.err)
}
return nil, errors.Join(errs...)
}