From a20fbf95de48dcee637980f080d619b2f73c8fa0 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 6 Jun 2025 20:19:44 +0700 Subject: [PATCH] all: enhanced TLS certificate verification error messages Added more descriptive error messages for TLS certificate verification failures across DoH, DoT, DoQ, and DoH3 protocols. The error messages now include: - Certificate subject information - Issuer organization details - Common name of the certificate This helps users and developers better understand certificate validation failures by providing specific details about the untrusted certificate, rather than just a generic "unknown authority" message. Example error message change: Before: "certificate signed by unknown authority" After: "certificate signed by unknown authority: TestCA, TestOrg, TestIssuerOrg" --- doh.go | 51 +++++++++++ doh_test.go | 243 ++++++++++++++++++++++++++++++++++++++++++++++++++++ doq.go | 2 +- doq_test.go | 223 +++++++++++++++++++++++++++++++++++++++++++++++ dot.go | 3 +- 5 files changed, 519 insertions(+), 3 deletions(-) create mode 100644 doq_test.go diff --git a/doh.go b/doh.go index 73b2764..3459cb8 100644 --- a/doh.go +++ b/doh.go @@ -2,6 +2,7 @@ package ctrld import ( "context" + "crypto/tls" "encoding/base64" "errors" "fmt" @@ -120,6 +121,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro resp, err = c.Do(req.Clone(retryCtx)) } if err != nil { + err = wrapUrlError(err) if r.isDoH3 { if closer, ok := c.Transport.(io.Closer); ok { closer.Close() @@ -208,3 +210,52 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header { } return header } + +// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer. +// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output. +// If no certificate-related information is available, it simply returns the original error unmodified. +func wrapCertificateVerificationError(err error) error { + var tlsErr *tls.CertificateVerificationError + if errors.As(err, &tlsErr) { + if len(tlsErr.UnverifiedCertificates) > 0 { + cert := tlsErr.UnverifiedCertificates[0] + // Extract a more user-friendly issuer name + var issuer string + var organization string + if len(cert.Issuer.Organization) > 0 { + organization = cert.Issuer.Organization[0] + issuer = organization + } else if cert.Issuer.CommonName != "" { + issuer = cert.Issuer.CommonName + } else { + issuer = cert.Issuer.String() + } + + // Get the organization from the subject field as well + if len(cert.Subject.Organization) > 0 { + organization = cert.Subject.Organization[0] + } + + // Extract the subject information + subjectCN := cert.Subject.CommonName + if subjectCN == "" && len(cert.Subject.Organization) > 0 { + subjectCN = cert.Subject.Organization[0] + } + return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer) + } + } + return err +} + +// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context. +func wrapUrlError(err error) error { + var urlErr *url.Error + if errors.As(err, &urlErr) { + var tlsErr *tls.CertificateVerificationError + if errors.As(urlErr.Err, &tlsErr) { + urlErr.Err = wrapCertificateVerificationError(tlsErr) + return urlErr + } + } + return err +} diff --git a/doh_test.go b/doh_test.go index 8d3e011..92fa79f 100644 --- a/doh_test.go +++ b/doh_test.go @@ -1,8 +1,22 @@ package ctrld import ( + "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "net" + "net/http" + "net/http/httptest" + "net/url" "runtime" + "strings" "testing" + "time" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go/http3" ) func Test_dohOsHeaderValue(t *testing.T) { @@ -21,3 +35,232 @@ func Test_dohOsHeaderValue(t *testing.T) { t.Fatalf("missing decoding value for: %q", runtime.GOOS) } } + +func Test_wrapUrlError(t *testing.T) { + tests := []struct { + name string + err error + wantErr string + }{ + { + name: "No wrapping for non-URL errors", + err: errors.New("plain error"), + wantErr: "plain error", + }, + { + name: "URL error without TLS error", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: errors.New("underlying error"), + }, + wantErr: "Get \"https://example.com\": underlying error", + }, + { + name: "TLS error with missing unverified certificate data", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: &tls.CertificateVerificationError{ + UnverifiedCertificates: nil, + Err: &x509.UnknownAuthorityError{}, + }, + }, + wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority`, + }, + { + name: "TLS error with valid certificate data", + err: &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: &tls.CertificateVerificationError{ + UnverifiedCertificates: []*x509.Certificate{ + { + Subject: pkix.Name{ + CommonName: "BadSubjectCN", + Organization: []string{"BadSubjectOrg"}, + }, + Issuer: pkix.Name{ + CommonName: "BadIssuerCN", + Organization: []string{"BadIssuerOrg"}, + }, + }, + }, + Err: &x509.UnknownAuthorityError{}, + }, + }, + wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority: BadSubjectCN, BadSubjectOrg, BadIssuerOrg`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := wrapUrlError(tt.err) + if gotErr.Error() != tt.wantErr { + t.Errorf("wrapCertificateVerificationError() error = %v, want %v", gotErr, tt.wantErr) + } + }) + } +} + +func Test_ClientCertificateVerificationError(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/dns-message") + }) + tlsServer, cert := testTLSServer(t, handler) + tlsServerUrl, err := url.Parse(tlsServer.URL) + if err != nil { + t.Fatal(err) + } + quicServer := newTestQUICServer(t) + http3Server := newTestHTTP3Server(t, handler) + + tests := []struct { + name string + uc *UpstreamConfig + }{ + { + "doh", + &UpstreamConfig{ + Name: "doh", + Type: ResolverTypeDOH, + Endpoint: tlsServer.URL, + Timeout: 1000, + }, + }, + { + "doh3", + &UpstreamConfig{ + Name: "doh3", + Type: ResolverTypeDOH3, + Endpoint: http3Server.addr, + Timeout: 5000, + }, + }, + { + "doq", + &UpstreamConfig{ + Name: "doq", + Type: ResolverTypeDOQ, + Endpoint: quicServer.addr, + Timeout: 5000, + }, + }, + { + "dot", + &UpstreamConfig{ + Name: "dot", + Type: ResolverTypeDOT, + Endpoint: net.JoinHostPort(tlsServerUrl.Hostname(), tlsServerUrl.Port()), + Timeout: 1000, + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tc.uc.Init() + tc.uc.SetupBootstrapIP() + r, err := NewResolver(tc.uc) + if err != nil { + t.Fatal(err) + } + msg := new(dns.Msg) + msg.SetQuestion("verify.controld.com.", dns.TypeA) + msg.RecursionDesired = true + _, err = r.Resolve(context.Background(), msg) + // Verify the error contains the expected certificate information + if err == nil { + t.Fatal("expected certificate verification error, got nil") + } + + // You can check the error contains information about the test certificate + if !strings.Contains(err.Error(), cert.Issuer.CommonName) { + t.Fatalf("error should contain issuer information %q, got: %v", cert.Issuer.CommonName, err) + } + }) + } +} + +// testTLSServer creates an HTTPS test server with a self-signed certificate +// returns the server and its certificate for verification testing +// testTLSServer creates an HTTPS test server with a self-signed certificate +func testTLSServer(t *testing.T, handler http.Handler) (*httptest.Server, *x509.Certificate) { + t.Helper() + + testCert := generateTestCertificate(t) + + // Create a test server + server := httptest.NewUnstartedServer(handler) + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + } + server.StartTLS() + + // Add cleanup + t.Cleanup(server.Close) + + return server, testCert.cert +} + +// testHTTP3Server represents a structure for an HTTP/3 test server with its server instance, TLS certificate, and address. +type testHTTP3Server struct { + server *http3.Server + cert *x509.Certificate + addr string +} + +// newTestHTTP3Server creates and starts a test HTTP/3 server with a given handler and returns the server instance. +func newTestHTTP3Server(t *testing.T, handler http.Handler) *testHTTP3Server { + t.Helper() + + testCert := generateTestCertificate(t) + + // First create a listener to get the actual port + udpAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + t.Fatalf("failed to create UDP listener: %v", err) + } + + // Get the actual address + actualAddr := udpConn.LocalAddr().String() + + // Create TLS config + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + NextProtos: []string{"h3"}, // HTTP/3 protocol identifier + } + + // Create HTTP/3 server + server := &http3.Server{ + Handler: handler, + TLSConfig: tlsConfig, + } + + // Start the server with the existing UDP connection + go func() { + if err := server.Serve(udpConn); err != nil && !errors.Is(err, http.ErrServerClosed) { + t.Logf("HTTP/3 server error: %v", err) + } + }() + + h3Server := &testHTTP3Server{ + server: server, + cert: testCert.cert, + addr: actualAddr, + } + + // Add cleanup + t.Cleanup(func() { + server.Close() + udpConn.Close() + }) + + // Wait a bit for the server to be ready + time.Sleep(100 * time.Millisecond) + + return h3Server +} diff --git a/doq.go b/doq.go index 3c3f9e8..0903411 100644 --- a/doq.go +++ b/doq.go @@ -43,7 +43,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls. continue } if err != nil { - return nil, err + return nil, wrapCertificateVerificationError(err) } return answer, nil } diff --git a/doq_test.go b/doq_test.go new file mode 100644 index 0000000..430a22a --- /dev/null +++ b/doq_test.go @@ -0,0 +1,223 @@ +// test_helpers.go +package ctrld + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "strings" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go" +) + +// testCertificate represents a test certificate with its components +type testCertificate struct { + cert *x509.Certificate + tlsCert tls.Certificate + template *x509.Certificate +} + +// generateTestCertificate creates a self-signed certificate for testing +func generateTestCertificate(t *testing.T) *testCertificate { + t.Helper() + + // Generate private key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate private key: %v", err) + } + + // Create certificate template + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + CommonName: "Test CA", + }, + Issuer: pkix.Name{ + Organization: []string{"Test Issuer Org"}, + CommonName: "Test Issuer CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + + // Create certificate + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + t.Fatalf("failed to parse certificate: %v", err) + } + + // Create TLS certificate + tlsCert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: privateKey, + } + + return &testCertificate{ + cert: cert, + tlsCert: tlsCert, + template: template, + } +} + +// testQUICServer is a structure representing a test QUIC server for handling connections and streams. +// listener is the QUIC listener used to accept incoming connections. +// cert is the x509 certificate used by the server for authentication. +// addr is the address on which the test server is running. +type testQUICServer struct { + listener *quic.Listener + cert *x509.Certificate + addr string +} + +// newTestQUICServer creates and initializes a test QUIC server with TLS configuration and starts accepting connections. +func newTestQUICServer(t *testing.T) *testQUICServer { + t.Helper() + + testCert := generateTestCertificate(t) + + // Create TLS config + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{testCert.tlsCert}, + NextProtos: []string{"doq"}, + } + + // Create QUIC listener + listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil) + if err != nil { + t.Fatalf("failed to create QUIC listener: %v", err) + } + + server := &testQUICServer{ + listener: listener, + cert: testCert.cert, + addr: listener.Addr().String(), + } + + // Start handling connections + go server.serve(t) + + // Add cleanup + t.Cleanup(func() { + listener.Close() + }) + + return server +} + +// serve handles incoming connections on the QUIC listener and delegates them to connection handlers in separate goroutines. +func (s *testQUICServer) serve(t *testing.T) { + for { + conn, err := s.listener.Accept(context.Background()) + if err != nil { + // Check if the error is due to the listener being closed + if strings.Contains(err.Error(), "server closed") { + return + } + t.Logf("failed to accept connection: %v", err) + continue + } + + go s.handleConnection(t, conn) + } +} + +// handleConnection manages an individual QUIC connection by accepting and handling incoming streams in separate goroutines. +func (s *testQUICServer) handleConnection(t *testing.T, conn quic.Connection) { + for { + stream, err := conn.AcceptStream(context.Background()) + if err != nil { + return + } + + go s.handleStream(t, stream) + } +} + +// handleStream processes a single QUIC stream, reads DNS messages, generates a response, and sends it back to the client. +func (s *testQUICServer) handleStream(t *testing.T, stream quic.Stream) { + defer stream.Close() + + // Read length (2 bytes) + lenBuf := make([]byte, 2) + _, err := stream.Read(lenBuf) + if err != nil { + t.Logf("failed to read message length: %v", err) + return + } + msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1]) + + // Read message + msgBuf := make([]byte, msgLen) + _, err = stream.Read(msgBuf) + if err != nil { + t.Logf("failed to read message: %v", err) + return + } + + // Parse DNS message + msg := new(dns.Msg) + if err := msg.Unpack(msgBuf); err != nil { + t.Logf("failed to unpack DNS message: %v", err) + return + } + + // Create response + response := new(dns.Msg) + response.SetReply(msg) + response.Authoritative = true + + // Add a test answer + if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA { + response.Answer = append(response.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: msg.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP("192.0.2.1"), // TEST-NET-1 address + }) + } + + // Pack response + respBytes, err := response.Pack() + if err != nil { + t.Logf("failed to pack response: %v", err) + return + } + + // Write length + respLen := uint16(len(respBytes)) + _, err = stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)}) + if err != nil { + t.Logf("failed to write response length: %v", err) + return + } + + // Write response + _, err = stream.Write(respBytes) + if err != nil { + t.Logf("failed to write response: %v", err) + return + } +} diff --git a/dot.go b/dot.go index 67d1ff8..295134c 100644 --- a/dot.go +++ b/dot.go @@ -23,7 +23,6 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if msg != nil && len(msg.Question) > 0 { dnsTyp = msg.Question[0].Qtype } - tcpNet, _ := r.uc.netForDNSType(dnsTyp) dnsClient := &dns.Client{ Net: tcpNet, @@ -39,5 +38,5 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro } answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint) - return answer, err + return answer, wrapCertificateVerificationError(err) }