mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-25 23:30:41 +01:00
Compare commits
4 Commits
v1.5.0
...
release-br
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c0585b2e8 | ||
|
|
112d1cb5a9 | ||
|
|
bd9bb90dd4 | ||
|
|
82fc628bf3 |
64
README.md
64
README.md
@@ -100,7 +100,7 @@ docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
|
||||
|
||||
# Usage
|
||||
The cli is self documenting, so free free to run `--help` on any sub-command to get specific usages.
|
||||
The cli is self documenting, so feel free to run `--help` on any sub-command to get specific usages.
|
||||
|
||||
## Arguments
|
||||
```
|
||||
@@ -266,5 +266,67 @@ The above will start a foreground process and:
|
||||
- Excluding `*.company.int` and `very-secure.local` matching queries, that are forwarded to `10.0.10.1:53`
|
||||
- Write a debug log to `/path/to/log.log`
|
||||
|
||||
## DNS Intercept Mode
|
||||
When running `ctrld` alongside VPN software, DNS conflicts can cause intermittent failures, bypassed filtering, or configuration loops. DNS Intercept Mode prevents these issues by transparently capturing all DNS traffic on the system and routing it through `ctrld`, without modifying network adapter DNS settings.
|
||||
|
||||
### When to Use
|
||||
Enable DNS Intercept Mode if you:
|
||||
- Use corporate VPN software (F5, Cisco AnyConnect, Palo Alto GlobalProtect, Zscaler)
|
||||
- Run overlay networks like Tailscale or WireGuard
|
||||
- Experience random DNS failures when VPN connects/disconnects
|
||||
- See gaps in your Control D analytics when VPN is active
|
||||
- Have endpoint security software that also manages DNS
|
||||
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --intercept-mode dns --cd RESOLVER_ID_HERE
|
||||
```
|
||||
|
||||
macOS
|
||||
```shell
|
||||
sudo ctrld start --intercept-mode dns --cd RESOLVER_ID_HERE
|
||||
```
|
||||
|
||||
`--intercept-mode dns` automatically detects VPN internal domains and routes them to the VPN's DNS server, while Control D handles everything else.
|
||||
|
||||
To disable intercept mode on a service that already has it enabled:
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --intercept-mode off
|
||||
```
|
||||
|
||||
macOS
|
||||
```shell
|
||||
sudo ctrld start --intercept-mode off
|
||||
```
|
||||
|
||||
This removes the intercept rules and reverts to standard interface-based DNS configuration.
|
||||
|
||||
### Platform Support
|
||||
| Platform | Supported | Mechanism |
|
||||
|----------|-----------|-----------|
|
||||
| Windows | ✅ | NRPT (Name Resolution Policy Table) |
|
||||
| macOS | ✅ | pf (packet filter) redirect |
|
||||
| Linux | ❌ | Not currently supported |
|
||||
|
||||
### Features
|
||||
- **VPN split routing** — VPN-specific domains are automatically detected and forwarded to the VPN's DNS server
|
||||
- **Captive portal recovery** — Wi-Fi login pages (hotels, airports, coffee shops) work automatically
|
||||
- **No network adapter changes** — DNS settings stay untouched, eliminating conflicts entirely
|
||||
- **Automatic port 53 conflict resolution** — if another process (e.g., `mDNSResponder` on macOS) is already using port 53, `ctrld` automatically listens on a different port. OS-level packet interception redirects all DNS traffic to `ctrld` transparently, so no manual configuration is needed. This only applies to intercept mode.
|
||||
|
||||
### Tested VPN Software
|
||||
- F5 BIG-IP APM
|
||||
- Cisco AnyConnect
|
||||
- Palo Alto GlobalProtect
|
||||
- Tailscale (including Exit Nodes)
|
||||
- Windscribe
|
||||
- WireGuard
|
||||
|
||||
For more details, see the [DNS Intercept Mode documentation](https://docs.controld.com/docs/dns-intercept).
|
||||
|
||||
## Contributing
|
||||
See [Contribution Guideline](./docs/contributing.md)
|
||||
|
||||
@@ -11,12 +11,14 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/docker/go-units"
|
||||
@@ -146,6 +148,88 @@ func initLogCmd() *cobra.Command {
|
||||
fmt.Println(logs.Data)
|
||||
},
|
||||
}
|
||||
var tailLines int
|
||||
logTailCmd := &cobra.Command{
|
||||
Use: "tail",
|
||||
Short: "Tail live runtime debug logs",
|
||||
Long: "Stream live runtime debug logs to the terminal, similar to tail -f. Press Ctrl+C to stop.",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
|
||||
p := &prog{router: router.New(&cfg, false)}
|
||||
s, _ := newService(p, svcConfig)
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("service not installed")
|
||||
return
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("service is not running")
|
||||
return
|
||||
}
|
||||
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir")
|
||||
}
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
tailPath := fmt.Sprintf("%s?lines=%d", tailLogsPath, tailLines)
|
||||
resp, err := cc.postStream(tailPath, nil)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("failed to connect for log tailing")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusMovedPermanently:
|
||||
warnRuntimeLoggingNotEnabled()
|
||||
return
|
||||
case http.StatusOK:
|
||||
default:
|
||||
mainLog.Load().Fatal().Msgf("unexpected response status: %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
// Set up signal handling for clean shutdown.
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
// Stream output to stdout.
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
os.Stdout.Write(buf[:n])
|
||||
}
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
mainLog.Load().Error().Err(readErr).Msg("error reading log stream")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
msg := fmt.Sprintf("\nexiting: %s\n", context.Cause(ctx).Error())
|
||||
os.Stdout.WriteString(msg)
|
||||
}
|
||||
case <-done:
|
||||
}
|
||||
|
||||
},
|
||||
}
|
||||
logTailCmd.Flags().IntVarP(&tailLines, "lines", "n", 10, "Number of historical lines to show on connect")
|
||||
|
||||
logCmd := &cobra.Command{
|
||||
Use: "log",
|
||||
Short: "Manage runtime debug logs",
|
||||
@@ -156,6 +240,7 @@ func initLogCmd() *cobra.Command {
|
||||
}
|
||||
logCmd.AddCommand(logSendCmd)
|
||||
logCmd.AddCommand(logViewCmd)
|
||||
logCmd.AddCommand(logTailCmd)
|
||||
rootCmd.AddCommand(logCmd)
|
||||
|
||||
return logCmd
|
||||
|
||||
@@ -32,6 +32,12 @@ func (c *controlClient) post(path string, data io.Reader) (*http.Response, error
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
// postStream sends a POST request with no timeout, suitable for long-lived streaming connections.
|
||||
func (c *controlClient) postStream(path string, data io.Reader) (*http.Response, error) {
|
||||
c.c.Timeout = 0
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
// deactivationRequest represents request for validating deactivation pin.
|
||||
type deactivationRequest struct {
|
||||
Pin int64 `json:"pin"`
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -29,6 +30,7 @@ const (
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
tailLogsPath = "/log/tail"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
@@ -344,6 +346,170 @@ func (p *prog) registerControlServerHandler() {
|
||||
}
|
||||
p.internalLogSent = time.Now()
|
||||
}))
|
||||
p.cs.register(tailLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Determine logging mode and validate before starting the stream.
|
||||
var lw *logWriter
|
||||
useInternalLog := p.needInternalLogging()
|
||||
if useInternalLog {
|
||||
p.mu.Lock()
|
||||
lw = p.internalLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
} else if p.cfg.Service.LogPath == "" {
|
||||
// No logging configured at all.
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse optional "lines" query param for initial context.
|
||||
numLines := 10
|
||||
if v := request.URL.Query().Get("lines"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n >= 0 {
|
||||
numLines = n
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if useInternalLog {
|
||||
// Internal logging mode: subscribe to the logWriter.
|
||||
|
||||
// Send last N lines as initial context.
|
||||
if numLines > 0 {
|
||||
if tail := lw.tailLastLines(numLines); len(tail) > 0 {
|
||||
w.Write(tail)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
ch, unsub := lw.Subscribe()
|
||||
defer unsub()
|
||||
for {
|
||||
select {
|
||||
case data, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-request.Context().Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// File-based logging mode: tail the log file.
|
||||
logFile := normalizeLogFilePath(p.cfg.Service.LogPath)
|
||||
f, err := os.Open(logFile)
|
||||
if err != nil {
|
||||
// Already committed 200, just return.
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Seek to show last N lines.
|
||||
if numLines > 0 {
|
||||
if tail := tailFileLastLines(f, numLines); len(tail) > 0 {
|
||||
w.Write(tail)
|
||||
flusher.Flush()
|
||||
}
|
||||
} else {
|
||||
// Seek to end.
|
||||
f.Seek(0, io.SeekEnd)
|
||||
}
|
||||
|
||||
// Poll for new data.
|
||||
buf := make([]byte, 4096)
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
n, err := f.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil && err != io.EOF {
|
||||
return
|
||||
}
|
||||
case <-request.Context().Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// tailFileLastLines reads the last n lines from a file and returns them.
|
||||
// The file position is left at the end of the file after this call.
|
||||
func tailFileLastLines(f *os.File, n int) []byte {
|
||||
stat, err := f.Stat()
|
||||
if err != nil || stat.Size() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read from the end in chunks to find the last n lines.
|
||||
const chunkSize = 4096
|
||||
fileSize := stat.Size()
|
||||
var lines []byte
|
||||
offset := fileSize
|
||||
count := 0
|
||||
|
||||
for offset > 0 && count <= n {
|
||||
readSize := int64(chunkSize)
|
||||
if readSize > offset {
|
||||
readSize = offset
|
||||
}
|
||||
offset -= readSize
|
||||
buf := make([]byte, readSize)
|
||||
nRead, err := f.ReadAt(buf, offset)
|
||||
if err != nil && err != io.EOF {
|
||||
break
|
||||
}
|
||||
buf = buf[:nRead]
|
||||
lines = append(buf, lines...)
|
||||
|
||||
// Count newlines in this chunk.
|
||||
for _, b := range buf {
|
||||
if b == '\n' {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Trim to last n lines.
|
||||
idx := 0
|
||||
nlCount := 0
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
if lines[i] == '\n' {
|
||||
nlCount++
|
||||
if nlCount == n+1 {
|
||||
idx = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
lines = lines[idx:]
|
||||
|
||||
// Seek to end of file for subsequent reads.
|
||||
f.Seek(0, io.SeekEnd)
|
||||
return lines
|
||||
}
|
||||
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
|
||||
339
cmd/cli/log_tail_test.go
Normal file
339
cmd/cli/log_tail_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// logWriter.tailLastLines tests
|
||||
// =============================================================================
|
||||
|
||||
func Test_logWriter_tailLastLines_Empty(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
if got := lw.tailLastLines(10); got != nil {
|
||||
t.Fatalf("expected nil for empty buffer, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_ZeroLines(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\n"))
|
||||
if got := lw.tailLastLines(0); got != nil {
|
||||
t.Fatalf("expected nil for n=0, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_NegativeLines(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\n"))
|
||||
if got := lw.tailLastLines(-1); got != nil {
|
||||
t.Fatalf("expected nil for n=-1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_FewerThanN(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\n"))
|
||||
got := string(lw.tailLastLines(10))
|
||||
want := "line1\nline2\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_ExactN(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\nline3\n"))
|
||||
got := string(lw.tailLastLines(3))
|
||||
want := "line1\nline2\nline3\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_MoreThanN(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\nline3\nline4\nline5\n"))
|
||||
got := string(lw.tailLastLines(2))
|
||||
want := "line4\nline5\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_NoTrailingNewline(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\nline3"))
|
||||
// Without trailing newline, "line3" is a partial line.
|
||||
// Asking for 1 line returns the last newline-terminated line plus the partial.
|
||||
got := string(lw.tailLastLines(1))
|
||||
want := "line2\nline3"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_SingleLineNoNewline(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("only line"))
|
||||
got := string(lw.tailLastLines(5))
|
||||
want := "only line"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_SingleLineWithNewline(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("only line\n"))
|
||||
got := string(lw.tailLastLines(1))
|
||||
want := "only line\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// logWriter.Subscribe tests
|
||||
// =============================================================================
|
||||
|
||||
func Test_logWriter_Subscribe_Basic(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
ch, unsub := lw.Subscribe()
|
||||
defer unsub()
|
||||
|
||||
msg := []byte("hello world\n")
|
||||
lw.Write(msg)
|
||||
|
||||
select {
|
||||
case got := <-ch:
|
||||
if string(got) != string(msg) {
|
||||
t.Fatalf("got %q, want %q", got, msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for subscriber data")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_MultipleSubscribers(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
ch1, unsub1 := lw.Subscribe()
|
||||
defer unsub1()
|
||||
ch2, unsub2 := lw.Subscribe()
|
||||
defer unsub2()
|
||||
|
||||
msg := []byte("broadcast\n")
|
||||
lw.Write(msg)
|
||||
|
||||
for i, ch := range []<-chan []byte{ch1, ch2} {
|
||||
select {
|
||||
case got := <-ch:
|
||||
if string(got) != string(msg) {
|
||||
t.Fatalf("subscriber %d: got %q, want %q", i, got, msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("subscriber %d: timed out", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_Unsubscribe(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
ch, unsub := lw.Subscribe()
|
||||
|
||||
// Verify subscribed.
|
||||
lw.Write([]byte("before unsub\n"))
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out before unsub")
|
||||
}
|
||||
|
||||
unsub()
|
||||
|
||||
// Channel should be closed after unsub.
|
||||
if _, ok := <-ch; ok {
|
||||
t.Fatal("channel should be closed after unsubscribe")
|
||||
}
|
||||
|
||||
// Verify subscriber list is empty.
|
||||
lw.mu.Lock()
|
||||
count := len(lw.subscribers)
|
||||
lw.mu.Unlock()
|
||||
if count != 0 {
|
||||
t.Fatalf("expected 0 subscribers after unsub, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_UnsubscribeIdempotent(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
_, unsub := lw.Subscribe()
|
||||
unsub()
|
||||
// Second unsub should not panic.
|
||||
unsub()
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_SlowSubscriberDropped(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
ch, unsub := lw.Subscribe()
|
||||
defer unsub()
|
||||
|
||||
// Fill the subscriber channel (buffer size is 256).
|
||||
for i := 0; i < 300; i++ {
|
||||
lw.Write([]byte("msg\n"))
|
||||
}
|
||||
|
||||
// Should have 256 buffered messages, rest dropped.
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
count++
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
if count != 256 {
|
||||
t.Fatalf("expected 256 buffered messages, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_ConcurrentWriteAndRead(t *testing.T) {
|
||||
lw := newLogWriterWithSize(64 * 1024)
|
||||
ch, unsub := lw.Subscribe()
|
||||
defer unsub()
|
||||
|
||||
const numWrites = 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < numWrites; i++ {
|
||||
lw.Write([]byte("concurrent write\n"))
|
||||
}
|
||||
}()
|
||||
|
||||
received := 0
|
||||
timeout := time.After(5 * time.Second)
|
||||
for received < numWrites {
|
||||
select {
|
||||
case <-ch:
|
||||
received++
|
||||
case <-timeout:
|
||||
t.Fatalf("timed out after receiving %d/%d messages", received, numWrites)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// tailFileLastLines tests
|
||||
// =============================================================================
|
||||
|
||||
func writeTempFile(t *testing.T, content string) *os.File {
|
||||
t.Helper()
|
||||
f, err := os.CreateTemp(t.TempDir(), "tail-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_Empty(t *testing.T) {
|
||||
f := writeTempFile(t, "")
|
||||
defer f.Close()
|
||||
if got := tailFileLastLines(f, 10); got != nil {
|
||||
t.Fatalf("expected nil for empty file, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_FewerThanN(t *testing.T) {
|
||||
f := writeTempFile(t, "line1\nline2\n")
|
||||
defer f.Close()
|
||||
got := string(tailFileLastLines(f, 10))
|
||||
want := "line1\nline2\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_ExactN(t *testing.T) {
|
||||
f := writeTempFile(t, "a\nb\nc\n")
|
||||
defer f.Close()
|
||||
got := string(tailFileLastLines(f, 3))
|
||||
want := "a\nb\nc\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_MoreThanN(t *testing.T) {
|
||||
f := writeTempFile(t, "line1\nline2\nline3\nline4\nline5\n")
|
||||
defer f.Close()
|
||||
got := string(tailFileLastLines(f, 2))
|
||||
want := "line4\nline5\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_NoTrailingNewline(t *testing.T) {
|
||||
f := writeTempFile(t, "line1\nline2\nline3")
|
||||
defer f.Close()
|
||||
// Without trailing newline, partial last line comes with the previous line.
|
||||
got := string(tailFileLastLines(f, 1))
|
||||
want := "line2\nline3"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_LargerThanChunk(t *testing.T) {
|
||||
// Build content larger than the 4096 chunk size to exercise multi-chunk reads.
|
||||
var sb strings.Builder
|
||||
for i := 0; i < 200; i++ {
|
||||
sb.WriteString(strings.Repeat("x", 50))
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
f := writeTempFile(t, sb.String())
|
||||
defer f.Close()
|
||||
got := string(tailFileLastLines(f, 3))
|
||||
lines := strings.Split(strings.TrimRight(got, "\n"), "\n")
|
||||
if len(lines) != 3 {
|
||||
t.Fatalf("expected 3 lines, got %d: %q", len(lines), got)
|
||||
}
|
||||
expectedLine := strings.Repeat("x", 50)
|
||||
for _, line := range lines {
|
||||
if line != expectedLine {
|
||||
t.Fatalf("unexpected line content: %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_SeeksToEnd(t *testing.T) {
|
||||
f := writeTempFile(t, "line1\nline2\nline3\n")
|
||||
defer f.Close()
|
||||
tailFileLastLines(f, 1)
|
||||
|
||||
// After tailFileLastLines, file position should be at the end.
|
||||
pos, err := f.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if pos != stat.Size() {
|
||||
t.Fatalf("expected file position at end (%d), got %d", stat.Size(), pos)
|
||||
}
|
||||
}
|
||||
@@ -38,11 +38,17 @@ type logReader struct {
|
||||
size int64
|
||||
}
|
||||
|
||||
// logSubscriber represents a subscriber to live log output.
|
||||
type logSubscriber struct {
|
||||
ch chan []byte
|
||||
}
|
||||
|
||||
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
|
||||
type logWriter struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
subscribers []*logSubscriber
|
||||
}
|
||||
|
||||
// newLogWriter creates an internal log writer.
|
||||
@@ -61,10 +67,70 @@ func newLogWriterWithSize(size int) *logWriter {
|
||||
return lw
|
||||
}
|
||||
|
||||
// Subscribe returns a channel that receives new log data as it's written,
|
||||
// and an unsubscribe function to clean up when done.
|
||||
func (lw *logWriter) Subscribe() (<-chan []byte, func()) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
sub := &logSubscriber{ch: make(chan []byte, 256)}
|
||||
lw.subscribers = append(lw.subscribers, sub)
|
||||
unsub := func() {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
for i, s := range lw.subscribers {
|
||||
if s == sub {
|
||||
lw.subscribers = append(lw.subscribers[:i], lw.subscribers[i+1:]...)
|
||||
close(sub.ch)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return sub.ch, unsub
|
||||
}
|
||||
|
||||
// tailLastLines returns the last n lines from the current buffer.
|
||||
func (lw *logWriter) tailLastLines(n int) []byte {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
data := lw.buf.Bytes()
|
||||
if n <= 0 || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Find the last n newlines from the end.
|
||||
count := 0
|
||||
pos := len(data)
|
||||
for pos > 0 {
|
||||
pos--
|
||||
if data[pos] == '\n' {
|
||||
count++
|
||||
if count == n+1 {
|
||||
pos++ // move past this newline
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]byte, len(data)-pos)
|
||||
copy(result, data[pos:])
|
||||
return result
|
||||
}
|
||||
|
||||
func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
|
||||
// Fan-out to subscribers (non-blocking).
|
||||
if len(lw.subscribers) > 0 {
|
||||
cp := make([]byte, len(p))
|
||||
copy(cp, p)
|
||||
for _, sub := range lw.subscribers {
|
||||
select {
|
||||
case sub.ch <- cp:
|
||||
default:
|
||||
// Drop if subscriber is slow to avoid blocking the logger.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If writing p causes overflows, discard old data.
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
|
||||
@@ -160,6 +160,7 @@ func hasLocalDnsServerRunning() bool {
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
defer windows.CloseHandle(h)
|
||||
p := windows.ProcessEntry32{Size: processEntrySize}
|
||||
for {
|
||||
e := windows.Process32Next(h, &p)
|
||||
|
||||
@@ -4,6 +4,7 @@ package ctrld
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
@@ -17,6 +18,31 @@ func currentNameserversFromResolvconf() []string {
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// localNameservers filters a list of nameserver strings, returning only those
|
||||
// that are not loopback or local machine IP addresses.
|
||||
func localNameservers(nss []string, regularIPs, loopbackIPs []netip.Addr) []string {
|
||||
var result []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, ns := range nss {
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback and local IPs
|
||||
isLocal := false
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
if ip.String() == v.String() {
|
||||
isLocal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isLocal && !seen[ip.String()] {
|
||||
seen[ip.String()] = true
|
||||
result = append(result, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file.
|
||||
// A nameserver is usable if it's not one of current machine's IP addresses
|
||||
// and loopback IP addresses.
|
||||
@@ -35,24 +61,7 @@ func dnsFromResolvConf() []string {
|
||||
}
|
||||
|
||||
nss := resolvconffile.NameServers()
|
||||
var localDNS []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, ns := range nss {
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback IPs
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
ipStr := v.String()
|
||||
if ip.String() == ipStr {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !seen[ip.String()] {
|
||||
seen[ip.String()] = true
|
||||
localDNS = append(localDNS, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
localDNS := localNameservers(nss, regularIPs, loopbackIPs)
|
||||
|
||||
// If we successfully read the file and found nameservers, return them
|
||||
if len(localDNS) > 0 {
|
||||
|
||||
105
nameservers_unix_test.go
Normal file
105
nameservers_unix_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
//go:build unix
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_localNameservers(t *testing.T) {
|
||||
loopbackIPs := []netip.Addr{
|
||||
netip.MustParseAddr("127.0.0.1"),
|
||||
netip.MustParseAddr("::1"),
|
||||
}
|
||||
regularIPs := []netip.Addr{
|
||||
netip.MustParseAddr("192.168.1.100"),
|
||||
netip.MustParseAddr("10.0.0.5"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nss []string
|
||||
regularIPs []netip.Addr
|
||||
loopbackIPs []netip.Addr
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "filters loopback IPv4",
|
||||
nss: []string{"127.0.0.1", "8.8.8.8"},
|
||||
regularIPs: nil,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"8.8.8.8"},
|
||||
},
|
||||
{
|
||||
name: "filters loopback IPv6",
|
||||
nss: []string{"::1", "1.1.1.1"},
|
||||
regularIPs: nil,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"1.1.1.1"},
|
||||
},
|
||||
{
|
||||
name: "filters local machine IPs",
|
||||
nss: []string{"192.168.1.100", "8.8.4.4"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: nil,
|
||||
want: []string{"8.8.4.4"},
|
||||
},
|
||||
{
|
||||
name: "filters both loopback and local IPs",
|
||||
nss: []string{"127.0.0.1", "192.168.1.100", "8.8.8.8"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"8.8.8.8"},
|
||||
},
|
||||
{
|
||||
name: "deduplicates results",
|
||||
nss: []string{"8.8.8.8", "8.8.8.8", "1.1.1.1"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"8.8.8.8", "1.1.1.1"},
|
||||
},
|
||||
{
|
||||
name: "all filtered returns nil",
|
||||
nss: []string{"127.0.0.1", "::1", "192.168.1.100"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "empty input returns nil",
|
||||
nss: nil,
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "skips unparseable entries",
|
||||
nss: []string{"not-an-ip", "8.8.8.8"},
|
||||
regularIPs: regularIPs,
|
||||
loopbackIPs: loopbackIPs,
|
||||
want: []string{"8.8.8.8"},
|
||||
},
|
||||
{
|
||||
name: "no local IPs filters nothing",
|
||||
nss: []string{"8.8.8.8", "1.1.1.1"},
|
||||
regularIPs: nil,
|
||||
loopbackIPs: nil,
|
||||
want: []string{"8.8.8.8", "1.1.1.1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := localNameservers(tt.nss, tt.regularIPs, tt.loopbackIPs)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("localNameservers() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("localNameservers()[%d] = %q, want %q", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user