Files
ctrld/doh.go
Cuong Manh Le ebcc545547 all: improving DoH query performance
Previously, for each DoH query, we use the net/http default transport
with DialContext function re-assigned. This has some problems:

 - The first query to server will be slow.
 - Using the default transport for all upstreams can have race condition
   in case of multiple queries to multiple DoH upstreams

This commit fixes those issues, by initializing a separate transport for
each DoH upstream, the warming up the transport by doing a test query.
Later queries can take the advantage and re-use the connection.
2023-01-20 21:32:14 +07:00

91 lines
2.6 KiB
Go

package ctrld
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/miekg/dns"
)
func newDohResolver(uc *UpstreamConfig) *dohResolver {
r := &dohResolver{
endpoint: uc.Endpoint,
isDoH3: uc.Type == resolverTypeDOH3,
transport: uc.transport,
}
if r.isDoH3 {
r.doh3DialFunc = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
host := addr
Log(ctx, ProxyLog.Debug(), "debug dial context D0H3 %s - %s", addr, bootstrapDNS)
// if we have a bootstrap ip set, use it to avoid DNS lookup
if uc.BootstrapIP != "" && addr == fmt.Sprintf("%s:443", uc.Domain) {
addr = fmt.Sprintf("%s:443", uc.BootstrapIP)
Log(ctx, ProxyLog.Debug(), "sending doh3 request to: %s", addr)
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg)
}
}
return r
}
type dohResolver struct {
endpoint string
isDoH3 bool
doh3DialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
transport *http.Transport
}
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
data, err := msg.Pack()
if err != nil {
return nil, err
}
enc := base64.RawURLEncoding.EncodeToString(data)
url := fmt.Sprintf("%s?dns=%s", r.endpoint, enc)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("could not create request: %w", err)
}
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
c := http.Client{Transport: r.transport}
if r.isDoH3 {
c.Transport = &http3.RoundTripper{}
c.Transport.(*http3.RoundTripper).Dial = r.doh3DialFunc
defer c.Transport.(*http3.RoundTripper).Close()
}
resp, err := c.Do(req)
if err != nil {
return nil, fmt.Errorf("could not perform request: %w", err)
}
defer resp.Body.Close()
buf, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read message from response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("wrong response from DOH server, got: %s, status: %d", string(buf), resp.StatusCode)
}
answer := new(dns.Msg)
return answer, answer.Unpack(buf)
}