mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
224 lines
5.5 KiB
Go
224 lines
5.5 KiB
Go
// test_helpers.go
|
|
package ctrld
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"math/big"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/quic-go/quic-go"
|
|
)
|
|
|
|
// testCertificate represents a test certificate with its components
|
|
type testCertificate struct {
|
|
cert *x509.Certificate
|
|
tlsCert tls.Certificate
|
|
template *x509.Certificate
|
|
}
|
|
|
|
// generateTestCertificate creates a self-signed certificate for testing
|
|
func generateTestCertificate(t *testing.T) *testCertificate {
|
|
t.Helper()
|
|
|
|
// Generate private key
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate private key: %v", err)
|
|
}
|
|
|
|
// Create certificate template
|
|
template := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"Test Org"},
|
|
CommonName: "Test CA",
|
|
},
|
|
Issuer: pkix.Name{
|
|
Organization: []string{"Test Issuer Org"},
|
|
CommonName: "Test Issuer CA",
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
DNSNames: []string{"localhost"},
|
|
}
|
|
|
|
// Create certificate
|
|
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
|
|
if err != nil {
|
|
t.Fatalf("failed to create certificate: %v", err)
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(derBytes)
|
|
if err != nil {
|
|
t.Fatalf("failed to parse certificate: %v", err)
|
|
}
|
|
|
|
// Create TLS certificate
|
|
tlsCert := tls.Certificate{
|
|
Certificate: [][]byte{derBytes},
|
|
PrivateKey: privateKey,
|
|
}
|
|
|
|
return &testCertificate{
|
|
cert: cert,
|
|
tlsCert: tlsCert,
|
|
template: template,
|
|
}
|
|
}
|
|
|
|
// testQUICServer is a structure representing a test QUIC server for handling connections and streams.
|
|
// listener is the QUIC listener used to accept incoming connections.
|
|
// cert is the x509 certificate used by the server for authentication.
|
|
// addr is the address on which the test server is running.
|
|
type testQUICServer struct {
|
|
listener *quic.Listener
|
|
cert *x509.Certificate
|
|
addr string
|
|
}
|
|
|
|
// newTestQUICServer creates and initializes a test QUIC server with TLS configuration and starts accepting connections.
|
|
func newTestQUICServer(t *testing.T) *testQUICServer {
|
|
t.Helper()
|
|
|
|
testCert := generateTestCertificate(t)
|
|
|
|
// Create TLS config
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{testCert.tlsCert},
|
|
NextProtos: []string{"doq"},
|
|
}
|
|
|
|
// Create QUIC listener
|
|
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create QUIC listener: %v", err)
|
|
}
|
|
|
|
server := &testQUICServer{
|
|
listener: listener,
|
|
cert: testCert.cert,
|
|
addr: listener.Addr().String(),
|
|
}
|
|
|
|
// Start handling connections
|
|
go server.serve(t)
|
|
|
|
// Add cleanup
|
|
t.Cleanup(func() {
|
|
listener.Close()
|
|
})
|
|
|
|
return server
|
|
}
|
|
|
|
// serve handles incoming connections on the QUIC listener and delegates them to connection handlers in separate goroutines.
|
|
func (s *testQUICServer) serve(t *testing.T) {
|
|
for {
|
|
conn, err := s.listener.Accept(context.Background())
|
|
if err != nil {
|
|
// Check if the error is due to the listener being closed
|
|
if strings.Contains(err.Error(), "server closed") {
|
|
return
|
|
}
|
|
t.Logf("failed to accept connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
go s.handleConnection(t, conn)
|
|
}
|
|
}
|
|
|
|
// handleConnection manages an individual QUIC connection by accepting and handling incoming streams in separate goroutines.
|
|
func (s *testQUICServer) handleConnection(t *testing.T, conn *quic.Conn) {
|
|
for {
|
|
stream, err := conn.AcceptStream(context.Background())
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
go s.handleStream(t, stream)
|
|
}
|
|
}
|
|
|
|
// handleStream processes a single QUIC stream, reads DNS messages, generates a response, and sends it back to the client.
|
|
func (s *testQUICServer) handleStream(t *testing.T, stream *quic.Stream) {
|
|
defer stream.Close()
|
|
|
|
// Read length (2 bytes)
|
|
lenBuf := make([]byte, 2)
|
|
_, err := stream.Read(lenBuf)
|
|
if err != nil {
|
|
t.Logf("failed to read message length: %v", err)
|
|
return
|
|
}
|
|
msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1])
|
|
|
|
// Read message
|
|
msgBuf := make([]byte, msgLen)
|
|
_, err = stream.Read(msgBuf)
|
|
if err != nil {
|
|
t.Logf("failed to read message: %v", err)
|
|
return
|
|
}
|
|
|
|
// Parse DNS message
|
|
msg := new(dns.Msg)
|
|
if err := msg.Unpack(msgBuf); err != nil {
|
|
t.Logf("failed to unpack DNS message: %v", err)
|
|
return
|
|
}
|
|
|
|
// Create response
|
|
response := new(dns.Msg)
|
|
response.SetReply(msg)
|
|
response.Authoritative = true
|
|
|
|
// Add a test answer
|
|
if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA {
|
|
response.Answer = append(response.Answer, &dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: msg.Question[0].Name,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: 300,
|
|
},
|
|
A: net.ParseIP("192.0.2.1"), // TEST-NET-1 address
|
|
})
|
|
}
|
|
|
|
// Pack response
|
|
respBytes, err := response.Pack()
|
|
if err != nil {
|
|
t.Logf("failed to pack response: %v", err)
|
|
return
|
|
}
|
|
|
|
// Write length
|
|
respLen := uint16(len(respBytes))
|
|
_, err = stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)})
|
|
if err != nil {
|
|
t.Logf("failed to write response length: %v", err)
|
|
return
|
|
}
|
|
|
|
// Write response
|
|
_, err = stream.Write(respBytes)
|
|
if err != nil {
|
|
t.Logf("failed to write response: %v", err)
|
|
return
|
|
}
|
|
}
|