Compare commits

..

4 Commits

Author SHA1 Message Date
Codescribe
5c0585b2e8 Add log tail command for live log streaming
This commit adds a new `ctrld log tail` subcommand that streams
runtime debug logs to the terminal in real-time, similar to `tail -f`.

Changes:
- log_writer.go: Add Subscribe/tailLastLines for fan-out to tail clients
- control_server.go: Add /log/tail endpoint with streaming response
  - Internal logging: subscribes to logWriter for live data
  - File-based logging: polls log file for new data (200ms interval)
  - Sends last N lines as initial context on connect
- commands.go: Add `log tail` cobra subcommand with --lines/-n flag
- control_client.go: Add postStream() with no timeout for long-lived connections

Usage:
  sudo ctrld log tail          # shows last 10 lines then follows
  sudo ctrld log tail -n 50    # shows last 50 lines then follows
  Ctrl+C to stop
2026-03-25 13:58:44 +07:00
Codescribe
112d1cb5a9 fix: close handle leak in hasLocalDnsServerRunning()
Add defer windows.CloseHandle(h) after CreateToolhelp32Snapshot to ensure
the process snapshot handle is properly released on all code paths (match
found, enumeration exhausted, or error).
2026-03-25 13:58:24 +07:00
Codescribe
bd9bb90dd4 Fix dnsFromResolvConf not filtering loopback IPs
The continue statement only broke out of the inner loop, so
loopback/local IPs (e.g. 127.0.0.1) were never filtered.
This caused ctrld to use itself as bootstrap DNS when already
installed as the system resolver — a self-referential loop.

Use the same isLocal flag pattern as getDNSFromScutil() and
getAllDHCPNameservers().
2026-03-25 13:57:46 +07:00
Codescribe
82fc628bf3 docs: add DNS Intercept Mode section to README 2026-03-25 13:57:35 +07:00
9 changed files with 861 additions and 22 deletions

View File

@@ -100,7 +100,7 @@ docker build -t controldns/ctrld . -f docker/Dockerfile
# Usage
The cli is self documenting, so free free to run `--help` on any sub-command to get specific usages.
The cli is self documenting, so feel free to run `--help` on any sub-command to get specific usages.
## Arguments
```
@@ -266,5 +266,67 @@ The above will start a foreground process and:
- Excluding `*.company.int` and `very-secure.local` matching queries, that are forwarded to `10.0.10.1:53`
- Write a debug log to `/path/to/log.log`
## DNS Intercept Mode
When running `ctrld` alongside VPN software, DNS conflicts can cause intermittent failures, bypassed filtering, or configuration loops. DNS Intercept Mode prevents these issues by transparently capturing all DNS traffic on the system and routing it through `ctrld`, without modifying network adapter DNS settings.
### When to Use
Enable DNS Intercept Mode if you:
- Use corporate VPN software (F5, Cisco AnyConnect, Palo Alto GlobalProtect, Zscaler)
- Run overlay networks like Tailscale or WireGuard
- Experience random DNS failures when VPN connects/disconnects
- See gaps in your Control D analytics when VPN is active
- Have endpoint security software that also manages DNS
### Command
Windows (Admin Shell)
```shell
ctrld.exe start --intercept-mode dns --cd RESOLVER_ID_HERE
```
macOS
```shell
sudo ctrld start --intercept-mode dns --cd RESOLVER_ID_HERE
```
`--intercept-mode dns` automatically detects VPN internal domains and routes them to the VPN's DNS server, while Control D handles everything else.
To disable intercept mode on a service that already has it enabled:
Windows (Admin Shell)
```shell
ctrld.exe start --intercept-mode off
```
macOS
```shell
sudo ctrld start --intercept-mode off
```
This removes the intercept rules and reverts to standard interface-based DNS configuration.
### Platform Support
| Platform | Supported | Mechanism |
|----------|-----------|-----------|
| Windows | ✅ | NRPT (Name Resolution Policy Table) |
| macOS | ✅ | pf (packet filter) redirect |
| Linux | ❌ | Not currently supported |
### Features
- **VPN split routing** — VPN-specific domains are automatically detected and forwarded to the VPN's DNS server
- **Captive portal recovery** — Wi-Fi login pages (hotels, airports, coffee shops) work automatically
- **No network adapter changes** — DNS settings stay untouched, eliminating conflicts entirely
- **Automatic port 53 conflict resolution** — if another process (e.g., `mDNSResponder` on macOS) is already using port 53, `ctrld` automatically listens on a different port. OS-level packet interception redirects all DNS traffic to `ctrld` transparently, so no manual configuration is needed. This only applies to intercept mode.
### Tested VPN Software
- F5 BIG-IP APM
- Cisco AnyConnect
- Palo Alto GlobalProtect
- Tailscale (including Exit Nodes)
- Windscribe
- WireGuard
For more details, see the [DNS Intercept Mode documentation](https://docs.controld.com/docs/dns-intercept).
## Contributing
See [Contribution Guideline](./docs/contributing.md)

View File

@@ -11,12 +11,14 @@ import (
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
"slices"
"sort"
"strconv"
"strings"
"syscall"
"time"
"github.com/docker/go-units"
@@ -146,6 +148,88 @@ func initLogCmd() *cobra.Command {
fmt.Println(logs.Data)
},
}
var tailLines int
logTailCmd := &cobra.Command{
Use: "tail",
Short: "Tail live runtime debug logs",
Long: "Stream live runtime debug logs to the terminal, similar to tail -f. Press Ctrl+C to stop.",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
Run: func(cmd *cobra.Command, args []string) {
p := &prog{router: router.New(&cfg, false)}
s, _ := newService(p, svcConfig)
status, err := s.Status()
if errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("service not installed")
return
}
if status == service.StatusStopped {
mainLog.Load().Warn().Msg("service is not running")
return
}
dir, err := socketDir()
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to find ctrld home dir")
}
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
tailPath := fmt.Sprintf("%s?lines=%d", tailLogsPath, tailLines)
resp, err := cc.postStream(tailPath, nil)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to connect for log tailing")
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusMovedPermanently:
warnRuntimeLoggingNotEnabled()
return
case http.StatusOK:
default:
mainLog.Load().Fatal().Msgf("unexpected response status: %d", resp.StatusCode)
return
}
// Set up signal handling for clean shutdown.
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
done := make(chan struct{})
go func() {
defer close(done)
// Stream output to stdout.
buf := make([]byte, 4096)
for {
n, readErr := resp.Body.Read(buf)
if n > 0 {
os.Stdout.Write(buf[:n])
}
if readErr != nil {
if readErr != io.EOF {
mainLog.Load().Error().Err(readErr).Msg("error reading log stream")
}
return
}
}
}()
select {
case <-ctx.Done():
if errors.Is(ctx.Err(), context.Canceled) {
msg := fmt.Sprintf("\nexiting: %s\n", context.Cause(ctx).Error())
os.Stdout.WriteString(msg)
}
case <-done:
}
},
}
logTailCmd.Flags().IntVarP(&tailLines, "lines", "n", 10, "Number of historical lines to show on connect")
logCmd := &cobra.Command{
Use: "log",
Short: "Manage runtime debug logs",
@@ -156,6 +240,7 @@ func initLogCmd() *cobra.Command {
}
logCmd.AddCommand(logSendCmd)
logCmd.AddCommand(logViewCmd)
logCmd.AddCommand(logTailCmd)
rootCmd.AddCommand(logCmd)
return logCmd

View File

@@ -32,6 +32,12 @@ func (c *controlClient) post(path string, data io.Reader) (*http.Response, error
return c.c.Post("http://unix"+path, contentTypeJson, data)
}
// postStream sends a POST request with no timeout, suitable for long-lived streaming connections.
func (c *controlClient) postStream(path string, data io.Reader) (*http.Response, error) {
c.c.Timeout = 0
return c.c.Post("http://unix"+path, contentTypeJson, data)
}
// deactivationRequest represents request for validating deactivation pin.
type deactivationRequest struct {
Pin int64 `json:"pin"`

View File

@@ -10,6 +10,7 @@ import (
"os"
"reflect"
"sort"
"strconv"
"time"
"github.com/kardianos/service"
@@ -29,6 +30,7 @@ const (
ifacePath = "/iface"
viewLogsPath = "/log/view"
sendLogsPath = "/log/send"
tailLogsPath = "/log/tail"
)
type ifaceResponse struct {
@@ -344,6 +346,170 @@ func (p *prog) registerControlServerHandler() {
}
p.internalLogSent = time.Now()
}))
p.cs.register(tailLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming unsupported", http.StatusInternalServerError)
return
}
// Determine logging mode and validate before starting the stream.
var lw *logWriter
useInternalLog := p.needInternalLogging()
if useInternalLog {
p.mu.Lock()
lw = p.internalLogWriter
p.mu.Unlock()
if lw == nil {
w.WriteHeader(http.StatusMovedPermanently)
return
}
} else if p.cfg.Service.LogPath == "" {
// No logging configured at all.
w.WriteHeader(http.StatusMovedPermanently)
return
}
// Parse optional "lines" query param for initial context.
numLines := 10
if v := request.URL.Query().Get("lines"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n >= 0 {
numLines = n
}
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK)
if useInternalLog {
// Internal logging mode: subscribe to the logWriter.
// Send last N lines as initial context.
if numLines > 0 {
if tail := lw.tailLastLines(numLines); len(tail) > 0 {
w.Write(tail)
flusher.Flush()
}
}
ch, unsub := lw.Subscribe()
defer unsub()
for {
select {
case data, ok := <-ch:
if !ok {
return
}
if _, err := w.Write(data); err != nil {
return
}
flusher.Flush()
case <-request.Context().Done():
return
}
}
} else {
// File-based logging mode: tail the log file.
logFile := normalizeLogFilePath(p.cfg.Service.LogPath)
f, err := os.Open(logFile)
if err != nil {
// Already committed 200, just return.
return
}
defer f.Close()
// Seek to show last N lines.
if numLines > 0 {
if tail := tailFileLastLines(f, numLines); len(tail) > 0 {
w.Write(tail)
flusher.Flush()
}
} else {
// Seek to end.
f.Seek(0, io.SeekEnd)
}
// Poll for new data.
buf := make([]byte, 4096)
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
n, err := f.Read(buf)
if n > 0 {
if _, werr := w.Write(buf[:n]); werr != nil {
return
}
flusher.Flush()
}
if err != nil && err != io.EOF {
return
}
case <-request.Context().Done():
return
}
}
}
}))
}
// tailFileLastLines reads the last n lines from a file and returns them.
// The file position is left at the end of the file after this call.
func tailFileLastLines(f *os.File, n int) []byte {
stat, err := f.Stat()
if err != nil || stat.Size() == 0 {
return nil
}
// Read from the end in chunks to find the last n lines.
const chunkSize = 4096
fileSize := stat.Size()
var lines []byte
offset := fileSize
count := 0
for offset > 0 && count <= n {
readSize := int64(chunkSize)
if readSize > offset {
readSize = offset
}
offset -= readSize
buf := make([]byte, readSize)
nRead, err := f.ReadAt(buf, offset)
if err != nil && err != io.EOF {
break
}
buf = buf[:nRead]
lines = append(buf, lines...)
// Count newlines in this chunk.
for _, b := range buf {
if b == '\n' {
count++
}
}
}
// Trim to last n lines.
idx := 0
nlCount := 0
for i := len(lines) - 1; i >= 0; i-- {
if lines[i] == '\n' {
nlCount++
if nlCount == n+1 {
idx = i + 1
break
}
}
}
lines = lines[idx:]
// Seek to end of file for subsequent reads.
f.Seek(0, io.SeekEnd)
return lines
}
func jsonResponse(next http.Handler) http.Handler {

339
cmd/cli/log_tail_test.go Normal file
View File

@@ -0,0 +1,339 @@
package cli
import (
"io"
"os"
"strings"
"sync"
"testing"
"time"
)
// =============================================================================
// logWriter.tailLastLines tests
// =============================================================================
func Test_logWriter_tailLastLines_Empty(t *testing.T) {
lw := newLogWriterWithSize(4096)
if got := lw.tailLastLines(10); got != nil {
t.Fatalf("expected nil for empty buffer, got %q", got)
}
}
func Test_logWriter_tailLastLines_ZeroLines(t *testing.T) {
lw := newLogWriterWithSize(4096)
lw.Write([]byte("line1\nline2\n"))
if got := lw.tailLastLines(0); got != nil {
t.Fatalf("expected nil for n=0, got %q", got)
}
}
func Test_logWriter_tailLastLines_NegativeLines(t *testing.T) {
lw := newLogWriterWithSize(4096)
lw.Write([]byte("line1\nline2\n"))
if got := lw.tailLastLines(-1); got != nil {
t.Fatalf("expected nil for n=-1, got %q", got)
}
}
func Test_logWriter_tailLastLines_FewerThanN(t *testing.T) {
lw := newLogWriterWithSize(4096)
lw.Write([]byte("line1\nline2\n"))
got := string(lw.tailLastLines(10))
want := "line1\nline2\n"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_logWriter_tailLastLines_ExactN(t *testing.T) {
lw := newLogWriterWithSize(4096)
lw.Write([]byte("line1\nline2\nline3\n"))
got := string(lw.tailLastLines(3))
want := "line1\nline2\nline3\n"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_logWriter_tailLastLines_MoreThanN(t *testing.T) {
lw := newLogWriterWithSize(4096)
lw.Write([]byte("line1\nline2\nline3\nline4\nline5\n"))
got := string(lw.tailLastLines(2))
want := "line4\nline5\n"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_logWriter_tailLastLines_NoTrailingNewline(t *testing.T) {
lw := newLogWriterWithSize(4096)
lw.Write([]byte("line1\nline2\nline3"))
// Without trailing newline, "line3" is a partial line.
// Asking for 1 line returns the last newline-terminated line plus the partial.
got := string(lw.tailLastLines(1))
want := "line2\nline3"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_logWriter_tailLastLines_SingleLineNoNewline(t *testing.T) {
lw := newLogWriterWithSize(4096)
lw.Write([]byte("only line"))
got := string(lw.tailLastLines(5))
want := "only line"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_logWriter_tailLastLines_SingleLineWithNewline(t *testing.T) {
lw := newLogWriterWithSize(4096)
lw.Write([]byte("only line\n"))
got := string(lw.tailLastLines(1))
want := "only line\n"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
// =============================================================================
// logWriter.Subscribe tests
// =============================================================================
func Test_logWriter_Subscribe_Basic(t *testing.T) {
lw := newLogWriterWithSize(4096)
ch, unsub := lw.Subscribe()
defer unsub()
msg := []byte("hello world\n")
lw.Write(msg)
select {
case got := <-ch:
if string(got) != string(msg) {
t.Fatalf("got %q, want %q", got, msg)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for subscriber data")
}
}
func Test_logWriter_Subscribe_MultipleSubscribers(t *testing.T) {
lw := newLogWriterWithSize(4096)
ch1, unsub1 := lw.Subscribe()
defer unsub1()
ch2, unsub2 := lw.Subscribe()
defer unsub2()
msg := []byte("broadcast\n")
lw.Write(msg)
for i, ch := range []<-chan []byte{ch1, ch2} {
select {
case got := <-ch:
if string(got) != string(msg) {
t.Fatalf("subscriber %d: got %q, want %q", i, got, msg)
}
case <-time.After(time.Second):
t.Fatalf("subscriber %d: timed out", i)
}
}
}
func Test_logWriter_Subscribe_Unsubscribe(t *testing.T) {
lw := newLogWriterWithSize(4096)
ch, unsub := lw.Subscribe()
// Verify subscribed.
lw.Write([]byte("before unsub\n"))
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timed out before unsub")
}
unsub()
// Channel should be closed after unsub.
if _, ok := <-ch; ok {
t.Fatal("channel should be closed after unsubscribe")
}
// Verify subscriber list is empty.
lw.mu.Lock()
count := len(lw.subscribers)
lw.mu.Unlock()
if count != 0 {
t.Fatalf("expected 0 subscribers after unsub, got %d", count)
}
}
func Test_logWriter_Subscribe_UnsubscribeIdempotent(t *testing.T) {
lw := newLogWriterWithSize(4096)
_, unsub := lw.Subscribe()
unsub()
// Second unsub should not panic.
unsub()
}
func Test_logWriter_Subscribe_SlowSubscriberDropped(t *testing.T) {
lw := newLogWriterWithSize(4096)
ch, unsub := lw.Subscribe()
defer unsub()
// Fill the subscriber channel (buffer size is 256).
for i := 0; i < 300; i++ {
lw.Write([]byte("msg\n"))
}
// Should have 256 buffered messages, rest dropped.
count := 0
for {
select {
case <-ch:
count++
default:
goto done
}
}
done:
if count != 256 {
t.Fatalf("expected 256 buffered messages, got %d", count)
}
}
func Test_logWriter_Subscribe_ConcurrentWriteAndRead(t *testing.T) {
lw := newLogWriterWithSize(64 * 1024)
ch, unsub := lw.Subscribe()
defer unsub()
const numWrites = 100
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < numWrites; i++ {
lw.Write([]byte("concurrent write\n"))
}
}()
received := 0
timeout := time.After(5 * time.Second)
for received < numWrites {
select {
case <-ch:
received++
case <-timeout:
t.Fatalf("timed out after receiving %d/%d messages", received, numWrites)
}
}
wg.Wait()
}
// =============================================================================
// tailFileLastLines tests
// =============================================================================
func writeTempFile(t *testing.T, content string) *os.File {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "tail-test-*")
if err != nil {
t.Fatal(err)
}
if _, err := f.WriteString(content); err != nil {
t.Fatal(err)
}
return f
}
func Test_tailFileLastLines_Empty(t *testing.T) {
f := writeTempFile(t, "")
defer f.Close()
if got := tailFileLastLines(f, 10); got != nil {
t.Fatalf("expected nil for empty file, got %q", got)
}
}
func Test_tailFileLastLines_FewerThanN(t *testing.T) {
f := writeTempFile(t, "line1\nline2\n")
defer f.Close()
got := string(tailFileLastLines(f, 10))
want := "line1\nline2\n"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_tailFileLastLines_ExactN(t *testing.T) {
f := writeTempFile(t, "a\nb\nc\n")
defer f.Close()
got := string(tailFileLastLines(f, 3))
want := "a\nb\nc\n"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_tailFileLastLines_MoreThanN(t *testing.T) {
f := writeTempFile(t, "line1\nline2\nline3\nline4\nline5\n")
defer f.Close()
got := string(tailFileLastLines(f, 2))
want := "line4\nline5\n"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_tailFileLastLines_NoTrailingNewline(t *testing.T) {
f := writeTempFile(t, "line1\nline2\nline3")
defer f.Close()
// Without trailing newline, partial last line comes with the previous line.
got := string(tailFileLastLines(f, 1))
want := "line2\nline3"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
}
func Test_tailFileLastLines_LargerThanChunk(t *testing.T) {
// Build content larger than the 4096 chunk size to exercise multi-chunk reads.
var sb strings.Builder
for i := 0; i < 200; i++ {
sb.WriteString(strings.Repeat("x", 50))
sb.WriteByte('\n')
}
f := writeTempFile(t, sb.String())
defer f.Close()
got := string(tailFileLastLines(f, 3))
lines := strings.Split(strings.TrimRight(got, "\n"), "\n")
if len(lines) != 3 {
t.Fatalf("expected 3 lines, got %d: %q", len(lines), got)
}
expectedLine := strings.Repeat("x", 50)
for _, line := range lines {
if line != expectedLine {
t.Fatalf("unexpected line content: %q", line)
}
}
}
func Test_tailFileLastLines_SeeksToEnd(t *testing.T) {
f := writeTempFile(t, "line1\nline2\nline3\n")
defer f.Close()
tailFileLastLines(f, 1)
// After tailFileLastLines, file position should be at the end.
pos, err := f.Seek(0, io.SeekCurrent)
if err != nil {
t.Fatal(err)
}
stat, err := f.Stat()
if err != nil {
t.Fatal(err)
}
if pos != stat.Size() {
t.Fatalf("expected file position at end (%d), got %d", stat.Size(), pos)
}
}

View File

@@ -38,11 +38,17 @@ type logReader struct {
size int64
}
// logSubscriber represents a subscriber to live log output.
type logSubscriber struct {
ch chan []byte
}
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
type logWriter struct {
mu sync.Mutex
buf bytes.Buffer
size int
mu sync.Mutex
buf bytes.Buffer
size int
subscribers []*logSubscriber
}
// newLogWriter creates an internal log writer.
@@ -61,10 +67,70 @@ func newLogWriterWithSize(size int) *logWriter {
return lw
}
// Subscribe returns a channel that receives new log data as it's written,
// and an unsubscribe function to clean up when done.
func (lw *logWriter) Subscribe() (<-chan []byte, func()) {
lw.mu.Lock()
defer lw.mu.Unlock()
sub := &logSubscriber{ch: make(chan []byte, 256)}
lw.subscribers = append(lw.subscribers, sub)
unsub := func() {
lw.mu.Lock()
defer lw.mu.Unlock()
for i, s := range lw.subscribers {
if s == sub {
lw.subscribers = append(lw.subscribers[:i], lw.subscribers[i+1:]...)
close(sub.ch)
break
}
}
}
return sub.ch, unsub
}
// tailLastLines returns the last n lines from the current buffer.
func (lw *logWriter) tailLastLines(n int) []byte {
lw.mu.Lock()
defer lw.mu.Unlock()
data := lw.buf.Bytes()
if n <= 0 || len(data) == 0 {
return nil
}
// Find the last n newlines from the end.
count := 0
pos := len(data)
for pos > 0 {
pos--
if data[pos] == '\n' {
count++
if count == n+1 {
pos++ // move past this newline
break
}
}
}
result := make([]byte, len(data)-pos)
copy(result, data[pos:])
return result
}
func (lw *logWriter) Write(p []byte) (int, error) {
lw.mu.Lock()
defer lw.mu.Unlock()
// Fan-out to subscribers (non-blocking).
if len(lw.subscribers) > 0 {
cp := make([]byte, len(p))
copy(cp, p)
for _, sub := range lw.subscribers {
select {
case sub.ch <- cp:
default:
// Drop if subscriber is slow to avoid blocking the logger.
}
}
}
// If writing p causes overflows, discard old data.
if lw.buf.Len()+len(p) > lw.size {
buf := lw.buf.Bytes()

View File

@@ -160,6 +160,7 @@ func hasLocalDnsServerRunning() bool {
if e != nil {
return false
}
defer windows.CloseHandle(h)
p := windows.ProcessEntry32{Size: processEntrySize}
for {
e := windows.Process32Next(h, &p)

View File

@@ -4,6 +4,7 @@ package ctrld
import (
"net"
"net/netip"
"slices"
"time"
@@ -17,6 +18,31 @@ func currentNameserversFromResolvconf() []string {
return resolvconffile.NameServers()
}
// localNameservers filters a list of nameserver strings, returning only those
// that are not loopback or local machine IP addresses.
func localNameservers(nss []string, regularIPs, loopbackIPs []netip.Addr) []string {
var result []string
seen := make(map[string]bool)
for _, ns := range nss {
if ip := net.ParseIP(ns); ip != nil {
// skip loopback and local IPs
isLocal := false
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
if ip.String() == v.String() {
isLocal = true
break
}
}
if !isLocal && !seen[ip.String()] {
seen[ip.String()] = true
result = append(result, ip.String())
}
}
}
return result
}
// dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file.
// A nameserver is usable if it's not one of current machine's IP addresses
// and loopback IP addresses.
@@ -35,24 +61,7 @@ func dnsFromResolvConf() []string {
}
nss := resolvconffile.NameServers()
var localDNS []string
seen := make(map[string]bool)
for _, ns := range nss {
if ip := net.ParseIP(ns); ip != nil {
// skip loopback IPs
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
ipStr := v.String()
if ip.String() == ipStr {
continue
}
}
if !seen[ip.String()] {
seen[ip.String()] = true
localDNS = append(localDNS, ip.String())
}
}
}
localDNS := localNameservers(nss, regularIPs, loopbackIPs)
// If we successfully read the file and found nameservers, return them
if len(localDNS) > 0 {

105
nameservers_unix_test.go Normal file
View File

@@ -0,0 +1,105 @@
//go:build unix
package ctrld
import (
"net/netip"
"testing"
)
func Test_localNameservers(t *testing.T) {
loopbackIPs := []netip.Addr{
netip.MustParseAddr("127.0.0.1"),
netip.MustParseAddr("::1"),
}
regularIPs := []netip.Addr{
netip.MustParseAddr("192.168.1.100"),
netip.MustParseAddr("10.0.0.5"),
}
tests := []struct {
name string
nss []string
regularIPs []netip.Addr
loopbackIPs []netip.Addr
want []string
}{
{
name: "filters loopback IPv4",
nss: []string{"127.0.0.1", "8.8.8.8"},
regularIPs: nil,
loopbackIPs: loopbackIPs,
want: []string{"8.8.8.8"},
},
{
name: "filters loopback IPv6",
nss: []string{"::1", "1.1.1.1"},
regularIPs: nil,
loopbackIPs: loopbackIPs,
want: []string{"1.1.1.1"},
},
{
name: "filters local machine IPs",
nss: []string{"192.168.1.100", "8.8.4.4"},
regularIPs: regularIPs,
loopbackIPs: nil,
want: []string{"8.8.4.4"},
},
{
name: "filters both loopback and local IPs",
nss: []string{"127.0.0.1", "192.168.1.100", "8.8.8.8"},
regularIPs: regularIPs,
loopbackIPs: loopbackIPs,
want: []string{"8.8.8.8"},
},
{
name: "deduplicates results",
nss: []string{"8.8.8.8", "8.8.8.8", "1.1.1.1"},
regularIPs: regularIPs,
loopbackIPs: loopbackIPs,
want: []string{"8.8.8.8", "1.1.1.1"},
},
{
name: "all filtered returns nil",
nss: []string{"127.0.0.1", "::1", "192.168.1.100"},
regularIPs: regularIPs,
loopbackIPs: loopbackIPs,
want: nil,
},
{
name: "empty input returns nil",
nss: nil,
regularIPs: regularIPs,
loopbackIPs: loopbackIPs,
want: nil,
},
{
name: "skips unparseable entries",
nss: []string{"not-an-ip", "8.8.8.8"},
regularIPs: regularIPs,
loopbackIPs: loopbackIPs,
want: []string{"8.8.8.8"},
},
{
name: "no local IPs filters nothing",
nss: []string{"8.8.8.8", "1.1.1.1"},
regularIPs: nil,
loopbackIPs: nil,
want: []string{"8.8.8.8", "1.1.1.1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := localNameservers(tt.nss, tt.regularIPs, tt.loopbackIPs)
if len(got) != len(tt.want) {
t.Fatalf("localNameservers() = %v, want %v", got, tt.want)
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("localNameservers()[%d] = %q, want %q", i, got[i], tt.want[i])
}
}
})
}
}