diff --git a/doq.go b/doq.go index eb7ed1c..9d0bdd9 100644 --- a/doq.go +++ b/doq.go @@ -42,11 +42,12 @@ const doqPoolSize = 16 // doqConnPool manages a pool of QUIC connections for DoQ queries using a buffered channel. type doqConnPool struct { - uc *UpstreamConfig - addrs []string - port string - tlsConfig *tls.Config - conns chan *doqConn + uc *UpstreamConfig + addrs []string + port string + tlsConfig *tls.Config + quicConfig *quic.Config + conns chan *doqConn } type doqConn struct { @@ -65,12 +66,17 @@ func newDOQConnPool(uc *UpstreamConfig, addrs []string) *doqConnPool { ServerName: uc.Domain, } + quicConfig := &quic.Config{ + KeepAlivePeriod: 15 * time.Second, + } + pool := &doqConnPool{ - uc: uc, - addrs: addrs, - port: port, - tlsConfig: tlsConfig, - conns: make(chan *doqConn, doqPoolSize), + uc: uc, + addrs: addrs, + port: port, + tlsConfig: tlsConfig, + quicConfig: quicConfig, + conns: make(chan *doqConn, doqPoolSize), } // Use SetFinalizer here because we need to call a method on the pool itself. @@ -85,12 +91,17 @@ func newDOQConnPool(uc *UpstreamConfig, addrs []string) *doqConnPool { // Resolve performs a DNS query using a pooled QUIC connection. func (p *doqConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { - // Retry logic for io.EOF errors (as per original implementation) + // Retry logic for transient errors: io.EOF (connection reset) and + // IdleTimeoutError (stale pooled connection timed out). for range 5 { answer, err := p.doResolve(ctx, msg) if err == io.EOF { continue } + var idleErr *quic.IdleTimeoutError + if errors.As(err, &idleErr) { + continue + } if err != nil { return nil, wrapCertificateVerificationError(err) } @@ -226,7 +237,7 @@ func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) udpConn.Close() return "", nil, err } - conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, p.tlsConfig, nil) + conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, p.tlsConfig, p.quicConfig) if err != nil { udpConn.Close() return "", nil, err @@ -241,7 +252,7 @@ func (p *doqConnPool) dialConn(ctx context.Context) (string, *quic.Conn, error) } pd := &quicParallelDialer{} - conn, err := pd.Dial(ctx, dialAddrs, p.tlsConfig, nil) + conn, err := pd.Dial(ctx, dialAddrs, p.tlsConfig, p.quicConfig) if err != nil { return "", nil, err }