all: enhanced TLS certificate verification error messages

Added more descriptive error messages for TLS certificate verification
failures across DoH, DoT, DoQ, and DoH3 protocols. The error messages
now include:

- Certificate subject information
- Issuer organization details
- Common name of the certificate

This helps users and developers better understand certificate validation
failures by providing specific details about the untrusted certificate,
rather than just a generic "unknown authority" message.

Example error message change:
Before: "certificate signed by unknown authority"
After: "certificate signed by unknown authority: TestCA, TestOrg, TestIssuerOrg"
This commit is contained in:
Cuong Manh Le
2025-06-06 20:19:44 +07:00
committed by Cuong Manh Le
parent 628c4302aa
commit a20fbf95de
5 changed files with 519 additions and 3 deletions

51
doh.go
View File

@@ -2,6 +2,7 @@ package ctrld
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
@@ -120,6 +121,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
resp, err = c.Do(req.Clone(retryCtx))
}
if err != nil {
err = wrapUrlError(err)
if r.isDoH3 {
if closer, ok := c.Transport.(io.Closer); ok {
closer.Close()
@@ -208,3 +210,52 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header {
}
return header
}
// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer.
// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output.
// If no certificate-related information is available, it simply returns the original error unmodified.
func wrapCertificateVerificationError(err error) error {
var tlsErr *tls.CertificateVerificationError
if errors.As(err, &tlsErr) {
if len(tlsErr.UnverifiedCertificates) > 0 {
cert := tlsErr.UnverifiedCertificates[0]
// Extract a more user-friendly issuer name
var issuer string
var organization string
if len(cert.Issuer.Organization) > 0 {
organization = cert.Issuer.Organization[0]
issuer = organization
} else if cert.Issuer.CommonName != "" {
issuer = cert.Issuer.CommonName
} else {
issuer = cert.Issuer.String()
}
// Get the organization from the subject field as well
if len(cert.Subject.Organization) > 0 {
organization = cert.Subject.Organization[0]
}
// Extract the subject information
subjectCN := cert.Subject.CommonName
if subjectCN == "" && len(cert.Subject.Organization) > 0 {
subjectCN = cert.Subject.Organization[0]
}
return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer)
}
}
return err
}
// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context.
func wrapUrlError(err error) error {
var urlErr *url.Error
if errors.As(err, &urlErr) {
var tlsErr *tls.CertificateVerificationError
if errors.As(urlErr.Err, &tlsErr) {
urlErr.Err = wrapCertificateVerificationError(tlsErr)
return urlErr
}
}
return err
}

View File

@@ -1,8 +1,22 @@
package ctrld
import (
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"net"
"net/http"
"net/http/httptest"
"net/url"
"runtime"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/quic-go/quic-go/http3"
)
func Test_dohOsHeaderValue(t *testing.T) {
@@ -21,3 +35,232 @@ func Test_dohOsHeaderValue(t *testing.T) {
t.Fatalf("missing decoding value for: %q", runtime.GOOS)
}
}
func Test_wrapUrlError(t *testing.T) {
tests := []struct {
name string
err error
wantErr string
}{
{
name: "No wrapping for non-URL errors",
err: errors.New("plain error"),
wantErr: "plain error",
},
{
name: "URL error without TLS error",
err: &url.Error{
Op: "Get",
URL: "https://example.com",
Err: errors.New("underlying error"),
},
wantErr: "Get \"https://example.com\": underlying error",
},
{
name: "TLS error with missing unverified certificate data",
err: &url.Error{
Op: "Get",
URL: "https://example.com",
Err: &tls.CertificateVerificationError{
UnverifiedCertificates: nil,
Err: &x509.UnknownAuthorityError{},
},
},
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority`,
},
{
name: "TLS error with valid certificate data",
err: &url.Error{
Op: "Get",
URL: "https://example.com",
Err: &tls.CertificateVerificationError{
UnverifiedCertificates: []*x509.Certificate{
{
Subject: pkix.Name{
CommonName: "BadSubjectCN",
Organization: []string{"BadSubjectOrg"},
},
Issuer: pkix.Name{
CommonName: "BadIssuerCN",
Organization: []string{"BadIssuerOrg"},
},
},
},
Err: &x509.UnknownAuthorityError{},
},
},
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority: BadSubjectCN, BadSubjectOrg, BadIssuerOrg`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotErr := wrapUrlError(tt.err)
if gotErr.Error() != tt.wantErr {
t.Errorf("wrapCertificateVerificationError() error = %v, want %v", gotErr, tt.wantErr)
}
})
}
}
func Test_ClientCertificateVerificationError(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/dns-message")
})
tlsServer, cert := testTLSServer(t, handler)
tlsServerUrl, err := url.Parse(tlsServer.URL)
if err != nil {
t.Fatal(err)
}
quicServer := newTestQUICServer(t)
http3Server := newTestHTTP3Server(t, handler)
tests := []struct {
name string
uc *UpstreamConfig
}{
{
"doh",
&UpstreamConfig{
Name: "doh",
Type: ResolverTypeDOH,
Endpoint: tlsServer.URL,
Timeout: 1000,
},
},
{
"doh3",
&UpstreamConfig{
Name: "doh3",
Type: ResolverTypeDOH3,
Endpoint: http3Server.addr,
Timeout: 5000,
},
},
{
"doq",
&UpstreamConfig{
Name: "doq",
Type: ResolverTypeDOQ,
Endpoint: quicServer.addr,
Timeout: 5000,
},
},
{
"dot",
&UpstreamConfig{
Name: "dot",
Type: ResolverTypeDOT,
Endpoint: net.JoinHostPort(tlsServerUrl.Hostname(), tlsServerUrl.Port()),
Timeout: 1000,
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init()
tc.uc.SetupBootstrapIP()
r, err := NewResolver(tc.uc)
if err != nil {
t.Fatal(err)
}
msg := new(dns.Msg)
msg.SetQuestion("verify.controld.com.", dns.TypeA)
msg.RecursionDesired = true
_, err = r.Resolve(context.Background(), msg)
// Verify the error contains the expected certificate information
if err == nil {
t.Fatal("expected certificate verification error, got nil")
}
// You can check the error contains information about the test certificate
if !strings.Contains(err.Error(), cert.Issuer.CommonName) {
t.Fatalf("error should contain issuer information %q, got: %v", cert.Issuer.CommonName, err)
}
})
}
}
// testTLSServer creates an HTTPS test server with a self-signed certificate
// returns the server and its certificate for verification testing
// testTLSServer creates an HTTPS test server with a self-signed certificate
func testTLSServer(t *testing.T, handler http.Handler) (*httptest.Server, *x509.Certificate) {
t.Helper()
testCert := generateTestCertificate(t)
// Create a test server
server := httptest.NewUnstartedServer(handler)
server.TLS = &tls.Config{
Certificates: []tls.Certificate{testCert.tlsCert},
}
server.StartTLS()
// Add cleanup
t.Cleanup(server.Close)
return server, testCert.cert
}
// testHTTP3Server represents a structure for an HTTP/3 test server with its server instance, TLS certificate, and address.
type testHTTP3Server struct {
server *http3.Server
cert *x509.Certificate
addr string
}
// newTestHTTP3Server creates and starts a test HTTP/3 server with a given handler and returns the server instance.
func newTestHTTP3Server(t *testing.T, handler http.Handler) *testHTTP3Server {
t.Helper()
testCert := generateTestCertificate(t)
// First create a listener to get the actual port
udpAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
t.Fatalf("failed to create UDP listener: %v", err)
}
// Get the actual address
actualAddr := udpConn.LocalAddr().String()
// Create TLS config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{testCert.tlsCert},
NextProtos: []string{"h3"}, // HTTP/3 protocol identifier
}
// Create HTTP/3 server
server := &http3.Server{
Handler: handler,
TLSConfig: tlsConfig,
}
// Start the server with the existing UDP connection
go func() {
if err := server.Serve(udpConn); err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Logf("HTTP/3 server error: %v", err)
}
}()
h3Server := &testHTTP3Server{
server: server,
cert: testCert.cert,
addr: actualAddr,
}
// Add cleanup
t.Cleanup(func() {
server.Close()
udpConn.Close()
})
// Wait a bit for the server to be ready
time.Sleep(100 * time.Millisecond)
return h3Server
}

2
doq.go
View File

@@ -43,7 +43,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
continue
}
if err != nil {
return nil, err
return nil, wrapCertificateVerificationError(err)
}
return answer, nil
}

223
doq_test.go Normal file
View File

@@ -0,0 +1,223 @@
// test_helpers.go
package ctrld
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
// testCertificate represents a test certificate with its components
type testCertificate struct {
cert *x509.Certificate
tlsCert tls.Certificate
template *x509.Certificate
}
// generateTestCertificate creates a self-signed certificate for testing
func generateTestCertificate(t *testing.T) *testCertificate {
t.Helper()
// Generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate private key: %v", err)
}
// Create certificate template
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
CommonName: "Test CA",
},
Issuer: pkix.Name{
Organization: []string{"Test Issuer Org"},
CommonName: "Test Issuer CA",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"localhost"},
}
// Create certificate
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("failed to create certificate: %v", err)
}
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
t.Fatalf("failed to parse certificate: %v", err)
}
// Create TLS certificate
tlsCert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: privateKey,
}
return &testCertificate{
cert: cert,
tlsCert: tlsCert,
template: template,
}
}
// testQUICServer is a structure representing a test QUIC server for handling connections and streams.
// listener is the QUIC listener used to accept incoming connections.
// cert is the x509 certificate used by the server for authentication.
// addr is the address on which the test server is running.
type testQUICServer struct {
listener *quic.Listener
cert *x509.Certificate
addr string
}
// newTestQUICServer creates and initializes a test QUIC server with TLS configuration and starts accepting connections.
func newTestQUICServer(t *testing.T) *testQUICServer {
t.Helper()
testCert := generateTestCertificate(t)
// Create TLS config
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{testCert.tlsCert},
NextProtos: []string{"doq"},
}
// Create QUIC listener
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil)
if err != nil {
t.Fatalf("failed to create QUIC listener: %v", err)
}
server := &testQUICServer{
listener: listener,
cert: testCert.cert,
addr: listener.Addr().String(),
}
// Start handling connections
go server.serve(t)
// Add cleanup
t.Cleanup(func() {
listener.Close()
})
return server
}
// serve handles incoming connections on the QUIC listener and delegates them to connection handlers in separate goroutines.
func (s *testQUICServer) serve(t *testing.T) {
for {
conn, err := s.listener.Accept(context.Background())
if err != nil {
// Check if the error is due to the listener being closed
if strings.Contains(err.Error(), "server closed") {
return
}
t.Logf("failed to accept connection: %v", err)
continue
}
go s.handleConnection(t, conn)
}
}
// handleConnection manages an individual QUIC connection by accepting and handling incoming streams in separate goroutines.
func (s *testQUICServer) handleConnection(t *testing.T, conn quic.Connection) {
for {
stream, err := conn.AcceptStream(context.Background())
if err != nil {
return
}
go s.handleStream(t, stream)
}
}
// handleStream processes a single QUIC stream, reads DNS messages, generates a response, and sends it back to the client.
func (s *testQUICServer) handleStream(t *testing.T, stream quic.Stream) {
defer stream.Close()
// Read length (2 bytes)
lenBuf := make([]byte, 2)
_, err := stream.Read(lenBuf)
if err != nil {
t.Logf("failed to read message length: %v", err)
return
}
msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1])
// Read message
msgBuf := make([]byte, msgLen)
_, err = stream.Read(msgBuf)
if err != nil {
t.Logf("failed to read message: %v", err)
return
}
// Parse DNS message
msg := new(dns.Msg)
if err := msg.Unpack(msgBuf); err != nil {
t.Logf("failed to unpack DNS message: %v", err)
return
}
// Create response
response := new(dns.Msg)
response.SetReply(msg)
response.Authoritative = true
// Add a test answer
if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA {
response.Answer = append(response.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: msg.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.ParseIP("192.0.2.1"), // TEST-NET-1 address
})
}
// Pack response
respBytes, err := response.Pack()
if err != nil {
t.Logf("failed to pack response: %v", err)
return
}
// Write length
respLen := uint16(len(respBytes))
_, err = stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)})
if err != nil {
t.Logf("failed to write response length: %v", err)
return
}
// Write response
_, err = stream.Write(respBytes)
if err != nil {
t.Logf("failed to write response: %v", err)
return
}
}

3
dot.go
View File

@@ -23,7 +23,6 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: tcpNet,
@@ -39,5 +38,5 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
}
answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint)
return answer, err
return answer, wrapCertificateVerificationError(err)
}