diff --git a/dot.go b/dot.go index 74f5ece..e8049bb 100644 --- a/dot.go +++ b/dot.go @@ -57,7 +57,7 @@ type dotConnPool struct { } type dotConn struct { - conn net.Conn + conn *tls.Conn lastUsed time.Time refCount int mu sync.Mutex @@ -114,13 +114,6 @@ func (p *dotConnPool) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return nil, wrapCertificateVerificationError(err) } - // Set deadline - deadline, ok := ctx.Deadline() - if !ok { - deadline = time.Now().Add(5 * time.Second) - } - _ = conn.SetDeadline(deadline) - client := dns.Client{Net: "tcp-tls"} answer, _, err := client.ExchangeWithConnContext(ctx, msg, &dns.Conn{Conn: conn}) isGood := err == nil @@ -145,7 +138,7 @@ func (p *dotConnPool) getConn(ctx context.Context) (net.Conn, string, error) { // Try to reuse an existing connection for addr, dotConn := range p.conns { dotConn.mu.Lock() - if dotConn.refCount == 0 && dotConn.conn != nil { + if dotConn.refCount == 0 && dotConn.conn != nil && isAlive(dotConn.conn) { dotConn.refCount++ dotConn.lastUsed = time.Now() conn := dotConn.conn @@ -202,7 +195,7 @@ func (p *dotConnPool) putConn(addr string, conn net.Conn, isGood bool) { } // dialConn creates a new TCP/TLS connection. -func (p *dotConnPool) dialConn(ctx context.Context) (string, net.Conn, error) { +func (p *dotConnPool) dialConn(ctx context.Context) (string, *tls.Conn, error) { logger := LoggerFromCtx(ctx) var endpoint string @@ -224,7 +217,7 @@ func (p *dotConnPool) dialConn(ctx context.Context) (string, net.Conn, error) { // Try bootstrap IPs in parallel if len(p.addrs) > 0 { type result struct { - conn net.Conn + conn *tls.Conn addr string err error } @@ -316,3 +309,28 @@ func (p *dotConnPool) CloseIdleConnections() { delete(p.conns, addr) } } + +func isAlive(c *tls.Conn) bool { + // Set a very short deadline for the read + c.SetReadDeadline(time.Now().Add(1 * time.Millisecond)) + + // Try to read 1 byte without consuming it (using a small buffer) + one := make([]byte, 1) + _, err := c.Read(one) + + // Reset the deadline for future operations + c.SetReadDeadline(time.Time{}) + + if err == io.EOF { + return false // Connection is definitely closed + } + + // If we get a timeout, it means no data is waiting, + // but the connection is likely still "up." + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + return err == nil +}