refactor: replace Unix socket log communication with HTTP-based system

Replace the legacy Unix socket log communication between `ctrld start` and
`ctrld run` with a modern HTTP-based system for better reliability and
maintainability.

Benefits:
- More reliable communication protocol using standard HTTP
- Better error handling and connection management
- Cleaner separation of concerns with dedicated endpoints
- Easier to test and debug with HTTP-based communication
- More maintainable code with proper abstraction layers

This change maintains backward compatibility while providing a more robust
foundation for inter-process communication between ctrld commands.
This commit is contained in:
Cuong Manh Le
2025-09-12 18:22:02 +07:00
committed by Cuong Manh Le
parent a04babbbc3
commit 56f8113bb0
7 changed files with 976 additions and 137 deletions

View File

@@ -234,22 +234,21 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
sockDir = d
}
sockPath := filepath.Join(sockDir, ctrldLogUnixSock)
if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil {
if conn, err := net.Dial(addr.Network(), addr.String()); err == nil {
lc := &logConn{conn: conn}
consoleWriter = newHumanReadableZapCore(io.MultiWriter(os.Stdout, lc), consoleWriterLevel)
p.logConn = lc
} else {
if !errors.Is(err, os.ErrNotExist) {
p.Warn().Err(err).Msg("Unable to create log ipc connection")
}
hlc := newHTTPLogClient(sockPath)
// Test if HTTP log server is available
if err := hlc.Ping(); err != nil {
if !errConnectionRefused(err) {
p.Warn().Err(err).Msg("Unable to ping log server")
}
} else {
p.Warn().Err(err).Msgf("Unable to resolve socket address: %s", sockPath)
// Server is available, use HTTP log client
consoleWriter = newHumanReadableZapCore(io.MultiWriter(os.Stdout, hlc), consoleWriterLevel)
p.logConn = hlc
}
notifyExitToLogServer := func() {
if p.logConn != nil {
_, _ = p.logConn.Write([]byte(msgExit))
_ = p.logConn.Close()
}
}
@@ -1354,7 +1353,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, notifyFunc func(), fatal bool) (
break
}
logMsg(il.Info().Err(err), n, "error listening on address: %s", addr)
logMsg(il.Debug().Err(err), n, "error listening on address: %s", addr)
if !check.IP && !check.Port {
if fatal {

View File

@@ -10,7 +10,6 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/kardianos/service"
@@ -104,11 +103,10 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
writeDefaultConfig := !noConfigStart && configBase64 == ""
logServerStarted := make(chan struct{})
// A buffer channel to gather log output from runCmd and report
// to user in case self-check process failed.
runCmdLogCh := make(chan string, 256)
stopLogCh := make(chan struct{})
ud, err := userHomeDir()
sockDir := ud
var logServerSocketPath string
if err != nil {
logger.Warn().Err(err).Msg("Failed to get user home directory")
logger.Warn().Msg("Log server did not start")
@@ -122,29 +120,17 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
if d, err := socketDir(); err == nil {
sockDir = d
}
sockPath := filepath.Join(sockDir, ctrldLogUnixSock)
_ = os.Remove(sockPath)
logServerSocketPath = filepath.Join(sockDir, ctrldLogUnixSock)
_ = os.Remove(logServerSocketPath)
go func() {
defer func() {
close(runCmdLogCh)
_ = os.Remove(sockPath)
}()
defer os.Remove(logServerSocketPath)
close(logServerStarted)
if conn := runLogServer(sockPath); conn != nil {
// Enough buffer for log message, we don't produce
// such long log message, but just in case.
buf := make([]byte, 1024)
for {
n, err := conn.Read(buf)
if err != nil {
return
}
msg := string(buf[:n])
if _, _, found := strings.Cut(msg, msgExit); found {
cancel()
}
runCmdLogCh <- msg
}
// Start HTTP log server
if err := httpLogServer(logServerSocketPath, stopLogCh); err != nil && err != http.ErrServerClosed {
logger.Warn().Err(err).Msg("Failed to serve HTTP log server")
return
}
}()
}
@@ -270,19 +256,29 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
case ok && status == service.StatusRunning:
logger.Notice().Msg("Service started")
default:
marker := bytes.Repeat([]byte("="), 32)
marker := append(bytes.Repeat([]byte("="), 32), '\n')
// If ctrld service is not running, emitting log obtained from ctrld process.
if status != service.StatusRunning || ctx.Err() != nil {
logger.Error().Msg("Ctrld service may not have started due to an error or misconfiguration, service log:")
_, _ = logger.Write(marker)
haveLog := false
for msg := range runCmdLogCh {
_, _ = logger.Write([]byte(strings.ReplaceAll(msg, msgExit, "")))
haveLog = true
}
// If we're unable to get log from "ctrld run", notice users about it.
if !haveLog {
logger.Write([]byte(`<no log output is obtained from ctrld process>"`))
// Wait for log collection to complete
<-stopLogCh
// Retrieve logs from HTTP server if available
if logServerSocketPath != "" {
hlc := newHTTPLogClient(logServerSocketPath)
logs, err := hlc.GetLogs()
if err != nil {
logger.Warn().Err(err).Msg("Failed to get logs from HTTP log server")
}
if len(logs) == 0 {
logger.Write([]byte(`<no log output is obtained from ctrld process>`))
} else {
logger.Write(logs)
}
} else {
logger.Write([]byte(`<no log output from HTTP log server>`))
}
}
// Report any error if occurred.

View File

@@ -1,67 +0,0 @@
package cli
import (
"net"
"time"
)
// logConn wraps a net.Conn, override the Write behavior.
// runCmd uses this wrapper, so as long as startCmd finished,
// ctrld log won't be flushed with un-necessary write errors.
// This prevents log pollution when the parent process closes the connection
type logConn struct {
conn net.Conn
}
// Read delegates to the underlying connection
// This maintains normal read behavior for the wrapped connection
func (lc *logConn) Read(b []byte) (n int, err error) {
return lc.conn.Read(b)
}
// Close delegates to the underlying connection
// This ensures proper cleanup of the wrapped connection
func (lc *logConn) Close() error {
return lc.conn.Close()
}
// LocalAddr delegates to the underlying connection
// This provides access to local address information
func (lc *logConn) LocalAddr() net.Addr {
return lc.conn.LocalAddr()
}
// RemoteAddr delegates to the underlying connection
// This provides access to remote address information
func (lc *logConn) RemoteAddr() net.Addr {
return lc.conn.RemoteAddr()
}
// SetDeadline delegates to the underlying connection
// This maintains timeout functionality for the wrapped connection
func (lc *logConn) SetDeadline(t time.Time) error {
return lc.conn.SetDeadline(t)
}
// SetReadDeadline delegates to the underlying connection
// This maintains read timeout functionality for the wrapped connection
func (lc *logConn) SetReadDeadline(t time.Time) error {
return lc.conn.SetReadDeadline(t)
}
// SetWriteDeadline delegates to the underlying connection
// This maintains write timeout functionality for the wrapped connection
func (lc *logConn) SetWriteDeadline(t time.Time) error {
return lc.conn.SetWriteDeadline(t)
}
// Write performs writes with underlying net.Conn, ignore any errors happen.
// "ctrld run" command use this wrapper to report errors to "ctrld start".
// If no error occurred, "ctrld start" may finish before "ctrld run" attempt
// to close the connection, so ignore errors conservatively here, prevent
// un-necessary error "write to closed connection" flushed to ctrld log.
// This prevents log pollution when the parent process closes the connection prematurely
func (lc *logConn) Write(b []byte) (int, error) {
_, _ = lc.conn.Write(b)
return len(b), nil
}

172
cmd/cli/http_log.go Normal file
View File

@@ -0,0 +1,172 @@
package cli
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"sync"
)
// HTTP log server endpoint constants
const (
httpLogEndpointPing = "/ping"
httpLogEndpointLogs = "/logs"
httpLogEndpointExit = "/exit"
)
// httpLogClient sends logs to an HTTP server via POST requests.
// This replaces the logConn functionality with HTTP-based communication.
type httpLogClient struct {
baseURL string
client *http.Client
}
// newHTTPLogClient creates a new HTTP log client
func newHTTPLogClient(sockPath string) *httpLogClient {
return &httpLogClient{
baseURL: "http://unix",
client: &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
},
}
}
// Write sends log data to the HTTP server via POST request
func (hlc *httpLogClient) Write(b []byte) (int, error) {
// Send log data via HTTP POST to /logs endpoint
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointLogs, "text/plain", bytes.NewReader(b))
if err != nil {
// Ignore errors to prevent log pollution, just like the original logConn
return len(b), nil
}
resp.Body.Close()
return len(b), nil
}
// Ping tests if the HTTP log server is available
func (hlc *httpLogClient) Ping() error {
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointPing)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// Close sends exit signal to the HTTP server
func (hlc *httpLogClient) Close() error {
// Send exit signal via HTTP POST with empty body
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// GetLogs retrieves all collected logs from the HTTP server
func (hlc *httpLogClient) GetLogs() ([]byte, error) {
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointLogs)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNoContent {
return []byte{}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// httpLogServer starts an HTTP server listening on unix socket to collect logs from runCmd.
func httpLogServer(sockPath string, stopLogCh chan struct{}) error {
addr, err := net.ResolveUnixAddr("unix", sockPath)
if err != nil {
return fmt.Errorf("invalid log sock path: %w", err)
}
ln, err := net.ListenUnix("unix", addr)
if err != nil {
return fmt.Errorf("could not listen log socket: %w", err)
}
defer ln.Close()
// Create a log writer to store all logs
logWriter := newLogWriter()
// Use a sync.Once to ensure channel is only closed once
var channelClosed sync.Once
mux := http.NewServeMux()
mux.HandleFunc(httpLogEndpointPing, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc(httpLogEndpointLogs, func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
// POST /logs - Store log data
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
// Store log data in log writer
logWriter.Write(body)
w.WriteHeader(http.StatusOK)
case http.MethodGet:
// GET /logs - Retrieve all logs
// Get all logs from the log writer
logWriter.mu.Lock()
logs := logWriter.buf.Bytes()
logWriter.mu.Unlock()
if len(logs) == 0 {
w.WriteHeader(http.StatusNoContent)
return
}
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write(logs)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
mux.HandleFunc(httpLogEndpointExit, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Close the stop channel to signal completion (only once)
channelClosed.Do(func() {
close(stopLogCh)
})
w.WriteHeader(http.StatusOK)
})
server := &http.Server{Handler: mux}
return server.Serve(ln)
}

758
cmd/cli/http_log_test.go Normal file
View File

@@ -0,0 +1,758 @@
package cli
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestHTTPLogServer(t *testing.T) {
// Create a temporary socket path
tmpDir := t.TempDir()
sockPath := filepath.Join(tmpDir, "test.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait a bit for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
t.Run("Ping endpoint", func(t *testing.T) {
resp, err := client.Get("http://unix" + httpLogEndpointPing)
if err != nil {
t.Fatalf("Failed to ping server: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
})
t.Run("Ping endpoint wrong method", func(t *testing.T) {
resp, err := client.Post("http://unix"+httpLogEndpointPing, "text/plain", bytes.NewReader([]byte("test")))
if err != nil {
t.Fatalf("Failed to send POST to ping: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", resp.StatusCode)
}
})
t.Run("Log endpoint", func(t *testing.T) {
testLog := "test log message"
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(testLog)))
if err != nil {
t.Fatalf("Failed to send log: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check if log was stored by retrieving it
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer logsResp.Body.Close()
if logsResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
}
body, err := io.ReadAll(logsResp.Body)
if err != nil {
t.Fatalf("Failed to read logs: %v", err)
}
if !strings.Contains(string(body), testLog) {
t.Errorf("Expected log '%s' not found in stored logs", testLog)
}
})
t.Run("Log endpoint wrong method", func(t *testing.T) {
// Test unsupported method (PUT) on /logs endpoint
req, err := http.NewRequest("PUT", "http://unix"+httpLogEndpointLogs, bytes.NewReader([]byte("test")))
if err != nil {
t.Fatalf("Failed to create PUT request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to send PUT to logs: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", resp.StatusCode)
}
})
t.Run("Exit endpoint", func(t *testing.T) {
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
if err != nil {
t.Fatalf("Failed to send exit: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check if channel is closed by trying to read from it
select {
case _, ok := <-stopLogCh:
if ok {
t.Error("Expected channel to be closed, but it's still open")
}
case <-time.After(1 * time.Second):
t.Error("Timeout waiting for channel closure")
}
})
t.Run("Exit endpoint wrong method", func(t *testing.T) {
resp, err := client.Get("http://unix" + httpLogEndpointExit)
if err != nil {
t.Fatalf("Failed to send GET to exit: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", resp.StatusCode)
}
})
t.Run("Multiple log messages", func(t *testing.T) {
logs := []string{"log1", "log2", "log3"}
for _, log := range logs {
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
if err != nil {
t.Fatalf("Failed to send log '%s': %v", log, err)
}
resp.Body.Close()
}
// Check if all logs were stored by retrieving them
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer logsResp.Body.Close()
if logsResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
}
body, err := io.ReadAll(logsResp.Body)
if err != nil {
t.Fatalf("Failed to read logs: %v", err)
}
logContent := string(body)
for i, expectedLog := range logs {
if !strings.Contains(logContent, expectedLog) {
t.Errorf("Log %d: expected '%s' not found in stored logs", i, expectedLog)
}
}
})
t.Run("Large log message", func(t *testing.T) {
largeLog := strings.Repeat("a", 1024*10) // 10KB log message
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(largeLog)))
if err != nil {
t.Fatalf("Failed to send large log: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check if large log was stored by retrieving it
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer logsResp.Body.Close()
if logsResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
}
body, err := io.ReadAll(logsResp.Body)
if err != nil {
t.Fatalf("Failed to read logs: %v", err)
}
if !strings.Contains(string(body), largeLog) {
t.Error("Large log message was not stored correctly")
}
})
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogServerInvalidSocketPath(t *testing.T) {
// Test with invalid socket path
invalidPath := "/invalid/path/that/does/not/exist.sock"
stopLogCh := make(chan struct{})
err := httpLogServer(invalidPath, stopLogCh)
if err == nil {
t.Error("Expected error for invalid socket path")
}
if !strings.Contains(err.Error(), "could not listen log socket") {
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
}
}
func TestHTTPLogServerSocketInUse(t *testing.T) {
// Create a temporary socket path
tmpDir := t.TempDir()
sockPath := filepath.Join(tmpDir, "test.sock")
// Create the first server
stopLogCh1 := make(chan struct{})
serverErr1 := make(chan error, 1)
go func() {
serverErr1 <- httpLogServer(sockPath, stopLogCh1)
}()
// Wait for first server to start
time.Sleep(100 * time.Millisecond)
// Try to create a second server on the same socket
stopLogCh2 := make(chan struct{})
err := httpLogServer(sockPath, stopLogCh2)
if err == nil {
t.Error("Expected error when socket is already in use")
}
if !strings.Contains(err.Error(), "could not listen log socket") {
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
}
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogServerConcurrentRequests(t *testing.T) {
// Create a temporary socket path
tmpDir := t.TempDir()
sockPath := filepath.Join(tmpDir, "test.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
// Send concurrent requests
numRequests := 10
done := make(chan bool, numRequests)
for i := 0; i < numRequests; i++ {
go func(i int) {
defer func() { done <- true }()
logMsg := fmt.Sprintf("concurrent log %d", i)
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
if err != nil {
t.Errorf("Failed to send concurrent log %d: %v", i, err)
return
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for request %d, got %d", i, resp.StatusCode)
}
}(i)
}
// Wait for all requests to complete
for i := 0; i < numRequests; i++ {
select {
case <-done:
// Request completed
case <-time.After(5 * time.Second):
t.Errorf("Timeout waiting for concurrent request %d", i)
}
}
// Check if all logs were stored by retrieving them
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer logsResp.Body.Close()
if logsResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
}
body, err := io.ReadAll(logsResp.Body)
if err != nil {
t.Fatalf("Failed to read logs: %v", err)
}
logContent := string(body)
// Verify all logs were stored
for i := 0; i < numRequests; i++ {
expectedLog := fmt.Sprintf("concurrent log %d", i)
if !strings.Contains(logContent, expectedLog) {
t.Errorf("Log '%s' was not stored", expectedLog)
}
}
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogServerErrorHandling(t *testing.T) {
// Create a temporary socket path
tmpDir := t.TempDir()
sockPath := filepath.Join(tmpDir, "test.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
t.Run("Invalid request body", func(t *testing.T) {
// Test with malformed request - this will fail at HTTP level, not server level
// The server will return 400 Bad Request for invalid body
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", strings.NewReader(""))
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()
// Empty body should still be processed successfully
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
})
// Clean up
os.Remove(sockPath)
}
func BenchmarkHTTPLogServer(b *testing.B) {
// Create a temporary socket path
tmpDir := b.TempDir()
sockPath := filepath.Join(tmpDir, "bench.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
go func() {
httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
// Benchmark log sending
b.ResetTimer()
for i := 0; i < b.N; i++ {
logMsg := fmt.Sprintf("benchmark log %d", i)
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
if err != nil {
b.Fatalf("Failed to send log: %v", err)
}
resp.Body.Close()
}
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogClient(t *testing.T) {
// Create a temporary socket path
tmpDir := t.TempDir()
sockPath := filepath.Join(tmpDir, "test.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP log client
client := newHTTPLogClient(sockPath)
t.Run("Ping server", func(t *testing.T) {
err := client.Ping()
if err != nil {
t.Errorf("Ping failed: %v", err)
}
})
t.Run("Write logs", func(t *testing.T) {
testLog := "test log message from client"
n, err := client.Write([]byte(testLog))
if err != nil {
t.Errorf("Write failed: %v", err)
}
if n != len(testLog) {
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
}
// Check if log was stored by retrieving it
logs, err := client.GetLogs()
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
if !strings.Contains(string(logs), testLog) {
t.Errorf("Expected log '%s' not found in stored logs", testLog)
}
})
t.Run("Close client", func(t *testing.T) {
err := client.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
// Check if channel is closed (signaling completion)
select {
case _, ok := <-stopLogCh:
if ok {
t.Error("Expected channel to be closed, but it's still open")
}
case <-time.After(1 * time.Second):
t.Error("Timeout waiting for channel closure")
}
})
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogClientServerUnavailable(t *testing.T) {
// Create client with non-existent socket
sockPath := "/non/existent/socket.sock"
client := newHTTPLogClient(sockPath)
t.Run("Ping unavailable server", func(t *testing.T) {
err := client.Ping()
if err == nil {
t.Error("Expected ping to fail for unavailable server")
}
})
t.Run("Write to unavailable server", func(t *testing.T) {
testLog := "test log message"
n, err := client.Write([]byte(testLog))
if err != nil {
t.Errorf("Write should not return error (ignores errors): %v", err)
}
if n != len(testLog) {
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
}
})
t.Run("Close unavailable server", func(t *testing.T) {
err := client.Close()
if err == nil {
t.Error("Expected close to fail for unavailable server")
}
})
}
func BenchmarkHTTPLogClient(b *testing.B) {
// Create a temporary socket path
tmpDir := b.TempDir()
sockPath := filepath.Join(tmpDir, "bench.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
go func() {
httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP log client
client := newHTTPLogClient(sockPath)
// Benchmark client writes
b.ResetTimer()
for i := 0; i < b.N; i++ {
logMsg := fmt.Sprintf("benchmark write %d", i)
client.Write([]byte(logMsg))
}
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogServerWithLogWriter(t *testing.T) {
// Create a temporary socket path
tmpDir := t.TempDir()
sockPath := filepath.Join(tmpDir, "test.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait a bit for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
t.Run("Store and retrieve logs", func(t *testing.T) {
// Send multiple log messages
logs := []string{"log message 1", "log message 2", "log message 3"}
for _, log := range logs {
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
if err != nil {
t.Fatalf("Failed to send log '%s': %v", log, err)
}
resp.Body.Close()
}
// Retrieve all logs
resp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read logs response: %v", err)
}
logContent := string(body)
for _, log := range logs {
if !strings.Contains(logContent, log) {
t.Errorf("Expected log '%s' not found in retrieved logs", log)
}
}
})
t.Run("Empty logs endpoint", func(t *testing.T) {
// Create a new server for this test
tmpDir2 := t.TempDir()
sockPath2 := filepath.Join(tmpDir2, "test2.sock")
stopLogCh2 := make(chan struct{})
go func() {
httpLogServer(sockPath2, stopLogCh2)
}()
time.Sleep(100 * time.Millisecond)
client2 := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath2)
},
},
}
resp, err := client2.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected status 204, got %d", resp.StatusCode)
}
os.Remove(sockPath2)
})
t.Run("Channel closure on exit", func(t *testing.T) {
// Send exit signal
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
if err != nil {
t.Fatalf("Failed to send exit: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check if channel is closed by trying to read from it
select {
case _, ok := <-stopLogCh:
if ok {
t.Error("Expected channel to be closed, but it's still open")
}
case <-time.After(1 * time.Second):
t.Error("Timeout waiting for channel closure")
}
})
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogClientGetLogs(t *testing.T) {
// Create a temporary socket path
tmpDir := t.TempDir()
sockPath := filepath.Join(tmpDir, "test.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
go func() {
httpLogServer(sockPath, stopLogCh)
}()
// Wait a bit for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP log client
client := newHTTPLogClient(sockPath)
t.Run("Get logs from client", func(t *testing.T) {
// Send some logs
testLogs := []string{"client log 1", "client log 2", "client log 3"}
for _, log := range testLogs {
client.Write([]byte(log + "\n"))
}
// Retrieve logs using client method
logs, err := client.GetLogs()
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
logContent := string(logs)
for _, log := range testLogs {
if !strings.Contains(logContent, log) {
t.Errorf("Expected log '%s' not found in retrieved logs", log)
}
}
})
t.Run("Get empty logs", func(t *testing.T) {
// Create a new client for empty logs test
tmpDir2 := t.TempDir()
sockPath2 := filepath.Join(tmpDir2, "test2.sock")
stopLogCh2 := make(chan struct{})
go func() {
httpLogServer(sockPath2, stopLogCh2)
}()
time.Sleep(100 * time.Millisecond)
client2 := newHTTPLogClient(sockPath2)
logs, err := client2.GetLogs()
if err != nil {
t.Fatalf("Failed to get empty logs: %v", err)
}
if len(logs) != 0 {
t.Errorf("Expected empty logs, got %d bytes", len(logs))
}
os.Remove(sockPath2)
})
// Clean up
os.Remove(sockPath)
}

View File

@@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"math/rand"
"net"
@@ -91,7 +92,7 @@ type prog struct {
apiReloadCh chan *ctrld.Config
apiForceReloadCh chan struct{}
apiForceReloadGroup singleflight.Group
logConn net.Conn
logConn io.WriteCloser
cs *controlServer
logger atomic.Pointer[ctrld.Logger]
csSetDnsDone chan struct{}
@@ -1148,28 +1149,6 @@ func randomPort() int {
return n
}
// runLogServer starts a unix listener, use by startCmd to gather log from runCmd.
func runLogServer(sockPath string) net.Conn {
addr, err := net.ResolveUnixAddr("unix", sockPath)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Invalid log sock path")
return nil
}
ln, err := net.ListenUnix("unix", addr)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Could not listen log socket")
return nil
}
defer ln.Close()
server, err := ln.Accept()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Could not accept connection")
return nil
}
return server
}
func errAddrInUse(err error) bool {
var opErr *net.OpError
if errors.As(err, &opErr) {

4
log.go
View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"os"
"time"
"go.uber.org/zap"
@@ -244,7 +245,8 @@ func (l *Logger) GetLogger() *Logger {
// Write implements io.Writer to allow direct writing to the logger
func (l *Logger) Write(p []byte) (n int, err error) {
l.Info().Msg(string(p))
stdoutSyncer := zapcore.AddSync(os.Stdout)
stdoutSyncer.Write(p)
return len(p), nil
}