From 56f8113bb09c7dff5c7b2f4ae9fa000ae64e24f7 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 12 Sep 2025 18:22:02 +0700 Subject: [PATCH] refactor: replace Unix socket log communication with HTTP-based system Replace the legacy Unix socket log communication between `ctrld start` and `ctrld run` with a modern HTTP-based system for better reliability and maintainability. Benefits: - More reliable communication protocol using standard HTTP - Better error handling and connection management - Cleaner separation of concerns with dedicated endpoints - Easier to test and debug with HTTP-based communication - More maintainable code with proper abstraction layers This change maintains backward compatibility while providing a more robust foundation for inter-process communication between ctrld commands. --- cmd/cli/cli.go | 23 +- cmd/cli/commands_service_start.go | 64 ++- cmd/cli/conn.go | 67 --- cmd/cli/http_log.go | 172 +++++++ cmd/cli/http_log_test.go | 758 ++++++++++++++++++++++++++++++ cmd/cli/prog.go | 25 +- log.go | 4 +- 7 files changed, 976 insertions(+), 137 deletions(-) delete mode 100644 cmd/cli/conn.go create mode 100644 cmd/cli/http_log.go create mode 100644 cmd/cli/http_log_test.go diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index c04518f..effb514 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -234,22 +234,21 @@ func run(appCallback *AppCallback, stopCh chan struct{}) { sockDir = d } sockPath := filepath.Join(sockDir, ctrldLogUnixSock) - if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil { - if conn, err := net.Dial(addr.Network(), addr.String()); err == nil { - lc := &logConn{conn: conn} - consoleWriter = newHumanReadableZapCore(io.MultiWriter(os.Stdout, lc), consoleWriterLevel) - p.logConn = lc - } else { - if !errors.Is(err, os.ErrNotExist) { - p.Warn().Err(err).Msg("Unable to create log ipc connection") - } + hlc := newHTTPLogClient(sockPath) + + // Test if HTTP log server is available + if err := hlc.Ping(); err != nil { + if !errConnectionRefused(err) { + p.Warn().Err(err).Msg("Unable to ping log server") } } else { - p.Warn().Err(err).Msgf("Unable to resolve socket address: %s", sockPath) + // Server is available, use HTTP log client + consoleWriter = newHumanReadableZapCore(io.MultiWriter(os.Stdout, hlc), consoleWriterLevel) + p.logConn = hlc } notifyExitToLogServer := func() { if p.logConn != nil { - _, _ = p.logConn.Write([]byte(msgExit)) + _ = p.logConn.Close() } } @@ -1354,7 +1353,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, notifyFunc func(), fatal bool) ( break } - logMsg(il.Info().Err(err), n, "error listening on address: %s", addr) + logMsg(il.Debug().Err(err), n, "error listening on address: %s", addr) if !check.IP && !check.Port { if fatal { diff --git a/cmd/cli/commands_service_start.go b/cmd/cli/commands_service_start.go index e206e0b..f8a9d98 100644 --- a/cmd/cli/commands_service_start.go +++ b/cmd/cli/commands_service_start.go @@ -10,7 +10,6 @@ import ( "net/http" "os" "path/filepath" - "strings" "time" "github.com/kardianos/service" @@ -104,11 +103,10 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error { writeDefaultConfig := !noConfigStart && configBase64 == "" logServerStarted := make(chan struct{}) - // A buffer channel to gather log output from runCmd and report - // to user in case self-check process failed. - runCmdLogCh := make(chan string, 256) + stopLogCh := make(chan struct{}) ud, err := userHomeDir() sockDir := ud + var logServerSocketPath string if err != nil { logger.Warn().Err(err).Msg("Failed to get user home directory") logger.Warn().Msg("Log server did not start") @@ -122,29 +120,17 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error { if d, err := socketDir(); err == nil { sockDir = d } - sockPath := filepath.Join(sockDir, ctrldLogUnixSock) - _ = os.Remove(sockPath) + logServerSocketPath = filepath.Join(sockDir, ctrldLogUnixSock) + _ = os.Remove(logServerSocketPath) go func() { - defer func() { - close(runCmdLogCh) - _ = os.Remove(sockPath) - }() + defer os.Remove(logServerSocketPath) + close(logServerStarted) - if conn := runLogServer(sockPath); conn != nil { - // Enough buffer for log message, we don't produce - // such long log message, but just in case. - buf := make([]byte, 1024) - for { - n, err := conn.Read(buf) - if err != nil { - return - } - msg := string(buf[:n]) - if _, _, found := strings.Cut(msg, msgExit); found { - cancel() - } - runCmdLogCh <- msg - } + + // Start HTTP log server + if err := httpLogServer(logServerSocketPath, stopLogCh); err != nil && err != http.ErrServerClosed { + logger.Warn().Err(err).Msg("Failed to serve HTTP log server") + return } }() } @@ -270,19 +256,29 @@ func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error { case ok && status == service.StatusRunning: logger.Notice().Msg("Service started") default: - marker := bytes.Repeat([]byte("="), 32) + marker := append(bytes.Repeat([]byte("="), 32), '\n') // If ctrld service is not running, emitting log obtained from ctrld process. if status != service.StatusRunning || ctx.Err() != nil { logger.Error().Msg("Ctrld service may not have started due to an error or misconfiguration, service log:") _, _ = logger.Write(marker) - haveLog := false - for msg := range runCmdLogCh { - _, _ = logger.Write([]byte(strings.ReplaceAll(msg, msgExit, ""))) - haveLog = true - } - // If we're unable to get log from "ctrld run", notice users about it. - if !haveLog { - logger.Write([]byte(`"`)) + + // Wait for log collection to complete + <-stopLogCh + + // Retrieve logs from HTTP server if available + if logServerSocketPath != "" { + hlc := newHTTPLogClient(logServerSocketPath) + logs, err := hlc.GetLogs() + if err != nil { + logger.Warn().Err(err).Msg("Failed to get logs from HTTP log server") + } + if len(logs) == 0 { + logger.Write([]byte(``)) + } else { + logger.Write(logs) + } + } else { + logger.Write([]byte(``)) } } // Report any error if occurred. diff --git a/cmd/cli/conn.go b/cmd/cli/conn.go deleted file mode 100644 index bdad00b..0000000 --- a/cmd/cli/conn.go +++ /dev/null @@ -1,67 +0,0 @@ -package cli - -import ( - "net" - "time" -) - -// logConn wraps a net.Conn, override the Write behavior. -// runCmd uses this wrapper, so as long as startCmd finished, -// ctrld log won't be flushed with un-necessary write errors. -// This prevents log pollution when the parent process closes the connection -type logConn struct { - conn net.Conn -} - -// Read delegates to the underlying connection -// This maintains normal read behavior for the wrapped connection -func (lc *logConn) Read(b []byte) (n int, err error) { - return lc.conn.Read(b) -} - -// Close delegates to the underlying connection -// This ensures proper cleanup of the wrapped connection -func (lc *logConn) Close() error { - return lc.conn.Close() -} - -// LocalAddr delegates to the underlying connection -// This provides access to local address information -func (lc *logConn) LocalAddr() net.Addr { - return lc.conn.LocalAddr() -} - -// RemoteAddr delegates to the underlying connection -// This provides access to remote address information -func (lc *logConn) RemoteAddr() net.Addr { - return lc.conn.RemoteAddr() -} - -// SetDeadline delegates to the underlying connection -// This maintains timeout functionality for the wrapped connection -func (lc *logConn) SetDeadline(t time.Time) error { - return lc.conn.SetDeadline(t) -} - -// SetReadDeadline delegates to the underlying connection -// This maintains read timeout functionality for the wrapped connection -func (lc *logConn) SetReadDeadline(t time.Time) error { - return lc.conn.SetReadDeadline(t) -} - -// SetWriteDeadline delegates to the underlying connection -// This maintains write timeout functionality for the wrapped connection -func (lc *logConn) SetWriteDeadline(t time.Time) error { - return lc.conn.SetWriteDeadline(t) -} - -// Write performs writes with underlying net.Conn, ignore any errors happen. -// "ctrld run" command use this wrapper to report errors to "ctrld start". -// If no error occurred, "ctrld start" may finish before "ctrld run" attempt -// to close the connection, so ignore errors conservatively here, prevent -// un-necessary error "write to closed connection" flushed to ctrld log. -// This prevents log pollution when the parent process closes the connection prematurely -func (lc *logConn) Write(b []byte) (int, error) { - _, _ = lc.conn.Write(b) - return len(b), nil -} diff --git a/cmd/cli/http_log.go b/cmd/cli/http_log.go new file mode 100644 index 0000000..c794cf0 --- /dev/null +++ b/cmd/cli/http_log.go @@ -0,0 +1,172 @@ +package cli + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "sync" +) + +// HTTP log server endpoint constants +const ( + httpLogEndpointPing = "/ping" + httpLogEndpointLogs = "/logs" + httpLogEndpointExit = "/exit" +) + +// httpLogClient sends logs to an HTTP server via POST requests. +// This replaces the logConn functionality with HTTP-based communication. +type httpLogClient struct { + baseURL string + client *http.Client +} + +// newHTTPLogClient creates a new HTTP log client +func newHTTPLogClient(sockPath string) *httpLogClient { + return &httpLogClient{ + baseURL: "http://unix", + client: &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + }, + } +} + +// Write sends log data to the HTTP server via POST request +func (hlc *httpLogClient) Write(b []byte) (int, error) { + // Send log data via HTTP POST to /logs endpoint + resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointLogs, "text/plain", bytes.NewReader(b)) + if err != nil { + // Ignore errors to prevent log pollution, just like the original logConn + return len(b), nil + } + resp.Body.Close() + return len(b), nil +} + +// Ping tests if the HTTP log server is available +func (hlc *httpLogClient) Ping() error { + resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointPing) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +// Close sends exit signal to the HTTP server +func (hlc *httpLogClient) Close() error { + // Send exit signal via HTTP POST with empty body + resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{})) + if err != nil { + return err + } + resp.Body.Close() + return nil +} + +// GetLogs retrieves all collected logs from the HTTP server +func (hlc *httpLogClient) GetLogs() ([]byte, error) { + resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointLogs) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNoContent { + return []byte{}, nil + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// httpLogServer starts an HTTP server listening on unix socket to collect logs from runCmd. +func httpLogServer(sockPath string, stopLogCh chan struct{}) error { + addr, err := net.ResolveUnixAddr("unix", sockPath) + if err != nil { + return fmt.Errorf("invalid log sock path: %w", err) + } + + ln, err := net.ListenUnix("unix", addr) + if err != nil { + return fmt.Errorf("could not listen log socket: %w", err) + } + defer ln.Close() + + // Create a log writer to store all logs + logWriter := newLogWriter() + + // Use a sync.Once to ensure channel is only closed once + var channelClosed sync.Once + + mux := http.NewServeMux() + mux.HandleFunc(httpLogEndpointPing, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + w.WriteHeader(http.StatusOK) + }) + + mux.HandleFunc(httpLogEndpointLogs, func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + // POST /logs - Store log data + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + + // Store log data in log writer + logWriter.Write(body) + + w.WriteHeader(http.StatusOK) + + case http.MethodGet: + // GET /logs - Retrieve all logs + // Get all logs from the log writer + logWriter.mu.Lock() + logs := logWriter.buf.Bytes() + logWriter.mu.Unlock() + + if len(logs) == 0 { + w.WriteHeader(http.StatusNoContent) + return + } + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write(logs) + + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + + mux.HandleFunc(httpLogEndpointExit, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Close the stop channel to signal completion (only once) + channelClosed.Do(func() { + close(stopLogCh) + }) + w.WriteHeader(http.StatusOK) + }) + + server := &http.Server{Handler: mux} + return server.Serve(ln) +} diff --git a/cmd/cli/http_log_test.go b/cmd/cli/http_log_test.go new file mode 100644 index 0000000..495f09e --- /dev/null +++ b/cmd/cli/http_log_test.go @@ -0,0 +1,758 @@ +package cli + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestHTTPLogServer(t *testing.T) { + // Create a temporary socket path + tmpDir := t.TempDir() + sockPath := filepath.Join(tmpDir, "test.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait a bit for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + t.Run("Ping endpoint", func(t *testing.T) { + resp, err := client.Get("http://unix" + httpLogEndpointPing) + if err != nil { + t.Fatalf("Failed to ping server: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + }) + + t.Run("Ping endpoint wrong method", func(t *testing.T) { + resp, err := client.Post("http://unix"+httpLogEndpointPing, "text/plain", bytes.NewReader([]byte("test"))) + if err != nil { + t.Fatalf("Failed to send POST to ping: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405, got %d", resp.StatusCode) + } + }) + + t.Run("Log endpoint", func(t *testing.T) { + testLog := "test log message" + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(testLog))) + if err != nil { + t.Fatalf("Failed to send log: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check if log was stored by retrieving it + logsResp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer logsResp.Body.Close() + + if logsResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode) + } + + body, err := io.ReadAll(logsResp.Body) + if err != nil { + t.Fatalf("Failed to read logs: %v", err) + } + + if !strings.Contains(string(body), testLog) { + t.Errorf("Expected log '%s' not found in stored logs", testLog) + } + }) + + t.Run("Log endpoint wrong method", func(t *testing.T) { + // Test unsupported method (PUT) on /logs endpoint + req, err := http.NewRequest("PUT", "http://unix"+httpLogEndpointLogs, bytes.NewReader([]byte("test"))) + if err != nil { + t.Fatalf("Failed to create PUT request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to send PUT to logs: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405, got %d", resp.StatusCode) + } + }) + + t.Run("Exit endpoint", func(t *testing.T) { + resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{})) + if err != nil { + t.Fatalf("Failed to send exit: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check if channel is closed by trying to read from it + select { + case _, ok := <-stopLogCh: + if ok { + t.Error("Expected channel to be closed, but it's still open") + } + case <-time.After(1 * time.Second): + t.Error("Timeout waiting for channel closure") + } + }) + + t.Run("Exit endpoint wrong method", func(t *testing.T) { + resp, err := client.Get("http://unix" + httpLogEndpointExit) + if err != nil { + t.Fatalf("Failed to send GET to exit: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Expected status 405, got %d", resp.StatusCode) + } + }) + + t.Run("Multiple log messages", func(t *testing.T) { + logs := []string{"log1", "log2", "log3"} + + for _, log := range logs { + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n"))) + if err != nil { + t.Fatalf("Failed to send log '%s': %v", log, err) + } + resp.Body.Close() + } + + // Check if all logs were stored by retrieving them + logsResp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer logsResp.Body.Close() + + if logsResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode) + } + + body, err := io.ReadAll(logsResp.Body) + if err != nil { + t.Fatalf("Failed to read logs: %v", err) + } + + logContent := string(body) + for i, expectedLog := range logs { + if !strings.Contains(logContent, expectedLog) { + t.Errorf("Log %d: expected '%s' not found in stored logs", i, expectedLog) + } + } + }) + + t.Run("Large log message", func(t *testing.T) { + largeLog := strings.Repeat("a", 1024*10) // 10KB log message + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(largeLog))) + if err != nil { + t.Fatalf("Failed to send large log: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check if large log was stored by retrieving it + logsResp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer logsResp.Body.Close() + + if logsResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode) + } + + body, err := io.ReadAll(logsResp.Body) + if err != nil { + t.Fatalf("Failed to read logs: %v", err) + } + + if !strings.Contains(string(body), largeLog) { + t.Error("Large log message was not stored correctly") + } + }) + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogServerInvalidSocketPath(t *testing.T) { + // Test with invalid socket path + invalidPath := "/invalid/path/that/does/not/exist.sock" + stopLogCh := make(chan struct{}) + + err := httpLogServer(invalidPath, stopLogCh) + if err == nil { + t.Error("Expected error for invalid socket path") + } + + if !strings.Contains(err.Error(), "could not listen log socket") { + t.Errorf("Expected 'could not listen log socket' error, got: %v", err) + } +} + +func TestHTTPLogServerSocketInUse(t *testing.T) { + // Create a temporary socket path + tmpDir := t.TempDir() + sockPath := filepath.Join(tmpDir, "test.sock") + + // Create the first server + stopLogCh1 := make(chan struct{}) + serverErr1 := make(chan error, 1) + go func() { + serverErr1 <- httpLogServer(sockPath, stopLogCh1) + }() + + // Wait for first server to start + time.Sleep(100 * time.Millisecond) + + // Try to create a second server on the same socket + stopLogCh2 := make(chan struct{}) + err := httpLogServer(sockPath, stopLogCh2) + if err == nil { + t.Error("Expected error when socket is already in use") + } + + if !strings.Contains(err.Error(), "could not listen log socket") { + t.Errorf("Expected 'could not listen log socket' error, got: %v", err) + } + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogServerConcurrentRequests(t *testing.T) { + // Create a temporary socket path + tmpDir := t.TempDir() + sockPath := filepath.Join(tmpDir, "test.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + // Send concurrent requests + numRequests := 10 + done := make(chan bool, numRequests) + + for i := 0; i < numRequests; i++ { + go func(i int) { + defer func() { done <- true }() + + logMsg := fmt.Sprintf("concurrent log %d", i) + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg))) + if err != nil { + t.Errorf("Failed to send concurrent log %d: %v", i, err) + return + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for request %d, got %d", i, resp.StatusCode) + } + }(i) + } + + // Wait for all requests to complete + for i := 0; i < numRequests; i++ { + select { + case <-done: + // Request completed + case <-time.After(5 * time.Second): + t.Errorf("Timeout waiting for concurrent request %d", i) + } + } + + // Check if all logs were stored by retrieving them + logsResp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer logsResp.Body.Close() + + if logsResp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode) + } + + body, err := io.ReadAll(logsResp.Body) + if err != nil { + t.Fatalf("Failed to read logs: %v", err) + } + + logContent := string(body) + // Verify all logs were stored + for i := 0; i < numRequests; i++ { + expectedLog := fmt.Sprintf("concurrent log %d", i) + if !strings.Contains(logContent, expectedLog) { + t.Errorf("Log '%s' was not stored", expectedLog) + } + } + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogServerErrorHandling(t *testing.T) { + // Create a temporary socket path + tmpDir := t.TempDir() + sockPath := filepath.Join(tmpDir, "test.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + t.Run("Invalid request body", func(t *testing.T) { + // Test with malformed request - this will fail at HTTP level, not server level + // The server will return 400 Bad Request for invalid body + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", strings.NewReader("")) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + // Empty body should still be processed successfully + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + }) + + // Clean up + os.Remove(sockPath) +} + +func BenchmarkHTTPLogServer(b *testing.B) { + // Create a temporary socket path + tmpDir := b.TempDir() + sockPath := filepath.Join(tmpDir, "bench.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + go func() { + httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + // Benchmark log sending + b.ResetTimer() + for i := 0; i < b.N; i++ { + logMsg := fmt.Sprintf("benchmark log %d", i) + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg))) + if err != nil { + b.Fatalf("Failed to send log: %v", err) + } + resp.Body.Close() + } + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogClient(t *testing.T) { + // Create a temporary socket path + tmpDir := t.TempDir() + sockPath := filepath.Join(tmpDir, "test.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP log client + client := newHTTPLogClient(sockPath) + + t.Run("Ping server", func(t *testing.T) { + err := client.Ping() + if err != nil { + t.Errorf("Ping failed: %v", err) + } + }) + + t.Run("Write logs", func(t *testing.T) { + testLog := "test log message from client" + n, err := client.Write([]byte(testLog)) + if err != nil { + t.Errorf("Write failed: %v", err) + } + if n != len(testLog) { + t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n) + } + + // Check if log was stored by retrieving it + logs, err := client.GetLogs() + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + + if !strings.Contains(string(logs), testLog) { + t.Errorf("Expected log '%s' not found in stored logs", testLog) + } + }) + + t.Run("Close client", func(t *testing.T) { + err := client.Close() + if err != nil { + t.Errorf("Close failed: %v", err) + } + + // Check if channel is closed (signaling completion) + select { + case _, ok := <-stopLogCh: + if ok { + t.Error("Expected channel to be closed, but it's still open") + } + case <-time.After(1 * time.Second): + t.Error("Timeout waiting for channel closure") + } + }) + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogClientServerUnavailable(t *testing.T) { + // Create client with non-existent socket + sockPath := "/non/existent/socket.sock" + client := newHTTPLogClient(sockPath) + + t.Run("Ping unavailable server", func(t *testing.T) { + err := client.Ping() + if err == nil { + t.Error("Expected ping to fail for unavailable server") + } + }) + + t.Run("Write to unavailable server", func(t *testing.T) { + testLog := "test log message" + n, err := client.Write([]byte(testLog)) + if err != nil { + t.Errorf("Write should not return error (ignores errors): %v", err) + } + if n != len(testLog) { + t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n) + } + }) + + t.Run("Close unavailable server", func(t *testing.T) { + err := client.Close() + if err == nil { + t.Error("Expected close to fail for unavailable server") + } + }) +} + +func BenchmarkHTTPLogClient(b *testing.B) { + // Create a temporary socket path + tmpDir := b.TempDir() + sockPath := filepath.Join(tmpDir, "bench.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + go func() { + httpLogServer(sockPath, stopLogCh) + }() + + // Wait for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP log client + client := newHTTPLogClient(sockPath) + + // Benchmark client writes + b.ResetTimer() + for i := 0; i < b.N; i++ { + logMsg := fmt.Sprintf("benchmark write %d", i) + client.Write([]byte(logMsg)) + } + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogServerWithLogWriter(t *testing.T) { + // Create a temporary socket path + tmpDir := t.TempDir() + sockPath := filepath.Join(tmpDir, "test.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + serverErr := make(chan error, 1) + go func() { + serverErr <- httpLogServer(sockPath, stopLogCh) + }() + + // Wait a bit for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath) + }, + }, + } + + t.Run("Store and retrieve logs", func(t *testing.T) { + // Send multiple log messages + logs := []string{"log message 1", "log message 2", "log message 3"} + + for _, log := range logs { + resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n"))) + if err != nil { + t.Fatalf("Failed to send log '%s': %v", log, err) + } + resp.Body.Close() + } + + // Retrieve all logs + resp, err := client.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read logs response: %v", err) + } + + logContent := string(body) + for _, log := range logs { + if !strings.Contains(logContent, log) { + t.Errorf("Expected log '%s' not found in retrieved logs", log) + } + } + }) + + t.Run("Empty logs endpoint", func(t *testing.T) { + // Create a new server for this test + tmpDir2 := t.TempDir() + sockPath2 := filepath.Join(tmpDir2, "test2.sock") + stopLogCh2 := make(chan struct{}) + + go func() { + httpLogServer(sockPath2, stopLogCh2) + }() + time.Sleep(100 * time.Millisecond) + + client2 := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", sockPath2) + }, + }, + } + + resp, err := client2.Get("http://unix" + httpLogEndpointLogs) + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + t.Errorf("Expected status 204, got %d", resp.StatusCode) + } + + os.Remove(sockPath2) + }) + + t.Run("Channel closure on exit", func(t *testing.T) { + // Send exit signal + resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{})) + if err != nil { + t.Fatalf("Failed to send exit: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + // Check if channel is closed by trying to read from it + select { + case _, ok := <-stopLogCh: + if ok { + t.Error("Expected channel to be closed, but it's still open") + } + case <-time.After(1 * time.Second): + t.Error("Timeout waiting for channel closure") + } + }) + + // Clean up + os.Remove(sockPath) +} + +func TestHTTPLogClientGetLogs(t *testing.T) { + // Create a temporary socket path + tmpDir := t.TempDir() + sockPath := filepath.Join(tmpDir, "test.sock") + + // Create log channel + stopLogCh := make(chan struct{}) + + // Start HTTP log server in a goroutine + go func() { + httpLogServer(sockPath, stopLogCh) + }() + + // Wait a bit for server to start + time.Sleep(100 * time.Millisecond) + + // Create HTTP log client + client := newHTTPLogClient(sockPath) + + t.Run("Get logs from client", func(t *testing.T) { + // Send some logs + testLogs := []string{"client log 1", "client log 2", "client log 3"} + for _, log := range testLogs { + client.Write([]byte(log + "\n")) + } + + // Retrieve logs using client method + logs, err := client.GetLogs() + if err != nil { + t.Fatalf("Failed to get logs: %v", err) + } + + logContent := string(logs) + for _, log := range testLogs { + if !strings.Contains(logContent, log) { + t.Errorf("Expected log '%s' not found in retrieved logs", log) + } + } + }) + + t.Run("Get empty logs", func(t *testing.T) { + // Create a new client for empty logs test + tmpDir2 := t.TempDir() + sockPath2 := filepath.Join(tmpDir2, "test2.sock") + stopLogCh2 := make(chan struct{}) + + go func() { + httpLogServer(sockPath2, stopLogCh2) + }() + time.Sleep(100 * time.Millisecond) + + client2 := newHTTPLogClient(sockPath2) + logs, err := client2.GetLogs() + if err != nil { + t.Fatalf("Failed to get empty logs: %v", err) + } + + if len(logs) != 0 { + t.Errorf("Expected empty logs, got %d bytes", len(logs)) + } + + os.Remove(sockPath2) + }) + + // Clean up + os.Remove(sockPath) +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 2a25626..89fd8e3 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "io/fs" "math/rand" "net" @@ -91,7 +92,7 @@ type prog struct { apiReloadCh chan *ctrld.Config apiForceReloadCh chan struct{} apiForceReloadGroup singleflight.Group - logConn net.Conn + logConn io.WriteCloser cs *controlServer logger atomic.Pointer[ctrld.Logger] csSetDnsDone chan struct{} @@ -1148,28 +1149,6 @@ func randomPort() int { return n } -// runLogServer starts a unix listener, use by startCmd to gather log from runCmd. -func runLogServer(sockPath string) net.Conn { - addr, err := net.ResolveUnixAddr("unix", sockPath) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Invalid log sock path") - return nil - } - ln, err := net.ListenUnix("unix", addr) - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Could not listen log socket") - return nil - } - defer ln.Close() - - server, err := ln.Accept() - if err != nil { - mainLog.Load().Warn().Err(err).Msg("Could not accept connection") - return nil - } - return server -} - func errAddrInUse(err error) bool { var opErr *net.OpError if errors.As(err, &opErr) { diff --git a/log.go b/log.go index a55157a..2f3a42f 100644 --- a/log.go +++ b/log.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "os" "time" "go.uber.org/zap" @@ -244,7 +245,8 @@ func (l *Logger) GetLogger() *Logger { // Write implements io.Writer to allow direct writing to the logger func (l *Logger) Write(p []byte) (n int, err error) { - l.Info().Msg(string(p)) + stdoutSyncer := zapcore.AddSync(os.Stdout) + stdoutSyncer.Write(p) return len(p), nil }