Compare commits

..

13 Commits

Author SHA1 Message Date
Cuong Manh Le
d0341497d1 Merge pull request #276 from Control-D-Inc/release-branch-v1.4.9
Release branch v1.4.9
2026-01-13 21:41:48 +07:00
Cuong Manh Le
27c5be43c2 fix(system): disable ghw warnings to reduce log noise
Disable warnings from ghw library when retrieving chassis information.
These warnings are undesirable but recoverable errors that emit unnecessary
log messages. Using WithDisableWarnings() suppresses them while maintaining
functionality.
2026-01-09 15:10:29 +07:00
Cuong Manh Le
3beffd0dc8 .github/workflows: temporary use actions/setup-go
Since WillAbides/setup-go-faster failed with macOS-latest.

See: https://github.com/WillAbides/setup-go-faster/issues/37
2025-12-18 17:10:43 +07:00
Cuong Manh Le
1f9c586444 docs: add documentation for runtime internal logging 2025-12-18 17:10:43 +07:00
Cuong Manh Le
a92e1ca024 Upgrade quic-go to v0.57.1 2025-12-18 17:10:43 +07:00
Cuong Manh Le
705df72110 fix: remove incorrect transport close on DoH3 error
Remove the transport Close() call from DoH3 error handling path.
The transport is shared and reused across requests, and closing it
on error would break subsequent requests. The transport lifecycle
is already properly managed by the http.Client and the finalizer
set in newDOH3Transport().
2025-12-18 17:10:43 +07:00
Cuong Manh Le
22122c45b2 Including system metadata when posting to utility API 2025-12-18 17:10:39 +07:00
Cuong Manh Le
57a9bb9fab Merge pull request #268 from Control-D-Inc/release-branch-v1.4.8
Release branch v1.4.8
2025-12-02 21:39:38 +07:00
Cuong Manh Le
78ea2d6361 .github/workflows: upgrade staticcheck-action to v1.4.0
While at it, also bump go version to 1.24
2025-11-12 15:22:01 +07:00
Cuong Manh Le
df3cf7ef62 Upgrade quic-go to v0.56.0 2025-11-12 15:15:16 +07:00
Cuong Manh Le
80e652b8d9 fix: ensure log and cache flags are processed during reload
During reload operations, log and cache flags were not being processed,
which prevented runtime internal logs from working correctly. To fix this,
processLogAndCacheFlags was refactored to accept explicit viper and config
parameters instead of relying on global state, enabling it to be called
during reload with the new configuration. This ensures that log and cache
settings are properly applied when the service reloads its configuration.
2025-11-12 15:15:05 +07:00
Cuong Manh Le
091c7edb19 Fix: Filter root domain from search domains on Linux
Remove empty and root domain (".") entries from search domains list
to prevent systemd-resolved errors. This addresses the issue where
systemd doesn't allow root domain in search domains configuration.

The filtering ensures only valid search domains are passed to
systemd-resolved, preventing DNS operation failures.
2025-11-12 15:14:40 +07:00
Cuong Manh Le
6c550b1d74 Upgrade quic-go to v0.55.0
While at it, also bump required go version to 1.24
2025-11-12 15:14:26 +07:00
178 changed files with 8615 additions and 7466 deletions

View File

@@ -11,6 +11,7 @@ A highly configurable DNS forwarding proxy with support for:
- Multiple upstreams with fallbacks
- Multiple network policy driven DNS query steering (via network cidr, MAC address or FQDN)
- Policy driven domain based "split horizon" DNS with wildcard support
- Integrations with common router vendors and firmware
- LAN client discovery via DHCP, mDNS, ARP, NDP, hosts file parsing
- Prometheus metrics exporter
@@ -25,17 +26,35 @@ All DNS protocols are supported, including:
- `DNS-over-QUIC`
# Use Cases
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy OSes, TVs, smart toasters).
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy routers, legacy OSes, TVs, smart toasters).
2. Create source IP based DNS routing policies with variable secure DNS upstreams. Subnet 1 (admin) uses upstream resolver A, while Subnet 2 (employee) uses upstream resolver B.
3. Create destination IP based DNS routing policies with variable secure DNS upstreams. Listener 1 uses upstream resolver C, while Listener 2 uses upstream resolver D.
4. Create domain level "split horizon" DNS routing policies to send internal domains (*.company.int) to a local DNS server, while everything else goes to another upstream.
5. Deploy on a router and create LAN client specific DNS routing policies from a web GUI (When using ControlD.com).
## OS Support
- Windows Desktop (386, amd64, arm)
- Windows (386, amd64, arm)
- Windows Server (386, amd64)
- MacOS (amd64, arm64)
- Linux (386, amd64, arm, mips)
- FreeBSD (386, amd64, arm)
- Common routers (See below)
### Supported Routers
You can run `ctrld` on any supported router. The list of supported routers and firmware includes:
- Asus Merlin
- DD-WRT
- Firewalla
- FreshTomato
- GL.iNet
- OpenWRT
- pfSense / OPNsense
- Synology
- Ubiquiti (UniFi, EdgeOS)
`ctrld` will attempt to interface with dnsmasq (or Windows Server) whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly.
# Install
There are several ways to download and install `ctrld`.
@@ -44,12 +63,12 @@ There are several ways to download and install `ctrld`.
The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform:
```shell
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl?version=2)"'
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"'
```
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative PowerShell:
```shell
(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1?version=2' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'"
(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'"
```
Or you can pull and run a Docker container from [Docker Hub](https://hub.docker.com/r/controldns/ctrld)
@@ -61,7 +80,7 @@ docker run -d --name=ctrld -p 127.0.0.1:53:53/tcp -p 127.0.0.1:53:53/udp control
Alternatively, if you know what you're doing you can download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section for the appropriate platform.
## Build
Lastly, you can build `ctrld` from source which requires `go1.23+`:
Lastly, you can build `ctrld` from source which requires `go1.21+`:
```shell
go build ./cmd/ctrld
@@ -111,7 +130,7 @@ Available Commands:
Flags:
-h, --help help for ctrld
-s, --silent do not write any log output
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug logging
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug level logging
--version version for ctrld
Use "ctrld [command] --help" for more information about a command.
@@ -142,7 +161,9 @@ You can then run a test query using a DNS client, for example, `dig`:
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file that was generated. To enforce a new config, restart the server.
## Service Mode
This mode will run the application as a background system service on any Windows, MacOS, Linux or FreeBSD distribution. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot.
This mode will run the application as a background system service on any Windows, MacOS, Linux, FreeBSD distribution or supported router. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot.
When Control D upstreams are used on a router type device, `ctrld` will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
### Command
@@ -179,7 +200,7 @@ Linux or Macos
`ctrld` can be configured in variety of different ways, which include: API, local config file or via cli launch args.
## API Based Auto Configuration
Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. This mode is used when the 1 liner installer command from the Control D onboarding guide is executed.
Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. In this mode, the application will automatically choose a non-conflicting IP and/or port and configure itself as the upstream to whatever process is running on port 53 (like dnsmasq or Windows DNS Server). This mode is used when the 1 liner installer command from the Control D onboarding guide is executed.
The following command will use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Endpoint.
@@ -196,7 +217,7 @@ sudo ctrld start --cd abcd1234
Once you run the above command, the following things will happen:
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
- All physical network interface will be updated to use the listener started by the service
- All physical network interface will be updated to use the listener started by the service or dnsmasq upstream will be switched to `ctrld`
- All DNS queries will be sent to the listener
## Manual Configuration

4
client_info_darwin.go Normal file
View File

@@ -0,0 +1,4 @@
package ctrld
// SelfDiscover reports whether ctrld should only do self discover.
func SelfDiscover() bool { return true }

6
client_info_others.go Normal file
View File

@@ -0,0 +1,6 @@
//go:build !windows && !darwin
package ctrld
// SelfDiscover reports whether ctrld should only do self discover.
func SelfDiscover() bool { return false }

18
client_info_windows.go Normal file
View File

@@ -0,0 +1,18 @@
package ctrld
import (
"golang.org/x/sys/windows"
)
// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine.
func isWindowsWorkStation() bool {
// From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa
const VER_NT_WORKSTATION = 0x0000001
osvi := windows.RtlGetVersion()
return osvi.ProductType == VER_NT_WORKSTATION
}
// SelfDiscover reports whether ctrld should only do self discover.
func SelfDiscover() bool {
return isWindowsWorkStation()
}

View File

@@ -8,3 +8,8 @@ import (
// addExtraSplitDnsRule adds split DNS rule if present.
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
// getActiveDirectoryDomain returns AD domain name of this computer.
func getActiveDirectoryDomain() (string, error) {
return "", nil
}

View File

@@ -10,17 +10,18 @@ import (
hh "github.com/microsoft/wmi/pkg/hardware/host"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/system"
)
// addExtraSplitDnsRule adds split DNS rule for domain if it's part of active directory.
func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
domain, err := getActiveDirectoryDomain()
domain, err := system.GetActiveDirectoryDomain()
if err != nil {
mainLog.Load().Debug().Msgf("Unable to get active directory domain: %v", err)
mainLog.Load().Debug().Msgf("unable to get active directory domain: %v", err)
return false
}
if domain == "" {
mainLog.Load().Debug().Msg("No active directory domain found")
mainLog.Load().Debug().Msg("no active directory domain found")
return false
}
// Network rules are lowercase during toml config marshaling,
@@ -40,11 +41,11 @@ func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
}
for _, rule := range lc.Policy.Rules {
if _, ok := rule[domain]; ok {
mainLog.Load().Debug().Msgf("Split-rule %q already existed for listener.%s", domain, n)
mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n)
return false
}
}
mainLog.Load().Debug().Msgf("Adding split-rule %q for listener.%s", domain, n)
mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n)
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
}
return true

View File

@@ -5,14 +5,16 @@ import (
"testing"
"time"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/testhelper"
"github.com/stretchr/testify/assert"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/system"
"github.com/Control-D-Inc/ctrld/testhelper"
)
func Test_getActiveDirectoryDomain(t *testing.T) {
start := time.Now()
domain, err := getActiveDirectoryDomain()
domain, err := system.GetActiveDirectoryDomain()
if err != nil {
t.Fatal(err)
}

File diff suppressed because it is too large Load Diff

1397
cmd/cli/commands.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,141 +0,0 @@
package cli
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"github.com/kardianos/service"
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
)
// ClientsCommand handles clients-related operations
type ClientsCommand struct {
controlClient *controlClient
}
// NewClientsCommand creates a new clients command handler
func NewClientsCommand() (*ClientsCommand, error) {
dir, err := socketDir()
if err != nil {
return nil, fmt.Errorf("failed to find ctrld home dir: %w", err)
}
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
return &ClientsCommand{
controlClient: cc,
}, nil
}
// ListClients lists all connected clients
func (cc *ClientsCommand) ListClients(cmd *cobra.Command, args []string) error {
// Check service status first
sc := NewServiceCommand()
s, _, err := sc.initializeServiceManager()
if err != nil {
return err
}
status, err := s.Status()
if errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("Service not installed")
return nil
}
if status == service.StatusStopped {
mainLog.Load().Warn().Msg("Service is not running")
return nil
}
resp, err := cc.controlClient.post(listClientsPath, nil)
if err != nil {
return fmt.Errorf("failed to get clients: %w", err)
}
defer resp.Body.Close()
var clients []*clientinfo.Client
if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil {
return fmt.Errorf("failed to decode clients result: %w", err)
}
map2Slice := func(m map[string]struct{}) []string {
s := make([]string, 0, len(m))
for k := range m {
if k == "" { // skip empty source from output.
continue
}
s = append(s, k)
}
sort.Strings(s)
return s
}
// If metrics is enabled, server set this for all clients, so we can check only the first one.
// Ideally, we may have a field in response to indicate that query count should be shown, but
// it would break earlier version of ctrld, which only look list of clients in response.
withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount
data := make([][]string, len(clients))
for i, c := range clients {
row := []string{
c.IP.String(),
c.Hostname,
c.Mac,
strings.Join(map2Slice(c.Source), ","),
}
if withQueryCount {
row = append(row, strconv.FormatInt(c.QueryCount, 10))
}
data[i] = row
}
table := tablewriter.NewWriter(os.Stdout)
headers := []string{"IP", "Hostname", "Mac", "Discovered"}
if withQueryCount {
headers = append(headers, "Queries")
}
table.SetHeader(headers)
table.SetAutoFormatHeaders(false)
table.AppendBulk(data)
table.Render()
return nil
}
// InitClientsCmd creates the clients command with proper logic
func InitClientsCmd(rootCmd *cobra.Command) *cobra.Command {
listClientsCmd := &cobra.Command{
Use: "list",
Short: "List clients that ctrld discovered",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: func(cmd *cobra.Command, args []string) error {
cc, err := NewClientsCommand()
if err != nil {
return err
}
return cc.ListClients(cmd, args)
},
}
clientsCmd := &cobra.Command{
Use: "clients",
Short: "Manage clients",
Args: cobra.OnlyValidArgs,
ValidArgs: []string{
listClientsCmd.Use,
},
}
clientsCmd.AddCommand(listClientsCmd)
rootCmd.AddCommand(clientsCmd)
return clientsCmd
}

View File

@@ -1,87 +0,0 @@
package cli
import (
"fmt"
"net"
"github.com/spf13/cobra"
)
// InterfacesCommand handles interfaces-related operations
type InterfacesCommand struct{}
// NewInterfacesCommand creates a new interfaces command handler
func NewInterfacesCommand() (*InterfacesCommand, error) {
return &InterfacesCommand{}, nil
}
// ListInterfaces lists all network interfaces
func (ic *InterfacesCommand) ListInterfaces(cmd *cobra.Command, args []string) error {
withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error {
fmt.Printf("Index : %d\n", i.Index)
fmt.Printf("Name : %s\n", i.Name)
var status string
if i.Flags&net.FlagUp != 0 {
status = "Up"
} else {
status = "Down"
}
fmt.Printf("Status: %s\n", status)
addrs, _ := i.Addrs()
for i, ipaddr := range addrs {
if i == 0 {
fmt.Printf("Addrs : %v\n", ipaddr)
continue
}
fmt.Printf(" %v\n", ipaddr)
}
nss, err := currentStaticDNS(i)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Failed to get DNS")
}
if len(nss) == 0 {
nss = currentDNS(i)
}
for i, dns := range nss {
if i == 0 {
fmt.Printf("DNS : %s\n", dns)
continue
}
fmt.Printf(" : %s\n", dns)
}
println()
return nil
})
return nil
}
// InitInterfacesCmd creates the interfaces command with proper logic
func InitInterfacesCmd(_ *cobra.Command) *cobra.Command {
listInterfacesCmd := &cobra.Command{
Use: "list",
Short: "List network interfaces",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: func(cmd *cobra.Command, args []string) error {
ic, err := NewInterfacesCommand()
if err != nil {
return err
}
return ic.ListInterfaces(cmd, args)
},
}
interfacesCmd := &cobra.Command{
Use: "interfaces",
Short: "Manage network interfaces",
Args: cobra.OnlyValidArgs,
ValidArgs: []string{
listInterfacesCmd.Use,
},
}
interfacesCmd.AddCommand(listInterfacesCmd)
return interfacesCmd
}

View File

@@ -1,175 +0,0 @@
package cli
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"path/filepath"
"github.com/docker/go-units"
"github.com/kardianos/service"
"github.com/spf13/cobra"
)
// LogCommand handles log-related operations
type LogCommand struct {
controlClient *controlClient
}
// NewLogCommand creates a new log command handler
func NewLogCommand() (*LogCommand, error) {
dir, err := socketDir()
if err != nil {
return nil, fmt.Errorf("failed to find ctrld home dir: %w", err)
}
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
return &LogCommand{
controlClient: cc,
}, nil
}
// warnRuntimeLoggingNotEnabled logs a warning about runtime logging not being enabled
func (lc *LogCommand) warnRuntimeLoggingNotEnabled() {
mainLog.Load().Warn().Msg("Runtime debug logging is not enabled")
mainLog.Load().Warn().Msg(`ctrld may be running without "--cd" flag or logging is already enabled`)
}
// SendLogs sends runtime debug logs to ControlD
func (lc *LogCommand) SendLogs(cmd *cobra.Command, args []string) error {
sc := NewServiceCommand()
s, _, err := sc.initializeServiceManager()
if err != nil {
return err
}
status, err := s.Status()
if errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("Service not installed")
return nil
}
if status == service.StatusStopped {
mainLog.Load().Warn().Msg("Service is not running")
return nil
}
resp, err := lc.controlClient.post(sendLogsPath, nil)
if err != nil {
return fmt.Errorf("failed to send logs: %w", err)
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusServiceUnavailable:
mainLog.Load().Warn().Msg("Runtime logs could only be sent once per minute")
return nil
case http.StatusMovedPermanently:
lc.warnRuntimeLoggingNotEnabled()
return nil
}
var logs logSentResponse
if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil {
return fmt.Errorf("failed to decode sent logs result: %w", err)
}
if logs.Error != "" {
return fmt.Errorf("failed to send logs: %s", logs.Error)
}
mainLog.Load().Notice().Msgf("Sent %s of runtime logs", units.BytesSize(float64(logs.Size)))
return nil
}
// ViewLogs views current runtime debug logs
func (lc *LogCommand) ViewLogs(cmd *cobra.Command, args []string) error {
sc := NewServiceCommand()
s, _, err := sc.initializeServiceManager()
if err != nil {
return err
}
status, err := s.Status()
if errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("Service not installed")
return nil
}
if status == service.StatusStopped {
mainLog.Load().Warn().Msg("Service is not running")
return nil
}
resp, err := lc.controlClient.post(viewLogsPath, nil)
if err != nil {
return fmt.Errorf("failed to get logs: %w", err)
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusMovedPermanently:
lc.warnRuntimeLoggingNotEnabled()
return nil
case http.StatusBadRequest:
mainLog.Load().Warn().Msg("Runtime debug logs are not available")
buf, err := io.ReadAll(resp.Body)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("Failed to read response body")
}
mainLog.Load().Warn().Msgf("ctrld process response:\n\n%s\n", string(buf))
return nil
case http.StatusOK:
}
var logs logViewResponse
if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil {
return fmt.Errorf("failed to decode view logs result: %w", err)
}
fmt.Print(logs.Data)
return nil
}
// InitLogCmd creates the log command with proper logic
func InitLogCmd(rootCmd *cobra.Command) *cobra.Command {
lc, err := NewLogCommand()
if err != nil {
panic(fmt.Sprintf("failed to create log command: %v", err))
}
logSendCmd := &cobra.Command{
Use: "send",
Short: "Send runtime debug logs to ControlD",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: lc.SendLogs,
}
logViewCmd := &cobra.Command{
Use: "view",
Short: "View current runtime debug logs",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: lc.ViewLogs,
}
logCmd := &cobra.Command{
Use: "log",
Short: "Manage runtime debug logs",
Args: cobra.OnlyValidArgs,
ValidArgs: []string{
logSendCmd.Use,
logViewCmd.Use,
},
}
logCmd.AddCommand(logSendCmd)
logCmd.AddCommand(logViewCmd)
rootCmd.AddCommand(logCmd)
return logCmd
}

View File

@@ -1,59 +0,0 @@
package cli
import (
"github.com/spf13/cobra"
"github.com/Control-D-Inc/ctrld"
)
// RunCommand handles run-related operations
type RunCommand struct {
// Add any dependencies here if needed in the future
}
// NewRunCommand creates a new run command handler
func NewRunCommand() *RunCommand {
return &RunCommand{}
}
// Run implements the logic for the run command
func (rc *RunCommand) Run(cmd *cobra.Command, args []string) {
RunCobraCommand(cmd)
}
// InitRunCmd creates the run command with proper logic
func InitRunCmd(rootCmd *cobra.Command) *cobra.Command {
rc := NewRunCommand()
runCmd := &cobra.Command{
Use: "run",
Short: "Run the DNS proxy server",
Args: cobra.NoArgs,
Run: rc.Run,
}
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = runCmd.Flags().MarkHidden("dev")
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
_ = runCmd.Flags().MarkHidden("homedir")
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
_ = runCmd.Flags().MarkHidden("iface")
runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
runCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true}
rootCmd.AddCommand(runCmd)
return runCmd
}

View File

@@ -1,256 +0,0 @@
package cli
import (
"fmt"
"os"
"runtime"
"github.com/kardianos/service"
"github.com/spf13/cobra"
)
// filterEmptyStrings removes empty strings from a slice
// This is used to clean up command line arguments and configuration values
func filterEmptyStrings(slice []string) []string {
var result []string
for _, s := range slice {
if s != "" {
result = append(result, s)
}
}
return result
}
// ServiceCommand handles service-related operations
// This encapsulates all service management functionality for the CLI
type ServiceCommand struct {
serviceManager *ServiceManager
}
// initializeServiceManager creates a service manager with default configuration
// This sets up the basic service infrastructure needed for all service operations
func (sc *ServiceCommand) initializeServiceManager() (service.Service, *prog, error) {
svcConfig := sc.createServiceConfig()
return sc.initializeServiceManagerWithServiceConfig(svcConfig)
}
// initializeServiceManagerWithServiceConfig creates a service manager with the given configuration
// This allows for custom service configuration while maintaining the same initialization pattern
func (sc *ServiceCommand) initializeServiceManagerWithServiceConfig(svcConfig *service.Config) (service.Service, *prog, error) {
p := &prog{}
s, err := sc.newService(p, svcConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create service: %w", err)
}
sc.serviceManager = &ServiceManager{prog: p, svc: s}
return s, p, nil
}
// newService creates a new service instance using the provided program and configuration.
// This abstracts the service creation process for different operating systems
func (sc *ServiceCommand) newService(p *prog, svcConfig *service.Config) (service.Service, error) {
s, err := newService(p, svcConfig)
if err != nil {
return nil, fmt.Errorf("failed to create service: %w", err)
}
return s, nil
}
// NewServiceCommand creates a new service command handler
// This provides a clean factory method for creating service command instances
func NewServiceCommand() *ServiceCommand {
return &ServiceCommand{}
}
// createServiceConfig creates a properly initialized service configuration
// This ensures consistent service naming and description across all platforms
func (sc *ServiceCommand) createServiceConfig() *service.Config {
return &service.Config{
Name: ctrldServiceName,
DisplayName: "Control-D Helper Service",
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
Option: service.KeyValue{},
}
}
// InitServiceCmd creates the service command with proper logic and aliases
// This sets up all service-related subcommands with appropriate permissions and flags
func InitServiceCmd(rootCmd *cobra.Command) *cobra.Command {
// Create service command handlers
sc := NewServiceCommand()
startCmd, startCmdAlias := createStartCommands(sc)
rootCmd.AddCommand(startCmdAlias)
// Stop command
stopCmd := &cobra.Command{
Use: "stop",
Short: "Stop the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: sc.Stop,
}
stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`)
_ = stopCmd.Flags().MarkHidden("pin")
// Restart command
restartCmd := &cobra.Command{
Use: "restart",
Short: "Restart the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: sc.Restart,
}
// Status command
statusCmd := &cobra.Command{
Use: "status",
Short: "Show status of the ctrld service",
Args: cobra.NoArgs,
RunE: sc.Status,
}
if runtime.GOOS == "darwin" {
// On darwin, running status command without privileges may return wrong information.
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
}
}
// Reload command
reloadCmd := &cobra.Command{
Use: "reload",
Short: "Reload the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: sc.Reload,
}
// Uninstall command
uninstallCmd := &cobra.Command{
Use: "uninstall",
Short: "Stop and uninstall the ctrld service",
Long: `Stop and uninstall the ctrld service.
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: sc.Uninstall,
}
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`)
_ = uninstallCmd.Flags().MarkHidden("pin")
uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`)
// Interfaces command - use the existing InitInterfacesCmd function
interfacesCmd := InitInterfacesCmd(rootCmd)
stopCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
Use: "stop",
Short: "Quick stop service and remove DNS from interface",
RunE: func(cmd *cobra.Command, args []string) error {
if !cmd.Flags().Changed("iface") {
os.Args = append(os.Args, "--iface="+ifaceStartStop)
}
iface = ifaceStartStop
return stopCmd.RunE(cmd, args)
},
}
stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
rootCmd.AddCommand(stopCmdAlias)
// Create aliases for other service commands
restartCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
Use: "restart",
Short: "Restart the ctrld service",
RunE: func(cmd *cobra.Command, args []string) error {
return restartCmd.RunE(cmd, args)
},
}
rootCmd.AddCommand(restartCmdAlias)
reloadCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
Use: "reload",
Short: "Reload the ctrld service",
RunE: func(cmd *cobra.Command, args []string) error {
return reloadCmd.RunE(cmd, args)
},
}
rootCmd.AddCommand(reloadCmdAlias)
statusCmdAlias := &cobra.Command{
Use: "status",
Short: "Show status of the ctrld service",
Args: cobra.NoArgs,
RunE: statusCmd.RunE,
}
rootCmd.AddCommand(statusCmdAlias)
uninstallCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
Use: "uninstall",
Short: "Stop and uninstall the ctrld service",
Long: `Stop and uninstall the ctrld service.
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
RunE: func(cmd *cobra.Command, args []string) error {
if !cmd.Flags().Changed("iface") {
os.Args = append(os.Args, "--iface="+ifaceStartStop)
}
iface = ifaceStartStop
return uninstallCmd.RunE(cmd, args)
},
}
uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags())
rootCmd.AddCommand(uninstallCmdAlias)
// Create service command
serviceCmd := &cobra.Command{
Use: "service",
Short: "Manage ctrld service",
Args: cobra.OnlyValidArgs,
}
serviceCmd.ValidArgs = make([]string, 7)
serviceCmd.ValidArgs[0] = startCmd.Use
serviceCmd.ValidArgs[1] = stopCmd.Use
serviceCmd.ValidArgs[2] = restartCmd.Use
serviceCmd.ValidArgs[3] = reloadCmd.Use
serviceCmd.ValidArgs[4] = statusCmd.Use
serviceCmd.ValidArgs[5] = uninstallCmd.Use
serviceCmd.ValidArgs[6] = interfacesCmd.Use
serviceCmd.AddCommand(startCmd)
serviceCmd.AddCommand(stopCmd)
serviceCmd.AddCommand(restartCmd)
serviceCmd.AddCommand(reloadCmd)
serviceCmd.AddCommand(statusCmd)
serviceCmd.AddCommand(uninstallCmd)
serviceCmd.AddCommand(interfacesCmd)
rootCmd.AddCommand(serviceCmd)
return serviceCmd
}

View File

@@ -1,41 +0,0 @@
package cli
import (
"fmt"
"time"
"github.com/kardianos/service"
)
// dialSocketControlServerTimeout is the default timeout to wait when ping control server.
const dialSocketControlServerTimeout = 30 * time.Second
// ServiceManager handles service operations
type ServiceManager struct {
prog *prog
svc service.Service
}
// NewServiceManager creates a new service manager
func NewServiceManager() (*ServiceManager, error) {
p := &prog{}
// Create a proper service configuration
svcConfig := &service.Config{
Name: ctrldServiceName,
DisplayName: "Control-D Helper Service",
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
Option: service.KeyValue{},
}
s, err := newService(p, svcConfig)
if err != nil {
return nil, fmt.Errorf("failed to create service: %w", err)
}
return &ServiceManager{prog: p, svc: s}, nil
}
// Status returns the current service status
func (sm *ServiceManager) Status() (service.Status, error) {
return sm.svc.Status()
}

View File

@@ -1,67 +0,0 @@
package cli
import (
"errors"
"io"
"net/http"
"path/filepath"
"github.com/kardianos/service"
"github.com/spf13/cobra"
)
// Reload implements the logic from cmdReload.Run
func (sc *ServiceCommand) Reload(cmd *cobra.Command, args []string) error {
logger := mainLog.Load()
logger.Debug().Msg("Service reload command started")
s, _, err := sc.initializeServiceManager()
if err != nil {
logger.Error().Err(err).Msg("Failed to initialize service manager")
return err
}
status, err := s.Status()
if errors.Is(err, service.ErrNotInstalled) {
logger.Warn().Msg("Service not installed")
return nil
}
if status == service.StatusStopped {
logger.Warn().Msg("Service is not running")
return nil
}
dir, err := socketDir()
if err != nil {
logger.Fatal().Err(err).Msg("Failed to find ctrld home dir")
}
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
resp, err := cc.post(reloadPath, nil)
if err != nil {
logger.Fatal().Err(err).Msg("Failed to send reload signal to ctrld")
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
logger.Notice().Msg("Service reloaded")
case http.StatusCreated:
logger.Warn().Msg("Service was reloaded, but new config requires service restart.")
logger.Warn().Msg("Restarting service")
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
logger.Warn().Msg("Service not installed")
return nil
}
return sc.Restart(cmd, args)
default:
buf, err := io.ReadAll(resp.Body)
if err != nil {
logger.Fatal().Err(err).Msg("Could not read response from control server")
}
logger.Error().Err(err).Msgf("Failed to reload ctrld: %s", string(buf))
}
logger.Debug().Msg("Service reload command completed")
return nil
}

View File

@@ -1,111 +0,0 @@
package cli
import (
"context"
"errors"
"time"
"github.com/kardianos/service"
"github.com/spf13/cobra"
)
// Restart implements the logic from cmdRestart.Run
func (sc *ServiceCommand) Restart(cmd *cobra.Command, args []string) error {
logger := mainLog.Load()
logger.Debug().Msg("Service restart command started")
readConfig(false)
v.Unmarshal(&cfg)
cdUID = curCdUID()
cdMode := cdUID != ""
s, p, err := sc.initializeServiceManager()
if err != nil {
logger.Error().Err(err).Msg("Failed to initialize service manager")
return err
}
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
logger.Warn().Msg("Service not installed")
return nil
}
p.cfg = &cfg
if iface == "" {
iface = "auto"
}
p.preRun()
if ir := runningIface(s); ir != nil {
p.runningIface = ir.Name
p.requiredMultiNICsConfig = ir.All
}
initInteractiveLogging()
var validateConfigErr error
if cdMode {
logger.Debug().Msg("Validating ControlD remote config")
validateConfigErr = doValidateCdRemoteConfig(cdUID, false)
if validateConfigErr != nil {
logger.Warn().Err(validateConfigErr).Msg("ControlD remote config validation failed")
}
}
if ir := runningIface(s); ir != nil {
iface = ir.Name
}
doRestart := func() bool {
logger.Debug().Msg("Starting service restart sequence")
tasks := []task{
{s.Stop, true, "Stop"},
{func() error {
// restore static DNS settings or DHCP
p.resetDNS(false, true)
return nil
}, false, "Cleanup"},
{func() error {
time.Sleep(time.Second * 1)
return nil
}, false, "Waiting for service to stop"},
}
if !doTasks(tasks) {
logger.Error().Msg("Service stop tasks failed")
return false
}
tasks = []task{
{s.Start, true, "Start"},
}
success := doTasks(tasks)
if success {
logger.Debug().Msg("Service restart sequence completed successfully")
} else {
logger.Error().Msg("Service restart sequence failed")
}
return success
}
if doRestart() {
if dir, err := socketDir(); err == nil {
timeout := dialSocketControlServerTimeout
if validateConfigErr != nil {
timeout = 5 * time.Second
}
if cc := newSocketControlClientWithTimeout(context.TODO(), s, dir, timeout); cc != nil {
_, _ = cc.post(ifacePath, nil)
logger.Debug().Msg("Control server ping successful")
} else {
logger.Warn().Err(err).Msg("Service was restarted, but ctrld process may not be ready yet")
}
} else {
logger.Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
}
logger.Notice().Msg("Service restarted")
} else {
logger.Error().Msg("Service restart failed")
}
logger.Debug().Msg("Service restart command completed")
return nil
}

View File

@@ -1,387 +0,0 @@
package cli
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"time"
"github.com/kardianos/service"
"github.com/spf13/cobra"
"github.com/Control-D-Inc/ctrld"
)
// Start implements the logic from cmdStart.Run
func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
logger := mainLog.Load()
logger.Debug().Msg("Service start command started")
checkStrFlagEmpty(cmd, cdUidFlagName)
checkStrFlagEmpty(cmd, cdOrgFlagName)
validateCdAndNextDNSFlags()
svcConfig := sc.createServiceConfig()
osArgs := os.Args[2:]
osArgs = filterEmptyStrings(osArgs)
if os.Args[1] == "service" {
osArgs = os.Args[3:]
}
setDependencies(svcConfig)
svcConfig.Arguments = append([]string{"run"}, osArgs...)
// Initialize service manager with proper configuration
s, p, err := sc.initializeServiceManagerWithServiceConfig(svcConfig)
if err != nil {
logger.Error().Err(err).Msg("Failed to initialize service manager")
return err
}
p.cfg = &cfg
p.preRun()
status, err := s.Status()
isCtrldRunning := status == service.StatusRunning
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
// Get current running iface, if any.
var currentIface *ifaceResponse
// If pin code was set, do not allow running start command.
if isCtrldRunning {
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
logger.Error().Msg("Deactivation pin check failed")
os.Exit(deactivationPinInvalidExitCode)
}
currentIface = runningIface(s)
logger.Debug().Msgf("Current interface on start: %v", currentIface)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
reportSetDnsOk := func(sockDir string) {
if cc := newSocketControlClient(ctx, s, sockDir); cc != nil {
if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK {
if iface == "auto" {
iface = defaultIfaceName()
}
res := &ifaceResponse{}
if err := json.NewDecoder(resp.Body).Decode(res); err != nil {
logger.Warn().Err(err).Msg("Failed to get iface info")
return
}
if res.OK {
name := res.Name
if iff, err := net.InterfaceByName(name); err == nil {
_, _ = patchNetIfaceName(iff)
name = iff.Name
}
logger := logger.With().Str("iface", name)
logger.Debug().Msg("Setting DNS successfully")
if res.All {
// Log that DNS is set for other interfaces.
withEachPhysicalInterfaces(
name,
"set DNS",
func(i *net.Interface) error { return nil },
)
}
}
}
}
}
// No config path, generating config in HOME directory.
noConfigStart := isNoConfigStart(cmd)
writeDefaultConfig := !noConfigStart && configBase64 == ""
logServerStarted := make(chan struct{})
stopLogCh := make(chan struct{})
ud, err := userHomeDir()
sockDir := ud
var logServerSocketPath string
if err != nil {
logger.Warn().Err(err).Msg("Failed to get user home directory")
logger.Warn().Msg("Log server did not start")
close(logServerStarted)
} else {
setWorkingDirectory(svcConfig, ud)
if configPath == "" && writeDefaultConfig {
defaultConfigFile = filepath.Join(ud, defaultConfigFile)
}
svcConfig.Arguments = append(svcConfig.Arguments, "--homedir="+ud)
if d, err := socketDir(); err == nil {
sockDir = d
}
logServerSocketPath = filepath.Join(sockDir, ctrldLogUnixSock)
_ = os.Remove(logServerSocketPath)
go func() {
defer os.Remove(logServerSocketPath)
close(logServerStarted)
// Start HTTP log server
if err := httpLogServer(logServerSocketPath, stopLogCh); err != nil && err != http.ErrServerClosed {
logger.Warn().Err(err).Msg("Failed to serve HTTP log server")
return
}
}()
}
<-logServerStarted
if !startOnly {
startOnly = len(osArgs) == 0
}
// If user run "ctrld start" and ctrld is already installed, starting existing service.
if startOnly && isCtrldInstalled {
tryReadingConfigWithNotice(false, true)
if err := v.Unmarshal(&cfg); err != nil {
logger.Fatal().Msgf("Failed to unmarshal config: %v", err)
}
// if already running, dont restart
if isCtrldRunning {
logger.Notice().Msg("Service is already running")
return nil
}
initInteractiveLogging()
tasks := []task{
{func() error {
// Save current DNS so we can restore later.
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
return err
}
return nil
})
return nil
}, false, "Save current DNS"},
{func() error {
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
}, false, "Configure service failure actions"},
{s.Start, true, "Start"},
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
}
logger.Notice().Msg("Starting existing ctrld service")
if doTasks(tasks) {
logger.Notice().Msg("Service started")
sockDir, err := socketDir()
if err != nil {
logger.Warn().Err(err).Msg("Failed to get socket directory")
os.Exit(1)
}
reportSetDnsOk(sockDir)
} else {
logger.Error().Err(err).Msg("Failed to start existing ctrld service")
os.Exit(1)
}
return nil
}
if cdUID != "" {
_ = doValidateCdRemoteConfig(cdUID, true)
} else if uid := cdUIDFromProvToken(); uid != "" {
cdUID = uid
logger.Debug().Msg("Using uid from provision token")
removeOrgFlagsFromArgs(svcConfig)
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
svcConfig.Arguments = append(svcConfig.Arguments, "--cd="+cdUID)
}
if cdUID != "" {
validateCdUpstreamProtocol()
}
if configPath != "" {
v.SetConfigFile(configPath)
}
tryReadingConfigWithNotice(writeDefaultConfig, true)
if err := v.Unmarshal(&cfg); err != nil {
logger.Fatal().Msgf("Failed to unmarshal config: %v", err)
}
initInteractiveLogging()
if nextdns != "" {
removeNextDNSFromArgs(svcConfig)
}
// Explicitly passing config, so on system where home directory could not be obtained,
// or sub-process env is different with the parent, we still behave correctly and use
// the expected config file.
if configPath == "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--config="+defaultConfigFile)
}
tasks := []task{
{s.Stop, false, "Stop"},
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
//resetDnsTask(p, s, isCtrldInstalled, currentIface),
{func() error {
// Save current DNS so we can restore later.
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
return err
}
return nil
})
return nil
}, false, "Save current DNS"},
{s.Install, false, "Install"},
{func() error {
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
}, false, "Configure Windows service failure actions"},
{s.Start, true, "Start"},
// Note that startCmd do not actually write ControlD config, but the config file was
// generated after s.Start, so we notice users here for consistent with nextdns mode.
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
}
logger.Notice().Msg("Starting service")
if doTasks(tasks) {
// add a small delay to ensure the service is started and did not crash
time.Sleep(1 * time.Second)
ok, status, err := selfCheckStatus(ctx, s, sockDir)
switch {
case ok && status == service.StatusRunning:
logger.Notice().Msg("Service started")
default:
marker := append(bytes.Repeat([]byte("="), 32), '\n')
// If ctrld service is not running, emitting log obtained from ctrld process.
if status != service.StatusRunning || ctx.Err() != nil {
logger.Error().Msg("Ctrld service may not have started due to an error or misconfiguration, service log:")
_, _ = logger.Write(marker)
// Wait for log collection to complete
<-stopLogCh
// Retrieve logs from HTTP server if available
if logServerSocketPath != "" {
hlc := newHTTPLogClient(logServerSocketPath)
logs, err := hlc.GetLogs()
if err != nil {
logger.Warn().Err(err).Msg("Failed to get logs from HTTP log server")
}
if len(logs) == 0 {
logger.Write([]byte("<no log output is obtained from ctrld process>\n"))
} else {
logger.Write(logs)
logger.Write([]byte("\n"))
}
} else {
logger.Write([]byte("<no log output from HTTP log server>\n"))
}
}
// Report any error if occurred.
if err != nil {
_, _ = logger.Write(marker)
msg := fmt.Sprintf("An error occurred while performing test query: %s\n", err)
logger.Write([]byte(msg))
}
// If ctrld service is running but selfCheckStatus failed, it could be related
// to user's system firewall configuration, notice users about it.
if status == service.StatusRunning && err == nil {
_, _ = logger.Write(marker)
logger.Write([]byte("ctrld service was running, but a DNS query could not be sent to its listener\n"))
logger.Write([]byte("Please check your system firewall if it is configured to block/intercept/redirect DNS queries\n"))
}
_, _ = logger.Write(marker)
uninstall(p, s)
os.Exit(1)
}
reportSetDnsOk(sockDir)
}
logger.Debug().Msg("Service start command completed")
return nil
}
// createStartCommands creates the start command and its alias
func createStartCommands(sc *ServiceCommand) (*cobra.Command, *cobra.Command) {
// Start command
startCmd := &cobra.Command{
Use: "start",
Short: "Install and start the ctrld service",
Long: `Install and start the ctrld service
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
Args: func(cmd *cobra.Command, args []string) error {
args = filterEmptyStrings(args)
if len(args) > 0 {
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
}
return nil
},
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: sc.Start,
}
// Keep these flags in sync with runCmd above, except for "-d"/"--nextdns".
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
_ = startCmd.Flags().MarkHidden("dev")
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`)
startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service")
_ = startCmd.Flags().MarkHidden("start_only")
startCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
// Start command alias
startCmdAlias := &cobra.Command{
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
Use: "start",
Short: "Quick start service and configure DNS on interface",
Long: `Quick start service and configure DNS on interface
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
Args: func(cmd *cobra.Command, args []string) error {
args = filterEmptyStrings(args)
if len(args) > 0 {
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
if len(os.Args) == 2 {
startOnly = true
}
if !cmd.Flags().Changed("iface") {
os.Args = append(os.Args, "--iface="+ifaceStartStop)
}
iface = ifaceStartStop
return startCmd.RunE(cmd, args)
},
}
startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`)
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
return startCmd, startCmdAlias
}

View File

@@ -1,41 +0,0 @@
package cli
import (
"os"
"github.com/kardianos/service"
"github.com/spf13/cobra"
)
// Status implements the logic from cmdStatus.Run
func (sc *ServiceCommand) Status(cmd *cobra.Command, args []string) error {
logger := mainLog.Load()
logger.Debug().Msg("Service status command started")
s, _, err := sc.initializeServiceManager()
if err != nil {
logger.Error().Err(err).Msg("Failed to initialize service manager")
return err
}
status, err := s.Status()
if err != nil {
logger.Error().Msg(err.Error())
os.Exit(1)
}
switch status {
case service.StatusUnknown:
logger.Notice().Msg("Unknown status")
os.Exit(2)
case service.StatusRunning:
logger.Notice().Msg("Service is running")
os.Exit(0)
case service.StatusStopped:
logger.Notice().Msg("Service is stopped")
os.Exit(1)
}
logger.Debug().Msg("Service status command completed")
return nil
}

View File

@@ -1,61 +0,0 @@
package cli
import (
"errors"
"os"
"github.com/kardianos/service"
"github.com/spf13/cobra"
)
// Stop implements the logic from cmdStop.Run
func (sc *ServiceCommand) Stop(cmd *cobra.Command, args []string) error {
logger := mainLog.Load()
logger.Debug().Msg("Service stop command started")
readConfig(false)
v.Unmarshal(&cfg)
s, p, err := sc.initializeServiceManager()
if err != nil {
logger.Error().Err(err).Msg("Failed to initialize service manager")
return err
}
p.cfg = &cfg
if iface == "" {
iface = "auto"
}
p.preRun()
if ir := runningIface(s); ir != nil {
p.runningIface = ir.Name
p.requiredMultiNICsConfig = ir.All
}
initInteractiveLogging()
status, err := s.Status()
if errors.Is(err, service.ErrNotInstalled) {
logger.Warn().Msg("Service not installed")
return nil
}
if status == service.StatusStopped {
logger.Warn().Msg("Service is already stopped")
return nil
}
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
logger.Error().Msg("Deactivation pin check failed")
os.Exit(deactivationPinInvalidExitCode)
}
logger.Debug().Msg("Stopping service")
if doTasks([]task{{s.Stop, true, "Stop"}}) {
logger.Notice().Msg("Service stopped")
} else {
logger.Error().Msg("Service stop failed")
}
logger.Debug().Msg("Service stop command completed")
return nil
}

View File

@@ -1,106 +0,0 @@
package cli
import (
"net"
"os"
"path/filepath"
"github.com/spf13/cobra"
"github.com/Control-D-Inc/ctrld"
)
// Uninstall implements the logic from cmdUninstall.Run
func (sc *ServiceCommand) Uninstall(cmd *cobra.Command, args []string) error {
logger := mainLog.Load()
logger.Debug().Msg("Service uninstall command started")
readConfig(false)
v.Unmarshal(&cfg)
s, p, err := sc.initializeServiceManager()
if err != nil {
logger.Error().Err(err).Msg("Failed to initialize service manager")
return err
}
p.cfg = &cfg
if iface == "" {
iface = "auto"
}
p.preRun()
if ir := runningIface(s); ir != nil {
p.runningIface = ir.Name
p.requiredMultiNICsConfig = ir.All
}
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
logger.Error().Msg("Deactivation pin check failed")
os.Exit(deactivationPinInvalidExitCode)
}
logger.Debug().Msg("Starting service uninstall")
uninstall(p, s)
if cleanup {
logger.Debug().Msg("Performing cleanup operations")
var files []string
// Config file.
files = append(files, v.ConfigFileUsed())
// Log file and backup log file.
// For safety, only process if log file path is absolute.
if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) {
files = append(files, logFile)
oldLogFile := logFile + oldLogSuffix
if _, err := os.Stat(oldLogFile); err == nil {
files = append(files, oldLogFile)
}
}
// Socket files.
if dir, _ := socketDir(); dir != "" {
files = append(files, filepath.Join(dir, ctrldControlUnixSock))
files = append(files, filepath.Join(dir, ctrldLogUnixSock))
}
// Static DNS settings files.
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
file := ctrld.SavedStaticDnsSettingsFilePath(i)
files = append(files, file)
return nil
})
bin, err := os.Executable()
if err != nil {
logger.Warn().Err(err).Msg("Failed to get executable path")
}
if bin != "" && supportedSelfDelete {
files = append(files, bin)
}
// Backup file after upgrading.
oldBin := bin + oldBinSuffix
if _, err := os.Stat(oldBin); err == nil {
files = append(files, oldBin)
}
for _, file := range files {
if file == "" {
continue
}
if err := os.Remove(file); err == nil {
logger.Notice().Str("file", file).Msg("File removed during cleanup")
} else {
logger.Debug().Err(err).Str("file", file).Msg("Failed to remove file during cleanup")
}
}
// Self-delete the ctrld binary if supported
if err := selfDeleteExe(); err != nil {
logger.Warn().Err(err).Msg("Failed to delete ctrld binary")
} else {
if !supportedSelfDelete {
logger.Debug().Msgf("File removed: %s", bin)
}
}
logger.Debug().Msg("Cleanup operations completed")
}
logger.Debug().Msg("Service uninstall command completed")
return nil
}

View File

@@ -1,197 +0,0 @@
package cli
import (
"bytes"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestBasicCommandStructure tests the actual root command structure
func TestBasicCommandStructure(t *testing.T) {
// Test the actual root command that's returned from initCLI()
rootCmd := initCLI()
// Test that root command has basic properties
assert.Equal(t, "ctrld", rootCmd.Use)
assert.NotEmpty(t, rootCmd.Short, "Root command should have a short description")
// Test that root command has subcommands
commands := rootCmd.Commands()
assert.NotNil(t, commands, "Root command should have subcommands")
assert.Greater(t, len(commands), 0, "Root command should have at least one subcommand")
// Test that expected commands exist
expectedCommands := []string{"run", "service", "clients", "upgrade", "log"}
for _, cmdName := range expectedCommands {
found := false
for _, cmd := range commands {
if cmd.Name() == cmdName {
found = true
break
}
}
assert.True(t, found, "Expected command %s not found in root command", cmdName)
}
}
// TestServiceCommandCreation tests service command creation
func TestServiceCommandCreation(t *testing.T) {
sc := NewServiceCommand()
require.NotNil(t, sc, "ServiceCommand should be created")
// Test service config creation
config := sc.createServiceConfig()
require.NotNil(t, config, "Service config should be created")
assert.Equal(t, ctrldServiceName, config.Name)
assert.Equal(t, "Control-D Helper Service", config.DisplayName)
assert.Equal(t, "A highly configurable, multi-protocol DNS forwarding proxy", config.Description)
}
// TestServiceCommandSubCommands tests service command sub commands
func TestServiceCommandSubCommands(t *testing.T) {
rootCmd := &cobra.Command{
Use: "ctrld",
Short: "DNS forwarding proxy",
}
serviceCmd := InitServiceCmd(rootCmd)
require.NotNil(t, serviceCmd, "Service command should be created")
// Test that service command has subcommands
subcommands := serviceCmd.Commands()
assert.Greater(t, len(subcommands), 0, "Service command should have subcommands")
// Test specific subcommands exist
expectedCommands := []string{"start", "stop", "restart", "reload", "status", "uninstall", "interfaces"}
for _, cmdName := range expectedCommands {
found := false
for _, cmd := range subcommands {
if cmd.Name() == cmdName {
found = true
break
}
}
assert.True(t, found, "Expected service subcommand %s not found", cmdName)
}
}
// TestCommandHelp tests basic help functionality
func TestCommandHelp(t *testing.T) {
// Initialize the CLI to set up the root command
rootCmd := initCLI()
// Test help command execution
var buf bytes.Buffer
rootCmd.SetOut(&buf)
rootCmd.SetErr(&buf)
rootCmd.SetArgs([]string{"--help"})
err := rootCmd.Execute()
assert.NoError(t, err, "Help command should execute without error")
assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description")
}
// TestCommandVersion tests version command
func TestCommandVersion(t *testing.T) {
// Initialize the CLI to set up the root command
rootCmd := initCLI()
var buf bytes.Buffer
rootCmd.SetOut(&buf)
rootCmd.SetErr(&buf)
// Test version command
rootCmd.SetArgs([]string{"--version"})
err := rootCmd.Execute()
assert.NoError(t, err, "Version command should execute without error")
assert.Contains(t, buf.String(), "version", "Version output should contain version information")
}
// TestCommandErrorHandling tests error handling
func TestCommandErrorHandling(t *testing.T) {
// Initialize the CLI to set up the root command
rootCmd := initCLI()
// Test invalid flag instead of invalid command
rootCmd.SetArgs([]string{"--invalid-flag"})
err := rootCmd.Execute()
assert.Error(t, err, "Invalid flag should return error")
}
// TestCommandFlags tests flag functionality
func TestCommandFlags(t *testing.T) {
// Initialize the CLI to set up the root command
rootCmd := initCLI()
// Test that root command has expected flags
verboseFlag := rootCmd.PersistentFlags().Lookup("verbose")
assert.NotNil(t, verboseFlag, "Verbose flag should exist")
assert.Equal(t, "v", verboseFlag.Shorthand)
silentFlag := rootCmd.PersistentFlags().Lookup("silent")
assert.NotNil(t, silentFlag, "Silent flag should exist")
assert.Equal(t, "s", silentFlag.Shorthand)
}
// TestCommandExecution tests basic command execution
func TestCommandExecution(t *testing.T) {
// Initialize the CLI to set up the root command
rootCmd := initCLI()
// Test that root command can be executed (help command)
var buf bytes.Buffer
rootCmd.SetOut(&buf)
rootCmd.SetErr(&buf)
rootCmd.SetArgs([]string{"--help"})
err := rootCmd.Execute()
assert.NoError(t, err, "Root command should execute without error")
assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description")
}
// TestCommandArgs tests argument handling
func TestCommandArgs(t *testing.T) {
// Initialize the CLI to set up the root command
rootCmd := initCLI()
// Test that root command can handle arguments properly
// Test with no args (should succeed)
err := rootCmd.Execute()
assert.NoError(t, err, "Root command with no args should execute")
// Test with help flag (should succeed)
rootCmd.SetArgs([]string{"--help"})
err = rootCmd.Execute()
assert.NoError(t, err, "Root command with help flag should execute")
}
// TestCommandSubcommands tests subcommand functionality
func TestCommandSubcommands(t *testing.T) {
// Initialize the CLI to set up the root command
rootCmd := initCLI()
// Test that root command has subcommands
commands := rootCmd.Commands()
assert.Greater(t, len(commands), 0, "Root command should have subcommands")
// Test that specific subcommands exist and can be executed
expectedSubcommands := []string{"run", "service", "clients", "upgrade", "log"}
for _, subCmdName := range expectedSubcommands {
// Find the subcommand
var subCmd *cobra.Command
for _, cmd := range commands {
if cmd.Name() == subCmdName {
subCmd = cmd
break
}
}
assert.NotNil(t, subCmd, "Subcommand %s should exist", subCmdName)
// Test that subcommand has help
assert.NotEmpty(t, subCmd.Short, "Subcommand %s should have a short description", subCmdName)
}
}

View File

@@ -1,192 +0,0 @@
package cli
import (
"context"
"errors"
"net/http"
"os"
"os/exec"
"strings"
"time"
"github.com/kardianos/service"
"github.com/minio/selfupdate"
"github.com/spf13/cobra"
)
const (
upgradeChannelDev = "dev"
upgradeChannelProd = "prod"
upgradeChannelDefault = "default"
)
// UpgradeCommand handles upgrade-related operations
type UpgradeCommand struct {
}
// NewUpgradeCommand creates a new upgrade command handler
func NewUpgradeCommand() (*UpgradeCommand, error) {
return &UpgradeCommand{}, nil
}
// Upgrade performs the upgrade operation
func (uc *UpgradeCommand) Upgrade(cmd *cobra.Command, args []string) error {
upgradeChannel := map[string]string{
upgradeChannelDefault: "https://dl.controld.dev",
upgradeChannelDev: "https://dl.controld.dev",
upgradeChannelProd: "https://dl.controld.com",
}
if isStableVersion(curVersion()) {
upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd]
}
bin, err := os.Executable()
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("Failed to get current ctrld binary path")
}
readConfig(false)
v.Unmarshal(&cfg)
svcCmd := NewServiceCommand()
s, p, err := svcCmd.initializeServiceManager()
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return nil
}
if iface == "" {
iface = "auto"
}
p.preRun()
if ir := runningIface(s); ir != nil {
p.runningIface = ir.Name
p.requiredMultiNICsConfig = ir.All
}
svcInstalled := true
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
svcInstalled = false
}
oldBin := bin + oldBinSuffix
baseUrl := upgradeChannel[upgradeChannelDefault]
if len(args) > 0 {
channel := args[0]
switch channel {
case upgradeChannelProd, upgradeChannelDev: // ok
default:
mainLog.Load().Fatal().Msgf("Upgrade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev)
}
baseUrl = upgradeChannel[channel]
}
dlUrl := upgradeUrl(baseUrl)
mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl)
resp, err := getWithRetry(dlUrl, downloadServerIp)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("Failed to download binary")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
mainLog.Load().Fatal().Msgf("Could not download binary: %s", http.StatusText(resp.StatusCode))
}
mainLog.Load().Debug().Msg("Updating current binary")
if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil {
if rerr := selfupdate.RollbackError(err); rerr != nil {
mainLog.Load().Error().Err(rerr).Msg("Could not rollback old binary")
}
mainLog.Load().Fatal().Err(err).Msg("Failed to update current binary")
}
doRestart := func() bool {
if !svcInstalled {
return true
}
tasks := []task{
{s.Stop, true, "Stop"},
{func() error {
// restore static DNS settings or DHCP
p.resetDNS(false, true)
return nil
}, false, "Cleanup"},
{func() error {
time.Sleep(time.Second * 1)
return nil
}, false, "Waiting for service to stop"},
}
doTasks(tasks)
tasks = []task{
{s.Start, true, "Start"},
}
if doTasks(tasks) {
if dir, err := socketDir(); err == nil {
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
_, _ = cc.post(ifacePath, nil)
return true
}
}
}
return false
}
if svcInstalled {
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
}
if doRestart() {
_ = os.Remove(oldBin)
_ = os.Chmod(bin, 0755)
ver := "unknown version"
out, err := exec.Command(bin, "--version").CombinedOutput()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version")
}
if after, found := strings.CutPrefix(string(out), "ctrld version "); found {
ver = after
}
mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver)
return nil
}
mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin)
if err := os.Remove(bin); err != nil {
mainLog.Load().Fatal().Err(err).Msg("Failed to remove new binary")
}
if err := os.Rename(oldBin, bin); err != nil {
mainLog.Load().Fatal().Err(err).Msg("Failed to restore old binary")
}
if doRestart() {
mainLog.Load().Notice().Msg("Restored previous binary successfully")
return nil
}
return nil
}
// InitUpgradeCmd creates the upgrade command with proper logic
func InitUpgradeCmd(rootCmd *cobra.Command) *cobra.Command {
upgradeCmd := &cobra.Command{
Use: "upgrade",
Short: "Upgrading ctrld to latest version",
ValidArgs: []string{upgradeChannelDev, upgradeChannelProd},
Args: cobra.MaximumNArgs(1),
PreRun: func(cmd *cobra.Command, args []string) {
checkHasElevatedPrivilege()
},
RunE: func(cmd *cobra.Command, args []string) error {
uc, err := NewUpgradeCommand()
if err != nil {
return err
}
return uc.Upgrade(cmd, args)
},
}
rootCmd.AddCommand(upgradeCmd)
return upgradeCmd
}

51
cmd/cli/conn.go Normal file
View File

@@ -0,0 +1,51 @@
package cli
import (
"net"
"time"
)
// logConn wraps a net.Conn, override the Write behavior.
// runCmd uses this wrapper, so as long as startCmd finished,
// ctrld log won't be flushed with un-necessary write errors.
type logConn struct {
conn net.Conn
}
func (lc *logConn) Read(b []byte) (n int, err error) {
return lc.conn.Read(b)
}
func (lc *logConn) Close() error {
return lc.conn.Close()
}
func (lc *logConn) LocalAddr() net.Addr {
return lc.conn.LocalAddr()
}
func (lc *logConn) RemoteAddr() net.Addr {
return lc.conn.RemoteAddr()
}
func (lc *logConn) SetDeadline(t time.Time) error {
return lc.conn.SetDeadline(t)
}
func (lc *logConn) SetReadDeadline(t time.Time) error {
return lc.conn.SetReadDeadline(t)
}
func (lc *logConn) SetWriteDeadline(t time.Time) error {
return lc.conn.SetWriteDeadline(t)
}
func (lc *logConn) Write(b []byte) (int, error) {
// Write performs writes with underlying net.Conn, ignore any errors happen.
// "ctrld run" command use this wrapper to report errors to "ctrld start".
// If no error occurred, "ctrld start" may finish before "ctrld run" attempt
// to close the connection, so ignore errors conservatively here, prevent
// un-necessary error "write to closed connection" flushed to ctrld log.
_, _ = lc.conn.Write(b)
return len(b), nil
}

View File

@@ -8,12 +8,10 @@ import (
"time"
)
// controlClient represents an HTTP client for communicating with the control server
type controlClient struct {
c *http.Client
}
// newControlClient creates a new control client with Unix socket transport
func newControlClient(addr string) *controlClient {
return &controlClient{c: &http.Client{
Transport: &http.Transport{

View File

@@ -37,14 +37,12 @@ type ifaceResponse struct {
OK bool `json:"ok"`
}
// controlServer represents an HTTP server for handling control requests
type controlServer struct {
server *http.Server
mux *http.ServeMux
addr string
}
// newControlServer creates a new control server instance
func newControlServer(addr string) (*controlServer, error) {
mux := http.NewServeMux()
s := &controlServer{
@@ -81,34 +79,34 @@ func (s *controlServer) register(pattern string, handler http.Handler) {
func (p *prog) registerControlServerHandler() {
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
p.Debug().Msg("Handling list clients request")
mainLog.Load().Debug().Msg("handling list clients request")
clients := p.ciTable.ListClients()
p.Debug().Int("client_count", len(clients)).Msg("Retrieved clients list")
mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
sort.Slice(clients, func(i, j int) bool {
return clients[i].IP.Less(clients[j].IP)
})
p.Debug().Msg("Sorted clients by IP address")
mainLog.Load().Debug().Msg("sorted clients by IP address")
if p.metricsQueryStats.Load() {
p.Debug().Msg("Metrics query stats enabled, collecting query counts")
mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts")
for idx, client := range clients {
p.Debug().
mainLog.Load().Debug().
Int("index", idx).
Str("ip", client.IP.String()).
Str("mac", client.Mac).
Str("hostname", client.Hostname).
Msg("Processing client metrics")
Msg("processing client metrics")
client.IncludeQueryCount = true
dm := &dto.Metric{}
if statsClientQueriesCount.MetricVec == nil {
p.Debug().
mainLog.Load().Debug().
Str("client_ip", client.IP.String()).
Msg("Skipping metrics collection: MetricVec is nil")
Msg("skipping metrics collection: MetricVec is nil")
continue
}
@@ -118,44 +116,44 @@ func (p *prog) registerControlServerHandler() {
client.Hostname,
)
if err != nil {
p.Debug().
mainLog.Load().Debug().
Err(err).
Str("client_ip", client.IP.String()).
Str("mac", client.Mac).
Str("hostname", client.Hostname).
Msg("Failed to get metrics for client")
Msg("failed to get metrics for client")
continue
}
if err := m.Write(dm); err == nil && dm.Counter != nil {
client.QueryCount = int64(dm.Counter.GetValue())
p.Debug().
mainLog.Load().Debug().
Str("client_ip", client.IP.String()).
Int64("query_count", client.QueryCount).
Msg("Successfully collected query count")
Msg("successfully collected query count")
} else if err != nil {
p.Debug().
mainLog.Load().Debug().
Err(err).
Str("client_ip", client.IP.String()).
Msg("Failed to write metric")
Msg("failed to write metric")
}
}
} else {
p.Debug().Msg("Metrics query stats disabled, skipping query counts")
mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts")
}
if err := json.NewEncoder(w).Encode(&clients); err != nil {
p.Error().
mainLog.Load().Error().
Err(err).
Int("client_count", len(clients)).
Msg("Failed to encode clients response")
Msg("failed to encode clients response")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
p.Debug().
mainLog.Load().Debug().
Int("client_count", len(clients)).
Msg("Successfully sent clients list response")
Msg("successfully sent clients list response")
}))
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
select {
@@ -177,14 +175,14 @@ func (p *prog) registerControlServerHandler() {
oldSvc := p.cfg.Service
p.mu.Unlock()
if err := p.sendReloadSignal(); err != nil {
p.Error().Err(err).Msg("Could not send reload signal")
mainLog.Load().Err(err).Msg("could not send reload signal")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
select {
case <-p.reloadDoneCh:
case <-time.After(5 * time.Second):
http.Error(w, "Timeout waiting for ctrld reload", http.StatusInternalServerError)
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
return
}
@@ -218,16 +216,20 @@ func (p *prog) registerControlServerHandler() {
return
}
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
// Re-fetch pin code from API.
if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev); rc != nil {
rcReq := &controld.ResolverConfigRequest{
RawUID: cdUID,
Version: rootCmd.Version,
Metadata: ctrld.SystemMetadata(context.Background()),
}
if rc, err := controld.FetchResolverConfig(rcReq, cdDev); rc != nil {
if rc.DeactivationPin != nil {
cdDeactivationPin.Store(*rc.DeactivationPin)
} else {
cdDeactivationPin.Store(defaultDeactivationPin)
}
} else {
p.Warn().Err(err).Msg("Could not re-fetch deactivation pin code")
mainLog.Load().Warn().Err(err).Msg("could not re-fetch deactivation pin code")
}
// If pin code not set, allowing deactivation.
@@ -239,7 +241,7 @@ func (p *prog) registerControlServerHandler() {
var req deactivationRequest
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusPreconditionFailed)
p.Error().Err(err).Msg("Invalid deactivation request")
mainLog.Load().Err(err).Msg("invalid deactivation request")
return
}
@@ -283,7 +285,7 @@ func (p *prog) registerControlServerHandler() {
}
}))
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
lr, err := p.logReaderRaw()
lr, err := p.logReader()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@@ -309,7 +311,7 @@ func (p *prog) registerControlServerHandler() {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
r, err := p.logReaderNoColor()
r, err := p.logReader()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@@ -322,15 +324,14 @@ func (p *prog) registerControlServerHandler() {
UID: cdUID,
Data: r.r,
}
p.Debug().Msg("Sending log file to ControlD server")
mainLog.Load().Debug().Msg("sending log file to ControlD server")
resp := logSentResponse{Size: r.size}
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil {
p.Error().Msgf("Could not send log file to ControlD server: %v", err)
if err := controld.SendLogs(req, cdDev); err != nil {
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
resp.Error = err.Error()
w.WriteHeader(http.StatusInternalServerError)
} else {
p.Debug().Msg("Sending log file successfully")
mainLog.Load().Debug().Msg("sending log file successfully")
w.WriteHeader(http.StatusOK)
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
@@ -340,7 +341,6 @@ func (p *prog) registerControlServerHandler() {
}))
}
// jsonResponse wraps an HTTP handler to set JSON content type
func jsonResponse(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")

4
cmd/cli/dns.go Normal file
View File

@@ -0,0 +1,4 @@
package cli
//lint:ignore U1000 use in os_linux.go
type getDNS func(iface string) []string

File diff suppressed because it is too large Load Diff

View File

@@ -77,8 +77,7 @@ func Test_prog_upstreamFor(t *testing.T) {
cfg := testhelper.SampleConfig(t)
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
p := &prog{cfg: cfg}
p.logger.Store(mainLog.Load())
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
p.um = newUpstreamMonitor(p.cfg)
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
for _, nc := range p.cfg.Network {
@@ -143,94 +142,9 @@ func Test_prog_upstreamFor(t *testing.T) {
}
}
func Test_prog_upstreamForWithCustomMatching(t *testing.T) {
cfg := testhelper.SampleConfig(t)
prog := &prog{cfg: cfg}
prog.logger.Store(mainLog.Load())
for _, nc := range prog.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
t.Fatal(err)
}
nc.IPNets = append(nc.IPNets, ipNet)
}
}
// Create a custom policy with domain-first matching order
customPolicy := &ctrld.ListenerPolicyConfig{
Name: "Custom Policy",
Networks: []ctrld.Rule{
{"network.0": []string{"upstream.1", "upstream.0"}},
},
Macs: []ctrld.Rule{
{"14:45:A0:67:83:0A": []string{"upstream.2"}},
},
Rules: []ctrld.Rule{
{"*.ru": []string{"upstream.1"}},
},
Matching: &ctrld.MatchingConfig{
Order: []string{"domain", "mac", "network"},
},
}
customListener := &ctrld.ListenerConfig{
Policy: customPolicy,
}
tests := []struct {
name string
ip string
mac string
domain string
upstreams []string
matched bool
}{
{
name: "Domain rule should match first with custom order",
ip: "192.168.0.1:0",
mac: "14:45:A0:67:83:0A",
domain: "example.ru",
upstreams: []string{"upstream.1"},
matched: true,
},
{
name: "MAC rule should match when no domain rule",
ip: "192.168.0.1:0",
mac: "14:45:A0:67:83:0A",
domain: "example.com",
upstreams: []string{"upstream.2"},
matched: true,
},
{
name: "Network rule should match when no domain or MAC rule",
ip: "192.168.0.1:0",
mac: "00:11:22:33:44:55",
domain: "example.com",
upstreams: []string{"upstream.1", "upstream.0"},
matched: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", tc.ip)
require.NoError(t, err)
require.NotNil(t, addr)
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
ufr := prog.upstreamFor(ctx, "0", customListener, addr, tc.mac, tc.domain)
assert.Equal(t, tc.matched, ufr.matched)
assert.Equal(t, tc.upstreams, ufr.upstreams)
})
}
}
func TestCache(t *testing.T) {
cfg := testhelper.SampleConfig(t)
prog := &prog{cfg: cfg}
prog.logger.Store(mainLog.Load())
for _, nc := range prog.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
@@ -550,254 +464,3 @@ func Test_isWanClient(t *testing.T) {
})
}
}
func Test_shouldStartRecovery(t *testing.T) {
tests := []struct {
name string
reason RecoveryReason
hasExistingRecovery bool
expectedResult bool
description string
}{
{
name: "network change with existing recovery",
reason: RecoveryReasonNetworkChange,
hasExistingRecovery: true,
expectedResult: true,
description: "should cancel existing recovery and start new one for network change",
},
{
name: "network change without existing recovery",
reason: RecoveryReasonNetworkChange,
hasExistingRecovery: false,
expectedResult: true,
description: "should start new recovery for network change",
},
{
name: "regular failure with existing recovery",
reason: RecoveryReasonRegularFailure,
hasExistingRecovery: true,
expectedResult: false,
description: "should skip duplicate recovery for regular failure",
},
{
name: "regular failure without existing recovery",
reason: RecoveryReasonRegularFailure,
hasExistingRecovery: false,
expectedResult: true,
description: "should start new recovery for regular failure",
},
{
name: "OS failure with existing recovery",
reason: RecoveryReasonOSFailure,
hasExistingRecovery: true,
expectedResult: false,
description: "should skip duplicate recovery for OS failure",
},
{
name: "OS failure without existing recovery",
reason: RecoveryReasonOSFailure,
hasExistingRecovery: false,
expectedResult: true,
description: "should start new recovery for OS failure",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
p := newTestProg(t)
// Setup existing recovery if needed
if tc.hasExistingRecovery {
p.recoveryCancelMu.Lock()
p.recoveryCancel = func() {} // Mock cancel function
p.recoveryCancelMu.Unlock()
}
result := p.shouldStartRecovery(tc.reason)
assert.Equal(t, tc.expectedResult, result, tc.description)
})
}
}
func Test_createRecoveryContext(t *testing.T) {
p := newTestProg(t)
ctx, cleanup := p.createRecoveryContext()
// Verify context is created
assert.NotNil(t, ctx)
assert.NotNil(t, cleanup)
// Verify recoveryCancel is set
p.recoveryCancelMu.Lock()
assert.NotNil(t, p.recoveryCancel)
p.recoveryCancelMu.Unlock()
// Test cleanup function
cleanup()
// Verify recoveryCancel is cleared
p.recoveryCancelMu.Lock()
assert.Nil(t, p.recoveryCancel)
p.recoveryCancelMu.Unlock()
}
func Test_prepareForRecovery(t *testing.T) {
tests := []struct {
name string
reason RecoveryReason
wantErr bool
}{
{
name: "regular failure",
reason: RecoveryReasonRegularFailure,
wantErr: false,
},
{
name: "network change",
reason: RecoveryReasonNetworkChange,
wantErr: false,
},
{
name: "OS failure",
reason: RecoveryReasonOSFailure,
wantErr: false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
p := newTestProg(t)
err := p.prepareForRecovery(tc.reason)
if tc.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// Verify recoveryRunning is set to true
assert.True(t, p.recoveryRunning.Load())
})
}
}
func Test_completeRecovery(t *testing.T) {
tests := []struct {
name string
reason RecoveryReason
recovered string
wantErr bool
}{
{
name: "regular failure recovery",
reason: RecoveryReasonRegularFailure,
recovered: "upstream1",
wantErr: false,
},
{
name: "network change recovery",
reason: RecoveryReasonNetworkChange,
recovered: "upstream2",
wantErr: false,
},
{
name: "OS failure recovery",
reason: RecoveryReasonOSFailure,
recovered: "upstream3",
wantErr: false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
p := newTestProg(t)
err := p.completeRecovery(tc.reason, tc.recovered)
if tc.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// Verify recoveryRunning is set to false
assert.False(t, p.recoveryRunning.Load())
})
}
}
func Test_reinitializeOSResolver(t *testing.T) {
p := newTestProg(t)
err := p.reinitializeOSResolver("Test message")
// This function should not return an error under normal circumstances
// The actual behavior depends on the OS resolver implementation
assert.NoError(t, err)
}
func Test_handleRecovery_Integration(t *testing.T) {
tests := []struct {
name string
reason RecoveryReason
wantErr bool
}{
{
name: "network change recovery",
reason: RecoveryReasonNetworkChange,
wantErr: false,
},
{
name: "regular failure recovery",
reason: RecoveryReasonRegularFailure,
wantErr: false,
},
{
name: "OS failure recovery",
reason: RecoveryReasonOSFailure,
wantErr: false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
p := newTestProg(t)
// This is an integration test that exercises the full recovery flow
// In a real test environment, you would mock the dependencies
// For now, we're just testing that the method doesn't panic
// and that the recovery logic flows correctly
assert.NotPanics(t, func() {
// Test only the preparation phase to avoid actual upstream checking
if !p.shouldStartRecovery(tc.reason) {
return
}
_, cleanup := p.createRecoveryContext()
defer cleanup()
if err := p.prepareForRecovery(tc.reason); err != nil {
return
}
// Skip the actual upstream recovery check for this test
// as it requires properly configured upstreams
})
})
}
}
// newTestProg creates a properly initialized *prog for testing.
func newTestProg(t *testing.T) *prog {
p := &prog{cfg: testhelper.SampleConfig(t)}
p.logger.Store(mainLog.Load())
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
return p
}

View File

@@ -4,15 +4,11 @@ import "regexp"
// validHostname reports whether hostname is a valid hostname.
// A valid hostname contains 3 -> 64 characters and conform to RFC1123.
// This function validates hostnames to ensure they meet DNS naming standards
// and prevents invalid hostnames from being used in DNS configurations
func validHostname(hostname string) bool {
hostnameLen := len(hostname)
if hostnameLen < 3 || hostnameLen > 64 {
return false
}
// RFC1123 regex pattern ensures hostnames follow DNS naming conventions
// This prevents issues with DNS resolution and system compatibility
validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
return validHostnameRfc1123.MatchString(hostname)
}

View File

@@ -1,172 +0,0 @@
package cli
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"sync"
)
// HTTP log server endpoint constants
const (
httpLogEndpointPing = "/ping"
httpLogEndpointLogs = "/logs"
httpLogEndpointExit = "/exit"
)
// httpLogClient sends logs to an HTTP server via POST requests.
// This replaces the logConn functionality with HTTP-based communication.
type httpLogClient struct {
baseURL string
client *http.Client
}
// newHTTPLogClient creates a new HTTP log client
func newHTTPLogClient(sockPath string) *httpLogClient {
return &httpLogClient{
baseURL: "http://unix",
client: &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
},
}
}
// Write sends log data to the HTTP server via POST request
func (hlc *httpLogClient) Write(b []byte) (int, error) {
// Send log data via HTTP POST to /logs endpoint
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointLogs, "text/plain", bytes.NewReader(b))
if err != nil {
// Ignore errors to prevent log pollution, just like the original logConn
return len(b), nil
}
resp.Body.Close()
return len(b), nil
}
// Ping tests if the HTTP log server is available
func (hlc *httpLogClient) Ping() error {
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointPing)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// Close sends exit signal to the HTTP server
func (hlc *httpLogClient) Close() error {
// Send exit signal via HTTP POST with empty body
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
if err != nil {
return err
}
resp.Body.Close()
return nil
}
// GetLogs retrieves all collected logs from the HTTP server
func (hlc *httpLogClient) GetLogs() ([]byte, error) {
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointLogs)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNoContent {
return []byte{}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// httpLogServer starts an HTTP server listening on unix socket to collect logs from runCmd.
func httpLogServer(sockPath string, stopLogCh chan struct{}) error {
addr, err := net.ResolveUnixAddr("unix", sockPath)
if err != nil {
return fmt.Errorf("invalid log sock path: %w", err)
}
ln, err := net.ListenUnix("unix", addr)
if err != nil {
return fmt.Errorf("could not listen log socket: %w", err)
}
defer ln.Close()
// Create a log writer to store all logs
logWriter := newLogWriter()
// Use a sync.Once to ensure channel is only closed once
var channelClosed sync.Once
mux := http.NewServeMux()
mux.HandleFunc(httpLogEndpointPing, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc(httpLogEndpointLogs, func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
// POST /logs - Store log data
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
// Store log data in log writer
logWriter.Write(body)
w.WriteHeader(http.StatusOK)
case http.MethodGet:
// GET /logs - Retrieve all logs
// Get all logs from the log writer
logWriter.mu.Lock()
logs := logWriter.buf.Bytes()
logWriter.mu.Unlock()
if len(logs) == 0 {
w.WriteHeader(http.StatusNoContent)
return
}
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write(logs)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})
mux.HandleFunc(httpLogEndpointExit, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Close the stop channel to signal completion (only once)
channelClosed.Do(func() {
close(stopLogCh)
})
w.WriteHeader(http.StatusOK)
})
server := &http.Server{Handler: mux}
return server.Serve(ln)
}

View File

@@ -1,747 +0,0 @@
package cli
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"
"golang.org/x/net/nettest"
)
func unixDomainSocketPath(t *testing.T) string {
t.Helper()
sockPath, err := nettest.LocalPath()
if err != nil {
t.Fatalf("Failed to create temporary directory: %v", err)
}
return sockPath
}
func TestHTTPLogServer(t *testing.T) {
sockPath := unixDomainSocketPath(t)
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait a bit for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
t.Run("Ping endpoint", func(t *testing.T) {
resp, err := client.Get("http://unix" + httpLogEndpointPing)
if err != nil {
t.Fatalf("Failed to ping server: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
})
t.Run("Ping endpoint wrong method", func(t *testing.T) {
resp, err := client.Post("http://unix"+httpLogEndpointPing, "text/plain", bytes.NewReader([]byte("test")))
if err != nil {
t.Fatalf("Failed to send POST to ping: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", resp.StatusCode)
}
})
t.Run("Log endpoint", func(t *testing.T) {
testLog := "test log message"
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(testLog)))
if err != nil {
t.Fatalf("Failed to send log: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check if log was stored by retrieving it
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer logsResp.Body.Close()
if logsResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
}
body, err := io.ReadAll(logsResp.Body)
if err != nil {
t.Fatalf("Failed to read logs: %v", err)
}
if !strings.Contains(string(body), testLog) {
t.Errorf("Expected log '%s' not found in stored logs", testLog)
}
})
t.Run("Log endpoint wrong method", func(t *testing.T) {
// Test unsupported method (PUT) on /logs endpoint
req, err := http.NewRequest("PUT", "http://unix"+httpLogEndpointLogs, bytes.NewReader([]byte("test")))
if err != nil {
t.Fatalf("Failed to create PUT request: %v", err)
}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Failed to send PUT to logs: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", resp.StatusCode)
}
})
t.Run("Exit endpoint", func(t *testing.T) {
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
if err != nil {
t.Fatalf("Failed to send exit: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check if channel is closed by trying to read from it
select {
case _, ok := <-stopLogCh:
if ok {
t.Error("Expected channel to be closed, but it's still open")
}
case <-time.After(1 * time.Second):
t.Error("Timeout waiting for channel closure")
}
})
t.Run("Exit endpoint wrong method", func(t *testing.T) {
resp, err := client.Get("http://unix" + httpLogEndpointExit)
if err != nil {
t.Fatalf("Failed to send GET to exit: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected status 405, got %d", resp.StatusCode)
}
})
t.Run("Multiple log messages", func(t *testing.T) {
logs := []string{"log1", "log2", "log3"}
for _, log := range logs {
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
if err != nil {
t.Fatalf("Failed to send log '%s': %v", log, err)
}
resp.Body.Close()
}
// Check if all logs were stored by retrieving them
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer logsResp.Body.Close()
if logsResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
}
body, err := io.ReadAll(logsResp.Body)
if err != nil {
t.Fatalf("Failed to read logs: %v", err)
}
logContent := string(body)
for i, expectedLog := range logs {
if !strings.Contains(logContent, expectedLog) {
t.Errorf("Log %d: expected '%s' not found in stored logs", i, expectedLog)
}
}
})
t.Run("Large log message", func(t *testing.T) {
largeLog := strings.Repeat("a", 1024*10) // 10KB log message
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(largeLog)))
if err != nil {
t.Fatalf("Failed to send large log: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check if large log was stored by retrieving it
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer logsResp.Body.Close()
if logsResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
}
body, err := io.ReadAll(logsResp.Body)
if err != nil {
t.Fatalf("Failed to read logs: %v", err)
}
if !strings.Contains(string(body), largeLog) {
t.Error("Large log message was not stored correctly")
}
})
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogServerInvalidSocketPath(t *testing.T) {
// Test with invalid socket path
invalidPath := "/invalid/path/that/does/not/exist.sock"
stopLogCh := make(chan struct{})
err := httpLogServer(invalidPath, stopLogCh)
if err == nil {
t.Error("Expected error for invalid socket path")
}
if !strings.Contains(err.Error(), "could not listen log socket") {
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
}
}
func TestHTTPLogServerSocketInUse(t *testing.T) {
// Create a temporary socket path
sockPath := unixDomainSocketPath(t)
defer os.Remove(sockPath)
// Create the first server
stopLogCh1 := make(chan struct{})
serverErr1 := make(chan error, 1)
go func() {
serverErr1 <- httpLogServer(sockPath, stopLogCh1)
}()
// Wait for first server to start
time.Sleep(100 * time.Millisecond)
// Try to create a second server on the same socket
stopLogCh2 := make(chan struct{})
err := httpLogServer(sockPath, stopLogCh2)
if err == nil {
t.Error("Expected error when socket is already in use")
}
if !strings.Contains(err.Error(), "could not listen log socket") {
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
}
}
func TestHTTPLogServerConcurrentRequests(t *testing.T) {
// Create a temporary socket path
sockPath := unixDomainSocketPath(t)
defer os.Remove(sockPath)
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
// Send concurrent requests
numRequests := 10
done := make(chan bool, numRequests)
for i := 0; i < numRequests; i++ {
go func(i int) {
defer func() { done <- true }()
logMsg := fmt.Sprintf("concurrent log %d", i)
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
if err != nil {
t.Errorf("Failed to send concurrent log %d: %v", i, err)
return
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for request %d, got %d", i, resp.StatusCode)
}
}(i)
}
// Wait for all requests to complete
for i := 0; i < numRequests; i++ {
select {
case <-done:
// Request completed
case <-time.After(5 * time.Second):
t.Errorf("Timeout waiting for concurrent request %d", i)
}
}
// Check if all logs were stored by retrieving them
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer logsResp.Body.Close()
if logsResp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
}
body, err := io.ReadAll(logsResp.Body)
if err != nil {
t.Fatalf("Failed to read logs: %v", err)
}
logContent := string(body)
// Verify all logs were stored
for i := 0; i < numRequests; i++ {
expectedLog := fmt.Sprintf("concurrent log %d", i)
if !strings.Contains(logContent, expectedLog) {
t.Errorf("Log '%s' was not stored", expectedLog)
}
}
}
func TestHTTPLogServerErrorHandling(t *testing.T) {
// Create a temporary socket path
sockPath := unixDomainSocketPath(t)
defer os.Remove(sockPath)
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
t.Run("Invalid request body", func(t *testing.T) {
// Test with malformed request - this will fail at HTTP level, not server level
// The server will return 400 Bad Request for invalid body
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", strings.NewReader(""))
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()
// Empty body should still be processed successfully
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
})
}
func BenchmarkHTTPLogServer(b *testing.B) {
// Create a temporary socket path
tmpDir := b.TempDir()
sockPath := filepath.Join(tmpDir, "bench.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
go func() {
httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
// Benchmark log sending
b.ResetTimer()
for i := 0; i < b.N; i++ {
logMsg := fmt.Sprintf("benchmark log %d", i)
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
if err != nil {
b.Fatalf("Failed to send log: %v", err)
}
resp.Body.Close()
}
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogClient(t *testing.T) {
// Create a temporary socket path
sockPath := unixDomainSocketPath(t)
defer os.Remove(sockPath)
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP log client
client := newHTTPLogClient(sockPath)
t.Run("Ping server", func(t *testing.T) {
err := client.Ping()
if err != nil {
t.Errorf("Ping failed: %v", err)
}
})
t.Run("Write logs", func(t *testing.T) {
testLog := "test log message from client"
n, err := client.Write([]byte(testLog))
if err != nil {
t.Errorf("Write failed: %v", err)
}
if n != len(testLog) {
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
}
// Check if log was stored by retrieving it
logs, err := client.GetLogs()
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
if !strings.Contains(string(logs), testLog) {
t.Errorf("Expected log '%s' not found in stored logs", testLog)
}
})
t.Run("Close client", func(t *testing.T) {
err := client.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
// Check if channel is closed (signaling completion)
select {
case _, ok := <-stopLogCh:
if ok {
t.Error("Expected channel to be closed, but it's still open")
}
case <-time.After(1 * time.Second):
t.Error("Timeout waiting for channel closure")
}
})
}
func TestHTTPLogClientServerUnavailable(t *testing.T) {
// Create client with non-existent socket
sockPath := "/non/existent/socket.sock"
client := newHTTPLogClient(sockPath)
t.Run("Ping unavailable server", func(t *testing.T) {
err := client.Ping()
if err == nil {
t.Error("Expected ping to fail for unavailable server")
}
})
t.Run("Write to unavailable server", func(t *testing.T) {
testLog := "test log message"
n, err := client.Write([]byte(testLog))
if err != nil {
t.Errorf("Write should not return error (ignores errors): %v", err)
}
if n != len(testLog) {
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
}
})
t.Run("Close unavailable server", func(t *testing.T) {
err := client.Close()
if err == nil {
t.Error("Expected close to fail for unavailable server")
}
})
}
func BenchmarkHTTPLogClient(b *testing.B) {
// Create a temporary socket path
tmpDir := b.TempDir()
sockPath := filepath.Join(tmpDir, "bench.sock")
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
go func() {
httpLogServer(sockPath, stopLogCh)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP log client
client := newHTTPLogClient(sockPath)
// Benchmark client writes
b.ResetTimer()
for i := 0; i < b.N; i++ {
logMsg := fmt.Sprintf("benchmark write %d", i)
client.Write([]byte(logMsg))
}
// Clean up
os.Remove(sockPath)
}
func TestHTTPLogServerWithLogWriter(t *testing.T) {
// Create a temporary socket path
sockPath := unixDomainSocketPath(t)
defer os.Remove(sockPath)
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
serverErr := make(chan error, 1)
go func() {
serverErr <- httpLogServer(sockPath, stopLogCh)
}()
// Wait a bit for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP client
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath)
},
},
}
t.Run("Store and retrieve logs", func(t *testing.T) {
// Send multiple log messages
logs := []string{"log message 1", "log message 2", "log message 3"}
for _, log := range logs {
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
if err != nil {
t.Fatalf("Failed to send log '%s': %v", log, err)
}
resp.Body.Close()
}
// Retrieve all logs
resp, err := client.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read logs response: %v", err)
}
logContent := string(body)
for _, log := range logs {
if !strings.Contains(logContent, log) {
t.Errorf("Expected log '%s' not found in retrieved logs", log)
}
}
})
t.Run("Empty logs endpoint", func(t *testing.T) {
// Create a new server for this test
sockPath2 := unixDomainSocketPath(t)
stopLogCh2 := make(chan struct{})
go func() {
httpLogServer(sockPath2, stopLogCh2)
}()
time.Sleep(100 * time.Millisecond)
client2 := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", sockPath2)
},
},
}
resp, err := client2.Get("http://unix" + httpLogEndpointLogs)
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected status 204, got %d", resp.StatusCode)
}
os.Remove(sockPath2)
})
t.Run("Channel closure on exit", func(t *testing.T) {
// Send exit signal
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
if err != nil {
t.Fatalf("Failed to send exit: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
// Check if channel is closed by trying to read from it
select {
case _, ok := <-stopLogCh:
if ok {
t.Error("Expected channel to be closed, but it's still open")
}
case <-time.After(1 * time.Second):
t.Error("Timeout waiting for channel closure")
}
})
}
func TestHTTPLogClientGetLogs(t *testing.T) {
// Create a temporary socket path
sockPath := unixDomainSocketPath(t)
defer os.Remove(sockPath)
// Create log channel
stopLogCh := make(chan struct{})
// Start HTTP log server in a goroutine
go func() {
httpLogServer(sockPath, stopLogCh)
}()
// Wait a bit for server to start
time.Sleep(100 * time.Millisecond)
// Create HTTP log client
client := newHTTPLogClient(sockPath)
t.Run("Get logs from client", func(t *testing.T) {
// Send some logs
testLogs := []string{"client log 1", "client log 2", "client log 3"}
for _, log := range testLogs {
client.Write([]byte(log + "\n"))
}
// Retrieve logs using client method
logs, err := client.GetLogs()
if err != nil {
t.Fatalf("Failed to get logs: %v", err)
}
logContent := string(logs)
for _, log := range testLogs {
if !strings.Contains(logContent, log) {
t.Errorf("Expected log '%s' not found in retrieved logs", log)
}
}
})
t.Run("Get empty logs", func(t *testing.T) {
// Create a new client for empty logs test
sockPath2 := unixDomainSocketPath(t)
stopLogCh2 := make(chan struct{})
go func() {
httpLogServer(sockPath2, stopLogCh2)
}()
time.Sleep(100 * time.Millisecond)
client2 := newHTTPLogClient(sockPath2)
logs, err := client2.GetLogs()
if err != nil {
t.Fatalf("Failed to get empty logs: %v", err)
}
if len(logs) != 0 {
t.Errorf("Expected empty logs, got %d bytes", len(logs))
}
os.Remove(sockPath2)
})
}

View File

@@ -9,7 +9,6 @@ import (
// AppCallback provides hooks for injecting certain functionalities
// from mobile platforms to main ctrld cli.
// This allows mobile applications to customize behavior without modifying core CLI code
type AppCallback struct {
HostName func() string
LanIp func() string
@@ -18,7 +17,6 @@ type AppCallback struct {
}
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
// This provides a clean interface for mobile apps to configure ctrld behavior
type AppConfig struct {
CdUID string
ProvisionID string
@@ -29,29 +27,18 @@ type AppConfig struct {
LogPath string
}
// Network and HTTP configuration constants
const (
// defaultHTTPTimeout provides reasonable timeout for HTTP operations
// This prevents hanging requests while allowing sufficient time for network delays
defaultHTTPTimeout = 30 * time.Second
// defaultMaxRetries provides retry attempts for failed HTTP requests
// This improves reliability in unstable network conditions
defaultMaxRetries = 3
// downloadServerIp is the fallback IP for download operations
// This ensures downloads work even when DNS resolution fails
downloadServerIp = "23.171.240.151"
defaultMaxRetries = 3
downloadServerIp = "23.171.240.151"
)
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
// This ensures reliable HTTP operations by preferring IPv4 and handling timeouts gracefully
func httpClientWithFallback(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
Transport: &http.Transport{
// Prefer IPv4 over IPv6
// This improves compatibility with networks that have IPv6 issues
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
@@ -62,7 +49,6 @@ func httpClientWithFallback(timeout time.Duration) *http.Client {
}
// doWithRetry performs an HTTP request with retries
// This improves reliability by automatically retrying failed requests with exponential backoff
func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) {
var lastErr error
client := httpClientWithFallback(defaultHTTPTimeout)
@@ -74,8 +60,7 @@ func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response,
}
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
// Linear backoff reduces server load and improves success rate
time.Sleep(time.Second * time.Duration(attempt+1))
time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff
}
resp, err := client.Do(req)
@@ -83,8 +68,8 @@ func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response,
return resp, nil
}
if ipReq != nil {
mainLog.Load().Warn().Err(err).Msgf("Dial to %q failed", req.Host)
mainLog.Load().Warn().Msgf("Fallback to direct ip to download prod version: %q", ip)
mainLog.Load().Warn().Err(err).Msgf("dial to %q failed", req.Host)
mainLog.Load().Warn().Msgf("fallback to direct IP to download prod version: %q", ip)
resp, err = client.Do(ipReq)
if err == nil {
return resp, nil
@@ -101,7 +86,6 @@ func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response,
}
// Helper for making GET requests with retries
// This provides a simplified interface for common GET operations with built-in retry logic
func getWithRetry(url string, ip string) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {

View File

@@ -6,105 +6,39 @@ import (
"fmt"
"io"
"os"
"regexp"
"strings"
"sync"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/rs/zerolog"
"github.com/Control-D-Inc/ctrld"
)
// Log writer constants for buffer management and log formatting
const (
// logWriterSize is the default buffer size for log writers
// This provides sufficient space for runtime logs without excessive memory usage
logWriterSize = 1024 * 1024 * 5 // 5 MB
// logWriterSmallSize is used for memory-constrained environments
// This reduces memory footprint while still maintaining log functionality
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
// logWriterInitialSize is the initial buffer allocation
// This provides immediate space for early log entries
logWriterInitialSize = 32 * 1024 // 32 KB
// logWriterSentInterval controls how often logs are sent to external systems
// This balances real-time logging with system performance
logWriterSentInterval = time.Minute
// logWriterInitEndMarker marks the end of initialization logs
// This helps separate startup logs from runtime logs
logWriterSize = 1024 * 1024 * 5 // 5 MB
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
logWriterInitialSize = 32 * 1024 // 32 KB
logWriterSentInterval = time.Minute
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
// logWriterLogEndMarker marks the end of log sections
// This provides clear boundaries for log parsing and analysis
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
)
// Custom level encoders that handle NOTICE level
// Since NOTICE and WARN share the same numeric value (1), we handle them specially
// in the encoder to display NOTICE messages with the "NOTICE" prefix.
// Note: WARN messages will also display as "NOTICE" because they share the same level value.
// This is the intended behavior for visual distinction.
// noticeLevelEncoder provides custom level encoding for NOTICE level
// This ensures NOTICE messages are clearly distinguished from other log levels
func noticeLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
switch l {
case ctrld.NoticeLevel:
enc.AppendString("NOTICE")
default:
zapcore.CapitalLevelEncoder(l, enc)
}
}
// noticeColorLevelEncoder provides colored level encoding for NOTICE level
// This uses cyan color to make NOTICE messages visually distinct in terminal output
func noticeColorLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
switch l {
case ctrld.NoticeLevel:
enc.AppendString("\x1b[36mNOTICE\x1b[0m") // Cyan color for NOTICE
default:
zapcore.CapitalColorLevelEncoder(l, enc)
}
}
// logViewResponse represents the response structure for log viewing requests
// This provides a consistent JSON format for log data retrieval
type logViewResponse struct {
Data string `json:"data"`
}
// logSentResponse represents the response structure for log sending operations
// This includes size information and error details for debugging
type logSentResponse struct {
Size int64 `json:"size"`
Error string `json:"error"`
}
// logReader provides read access to log data with size information.
//
// This struct encapsulates log reading functionality for external consumers,
// providing both the log content and metadata about the log size. It supports
// reading from both internal log buffers (when no external logging is configured)
// and external log files (when logging to file is enabled).
//
// Fields:
// - r: An io.ReadCloser that provides access to the log content
// - size: The total size of the log data in bytes
//
// The logReader is used by the control server to serve log content to clients
// and by various CLI commands that need to display or process log data.
type logReader struct {
r io.ReadCloser
size int64
}
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
// This provides in-memory log storage for debugging and monitoring purposes
type logWriter struct {
mu sync.Mutex
buf bytes.Buffer
@@ -112,37 +46,30 @@ type logWriter struct {
}
// newLogWriter creates an internal log writer.
// This provides the default log writer with standard buffer size
func newLogWriter() *logWriter {
return newLogWriterWithSize(logWriterSize)
}
// newSmallLogWriter creates an internal log writer with small buffer size.
// This is used in memory-constrained environments or for temporary logging
func newSmallLogWriter() *logWriter {
return newLogWriterWithSize(logWriterSmallSize)
}
// newLogWriterWithSize creates an internal log writer with a given buffer size.
// This allows customization of log buffer size based on specific requirements
func newLogWriterWithSize(size int) *logWriter {
lw := &logWriter{size: size}
return lw
}
// Write implements io.Writer interface for logWriter
// This manages buffer overflow by discarding old data while preserving important markers
func (lw *logWriter) Write(p []byte) (int, error) {
lw.mu.Lock()
defer lw.mu.Unlock()
// If writing p causes overflows, discard old data.
// This prevents unbounded memory growth while maintaining recent logs
if lw.buf.Len()+len(p) > lw.size {
buf := lw.buf.Bytes()
haveEndMarker := false
// If there's init end marker already, preserve the data til the marker.
// This ensures initialization logs are always available for debugging
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
buf = buf[:idx+len(logWriterInitEndMarker)]
haveEndMarker = true
@@ -168,20 +95,20 @@ func (lw *logWriter) Write(p []byte) (int, error) {
// initLogging initializes global logging setup.
func (p *prog) initLogging(backup bool) {
logCores := initLoggingWithBackup(backup)
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
logWriters := initLoggingWithBackup(backup)
// Initializing internal logging after global logging.
p.initInternalLogging(logCores)
p.logger.Store(mainLog.Load())
p.initInternalLogging(logWriters)
}
// initInternalLogging performs internal logging if there's no log enabled.
func (p *prog) initInternalLogging(externalCores []zapcore.Core) {
func (p *prog) initInternalLogging(writers []io.Writer) {
if !p.needInternalLogging() {
return
}
p.initInternalLogWriterOnce.Do(func() {
p.Notice().Msg("Internal logging enabled")
mainLog.Load().Notice().Msg("internal logging enabled")
p.internalLogWriter = newLogWriter()
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
p.internalWarnLogWriter = newSmallLogWriter()
@@ -190,26 +117,28 @@ func (p *prog) initInternalLogging(externalCores []zapcore.Core) {
lw := p.internalLogWriter
wlw := p.internalWarnLogWriter
p.mu.Unlock()
// Create zap cores for different writers
var cores []zapcore.Core
cores = append(cores, externalCores...)
// Add core for internal log writer.
// Run the internal logging at debug level, so we could
// If ctrld was run without explicit verbose level,
// run the internal logging at debug level, so we could
// have enough information for troubleshooting.
internalCore := newHumanReadableZapCore(lw, zapcore.DebugLevel)
cores = append(cores, internalCore)
// Add core for internal warn log writer
warnCore := newHumanReadableZapCore(wlw, zapcore.WarnLevel)
cores = append(cores, warnCore)
// Create a multi-core logger
multiCore := zapcore.NewTee(cores...)
logger := zap.New(multiCore)
mainLog.Store(&ctrld.Logger{Logger: logger})
if verbose == 0 {
for i := range writers {
w := &zerolog.FilteredLevelWriter{
Writer: zerolog.LevelWriterAdapter{Writer: writers[i]},
Level: zerolog.NoticeLevel,
}
writers[i] = w
}
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}
writers = append(writers, lw)
writers = append(writers, &zerolog.FilteredLevelWriter{
Writer: zerolog.LevelWriterAdapter{Writer: wlw},
Level: zerolog.WarnLevel,
})
multi := zerolog.MultiLevelWriter(writers...)
l := mainLog.Load().Output(multi).With().Logger()
mainLog.Store(&l)
ctrld.ProxyLogger.Store(&l)
}
// needInternalLogging reports whether prog needs to run internal logging.
@@ -225,69 +154,7 @@ func (p *prog) needInternalLogging() bool {
return true
}
// logReaderNoColor returns a logReader with ANSI color codes stripped from the log content.
//
// This method is useful when log content needs to be processed by tools that don't
// handle ANSI escape sequences properly, or when storing logs in plain text format.
// It internally calls logReader(true) to strip color codes.
//
// Returns:
// - *logReader: A logReader instance with color codes removed, or nil if no logs available
// - error: Any error encountered during log reading (e.g., empty logs, file access issues)
//
// Use cases:
// - Log processing pipelines that require plain text
// - Storing logs in databases or text files
// - Displaying logs in environments that don't support color
func (p *prog) logReaderNoColor() (*logReader, error) {
return p.logReader(true)
}
// logReaderRaw returns a logReader with ANSI color codes preserved in the log content.
//
// This method maintains the original formatting of log entries including color codes,
// which is useful for displaying logs in terminals that support ANSI colors or when
// the original visual formatting needs to be preserved. It internally calls logReader(false).
//
// Returns:
// - *logReader: A logReader instance with color codes preserved, or nil if no logs available
// - error: Any error encountered during log reading (e.g., empty logs, file access issues)
//
// Use cases:
// - Terminal-based log viewers that support color
// - Interactive debugging sessions
// - Preserving original log formatting for display
func (p *prog) logReaderRaw() (*logReader, error) {
return p.logReader(false)
}
// logReader creates a logReader instance for accessing log content with optional color stripping.
//
// This is the core method that handles log reading from different sources based on the
// current logging configuration. It supports both internal logging (when no external
// logging is configured) and external file logging (when logging to file is enabled).
//
// Behavior:
// - Internal logging: Reads from internal log buffers (normal logs + warning logs)
// and combines them with appropriate markers for separation
// - External logging: Reads directly from the configured log file
// - Empty logs: Returns appropriate error messages when no log content is available
//
// Parameters:
// - stripColor: If true, removes ANSI color codes from log content; if false, preserves them
//
// Returns:
// - *logReader: A logReader instance providing access to log content and size metadata
// - error: Any error encountered during log reading, including:
// - "nil internal log writer" - Internal logging not properly initialized
// - "nil internal warn log writer" - Warning log writer not properly initialized
// - "internal log is empty" - No content in internal log buffers
// - "log file is empty" - External log file exists but contains no data
// - File system errors when accessing external log files
//
// The method handles thread-safe access to internal log buffers and provides
// comprehensive error handling for various edge cases.
func (p *prog) logReader(stripColor bool) (*logReader, error) {
func (p *prog) logReader() (*logReader, error) {
if p.needInternalLogging() {
p.mu.Lock()
lw := p.internalLogWriter
@@ -299,15 +166,14 @@ func (p *prog) logReader(stripColor bool) (*logReader, error) {
if wlw == nil {
return nil, errors.New("nil internal warn log writer")
}
// Normal log content.
lw.mu.Lock()
lwReader := newLogReader(&lw.buf, stripColor)
lwReader := bytes.NewReader(lw.buf.Bytes())
lwSize := lw.buf.Len()
lw.mu.Unlock()
// Warn log content.
wlw.mu.Lock()
wlwReader := newLogReader(&wlw.buf, stripColor)
wlwReader := bytes.NewReader(wlw.buf.Bytes())
wlwSize := wlw.buf.Len()
wlw.mu.Unlock()
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
@@ -336,72 +202,3 @@ func (p *prog) logReader(stripColor bool) (*logReader, error) {
}
return lr, nil
}
// newHumanReadableZapCore creates a zap core optimized for human-readable log output.
//
// Features:
// - Uses development encoder configuration for enhanced readability
// - Console encoding with colored log levels for easy visual scanning
// - Millisecond precision timestamps in human-friendly format
// - Structured field output with clear key-value pairs
// - Ideal for development, debugging, and interactive terminal sessions
//
// Parameters:
// - w: The output writer (e.g., os.Stdout, file, buffer)
// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error)
//
// Returns a zapcore.Core configured for human consumption.
func newHumanReadableZapCore(w io.Writer, level zapcore.Level) zapcore.Core {
encoderConfig := zap.NewDevelopmentEncoderConfig()
encoderConfig.TimeKey = "time"
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli)
encoderConfig.EncodeLevel = noticeColorLevelEncoder
encoder := zapcore.NewConsoleEncoder(encoderConfig)
return zapcore.NewCore(encoder, zapcore.AddSync(w), level)
}
// newMachineFriendlyZapCore creates a zap core optimized for machine processing and log aggregation.
//
// Features:
// - Uses production encoder configuration for consistent, parseable output
// - Console encoding with non-colored log levels for log parsing tools
// - Millisecond precision timestamps in ISO-like format
// - Structured field output optimized for log aggregation systems
// - Ideal for production environments, log shipping, and automated analysis
//
// Parameters:
// - w: The output writer (e.g., os.Stdout, file, buffer)
// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error)
//
// Returns a zapcore.Core configured for machine consumption and log aggregation.
func newMachineFriendlyZapCore(w io.Writer, level zapcore.Level) zapcore.Core {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = "time"
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli)
encoderConfig.EncodeLevel = noticeLevelEncoder
encoder := zapcore.NewConsoleEncoder(encoderConfig)
return zapcore.NewCore(encoder, zapcore.AddSync(w), level)
}
// ansiRegex is a regular expression to match ANSI color codes.
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
// newLogReader creates a reader for log buffer content with optional ANSI color stripping.
//
// This function provides flexible log content access by allowing consumers to choose
// between raw log data (with ANSI color codes) or stripped content (without color codes).
// The color stripping is useful when logs need to be processed by tools that don't
// handle ANSI escape sequences properly, or when storing logs in plain text format.
//
// Parameters:
// - buf: The log buffer containing the log data to read
// - stripColor: If true, strips ANSI color codes from the log content;
// if false, returns raw log content with color codes preserved
//
// Returns an io.Reader that provides access to the processed log content.
func newLogReader(buf *bytes.Buffer, stripColor bool) io.Reader {
if stripColor {
return strings.NewReader(ansiRegex.ReplaceAllString(buf.String(), ""))
}
return strings.NewReader(buf.String())
}

View File

@@ -1,16 +1,9 @@
package cli
import (
"bytes"
"io"
"strings"
"sync"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/Control-D-Inc/ctrld"
)
func Test_logWriter_Write(t *testing.T) {
@@ -90,328 +83,3 @@ func Test_logWriter_MarkerInitEnd(t *testing.T) {
t.Fatalf("unexpected log content: %s", lw.buf.String())
}
}
// TestNoticeLevel tests that the custom NOTICE level works correctly
func TestNoticeLevel(t *testing.T) {
// Create a buffer to capture log output
var buf bytes.Buffer
// Create encoder config with custom NOTICE level support
encoderConfig := zap.NewDevelopmentEncoderConfig()
encoderConfig.TimeKey = "time"
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout("15:04:05.000")
encoderConfig.EncodeLevel = noticeLevelEncoder
// Test with NOTICE level
encoder := zapcore.NewConsoleEncoder(encoderConfig)
core := zapcore.NewCore(encoder, zapcore.AddSync(&buf), ctrld.NoticeLevel)
logger := zap.New(core)
ctrldLogger := &ctrld.Logger{Logger: logger}
// Log messages at different levels
ctrldLogger.Debug().Msg("This is a DEBUG message")
ctrldLogger.Info().Msg("This is an INFO message")
ctrldLogger.Notice().Msg("This is a NOTICE message")
ctrldLogger.Warn().Msg("This is a WARN message")
ctrldLogger.Error().Msg("This is an ERROR message")
output := buf.String()
// Verify that DEBUG and INFO messages are NOT logged (filtered out)
if strings.Contains(output, "DEBUG") {
t.Error("DEBUG message should not be logged when level is NOTICE")
}
if strings.Contains(output, "INFO") {
t.Error("INFO message should not be logged when level is NOTICE")
}
// Verify that NOTICE, WARN, and ERROR messages ARE logged
if !strings.Contains(output, "NOTICE") {
t.Error("NOTICE message should be logged when level is NOTICE")
}
if !strings.Contains(output, "WARN") {
t.Error("WARN message should be logged when level is NOTICE")
}
if !strings.Contains(output, "ERROR") {
t.Error("ERROR message should be logged when level is NOTICE")
}
// Verify the NOTICE message content
if !strings.Contains(output, "This is a NOTICE message") {
t.Error("NOTICE message content should be present")
}
t.Logf("Log output with NOTICE level:\n%s", output)
}
func TestNewLogReader(t *testing.T) {
tests := []struct {
name string
bufContent string
stripColor bool
expected string
description string
}{
{
name: "empty_buffer_no_color_strip",
bufContent: "",
stripColor: false,
expected: "",
description: "Empty buffer should return empty reader",
},
{
name: "empty_buffer_with_color_strip",
bufContent: "",
stripColor: true,
expected: "",
description: "Empty buffer with color strip should return empty reader",
},
{
name: "plain_text_no_color_strip",
bufContent: "This is plain text without any color codes",
stripColor: false,
expected: "This is plain text without any color codes",
description: "Plain text should be returned as-is when not stripping colors",
},
{
name: "plain_text_with_color_strip",
bufContent: "This is plain text without any color codes",
stripColor: true,
expected: "This is plain text without any color codes",
description: "Plain text should be returned as-is when stripping colors",
},
{
name: "text_with_ansi_codes_no_strip",
bufContent: "Normal text \x1b[31mred text\x1b[0m normal again",
stripColor: false,
expected: "Normal text \x1b[31mred text\x1b[0m normal again",
description: "ANSI color codes should be preserved when not stripping",
},
{
name: "text_with_ansi_codes_with_strip",
bufContent: "Normal text \x1b[31mred text\x1b[0m normal again",
stripColor: true,
expected: "Normal text red text normal again",
description: "ANSI color codes should be removed when stripping colors",
},
{
name: "multiple_ansi_codes_no_strip",
bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
stripColor: false,
expected: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
description: "Multiple ANSI codes should be preserved when not stripping",
},
{
name: "multiple_ansi_codes_with_strip",
bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
stripColor: true,
expected: "Bold Green Blue text",
description: "Multiple ANSI codes should be removed when stripping colors",
},
{
name: "complex_ansi_sequences_no_strip",
bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
stripColor: false,
expected: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
description: "Complex ANSI sequences should be preserved when not stripping",
},
{
name: "complex_ansi_sequences_with_strip",
bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
stripColor: true,
expected: "Bold red on green Orange",
description: "Complex ANSI sequences should be removed when stripping colors",
},
{
name: "ansi_codes_with_newlines_no_strip",
bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
stripColor: false,
expected: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
description: "ANSI codes with newlines should be preserved when not stripping",
},
{
name: "ansi_codes_with_newlines_with_strip",
bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
stripColor: true,
expected: "Line 1\nRed line\nLine 3",
description: "ANSI codes with newlines should be removed when stripping colors",
},
{
name: "malformed_ansi_codes_no_strip",
bufContent: "Text \x1b[invalidm \x1b[0m normal",
stripColor: false,
expected: "Text \x1b[invalidm \x1b[0m normal",
description: "Malformed ANSI codes should be preserved when not stripping",
},
{
name: "malformed_ansi_codes_with_strip",
bufContent: "Text \x1b[invalidm \x1b[0m normal",
stripColor: true,
expected: "Text \x1b[invalidm normal",
description: "Non-matching ANSI sequences should be preserved when stripping colors",
},
{
name: "large_buffer_no_strip",
bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
stripColor: false,
expected: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
description: "Large buffer should handle ANSI codes correctly when not stripping",
},
{
name: "large_buffer_with_strip",
bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
stripColor: true,
expected: strings.Repeat("A", 10000) + strings.Repeat("B", 1000),
description: "Large buffer should remove ANSI codes correctly when stripping",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a buffer with the test content
buf := &bytes.Buffer{}
buf.WriteString(tt.bufContent)
// Create the log reader
reader := newLogReader(buf, tt.stripColor)
// Read all content from the reader
content, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("Failed to read from log reader: %v", err)
}
// Verify the content matches expected
actual := string(content)
if actual != tt.expected {
t.Errorf("Expected content: %q, got: %q", tt.expected, actual)
t.Logf("Description: %s", tt.description)
}
})
}
}
func TestNewLogReader_ReaderBehavior(t *testing.T) {
// Test that the returned reader behaves correctly
buf := &bytes.Buffer{}
buf.WriteString("Test content with \x1b[31mred\x1b[0m text")
// Test with color stripping
reader := newLogReader(buf, true)
// Test reading in chunks
chunk1 := make([]byte, 10)
n1, err := reader.Read(chunk1)
if err != nil && err != io.EOF {
t.Fatalf("Unexpected error reading first chunk: %v", err)
}
if n1 != 10 {
t.Errorf("Expected to read 10 bytes, got %d", n1)
}
// Test reading remaining content
remaining, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("Failed to read remaining content: %v", err)
}
// Verify total content
totalContent := string(chunk1[:n1]) + string(remaining)
expected := "Test content with red text"
if totalContent != expected {
t.Errorf("Expected total content: %q, got: %q", expected, totalContent)
}
}
func TestNewLogReader_ConcurrentAccess(t *testing.T) {
// Test concurrent access to the same buffer
buf := &bytes.Buffer{}
buf.WriteString("Concurrent test with \x1b[32mgreen\x1b[0m text")
var wg sync.WaitGroup
numGoroutines := 10
results := make(chan string, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
reader := newLogReader(buf, true)
content, err := io.ReadAll(reader)
if err != nil {
t.Errorf("Failed to read content: %v", err)
return
}
results <- string(content)
}()
}
wg.Wait()
close(results)
// Verify all goroutines got the same result
expected := "Concurrent test with green text"
for result := range results {
if result != expected {
t.Errorf("Expected: %q, got: %q", expected, result)
}
}
}
func TestNewLogReader_ANSIRegexEdgeCases(t *testing.T) {
// Test edge cases for ANSI regex matching
tests := []struct {
name string
input string
expected string
}{
{
name: "empty_escape_sequence",
input: "Text \x1b[m normal",
expected: "Text normal",
},
{
name: "multiple_semicolons",
input: "Text \x1b[1;2;3;4m normal",
expected: "Text normal",
},
{
name: "numeric_only",
input: "Text \x1b[123m normal",
expected: "Text normal",
},
{
name: "mixed_numeric_semicolon",
input: "Text \x1b[1;23;456m normal",
expected: "Text normal",
},
{
name: "no_closing_bracket",
input: "Text \x1b[31 normal",
expected: "Text \x1b[31 normal",
},
{
name: "no_opening_bracket",
input: "Text 31m normal",
expected: "Text 31m normal",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
buf.WriteString(tt.input)
reader := newLogReader(buf, true)
content, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("Failed to read content: %v", err)
}
actual := string(content)
if actual != tt.expected {
t.Errorf("Expected: %q, got: %q", tt.expected, actual)
}
})
}
}

View File

@@ -84,7 +84,7 @@ func (p *prog) detectLoop(msg *dns.Msg) {
//
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
func (p *prog) checkDnsLoop() {
p.Debug().Msg("Start checking DNS loop")
mainLog.Load().Debug().Msg("start checking DNS loop")
upstream := make(map[string]*ctrld.UpstreamConfig)
p.loopMu.Lock()
for n, uc := range p.cfg.Upstream {
@@ -93,7 +93,7 @@ func (p *prog) checkDnsLoop() {
}
// Do not send test query to external upstream.
if !canBeLocalUpstream(uc.Domain) {
p.Debug().Msgf("Skipping external: upstream.%s", n)
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
continue
}
uid := uc.UID()
@@ -102,7 +102,6 @@ func (p *prog) checkDnsLoop() {
}
p.loopMu.Unlock()
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
for uid := range p.loop {
msg := loopTestMsg(uid)
uc := upstream[uid]
@@ -110,16 +109,16 @@ func (p *prog) checkDnsLoop() {
if uc == nil {
continue
}
resolver, err := ctrld.NewResolver(loggerCtx, uc)
resolver, err := ctrld.NewResolver(uc)
if err != nil {
p.Warn().Err(err).Msgf("Could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
continue
}
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
p.Warn().Err(err).Msgf("Could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
}
}
p.Debug().Msg("End checking DNS loop")
mainLog.Load().Debug().Msg("end checking DNS loop")
}
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
@@ -138,7 +137,7 @@ func (p *prog) checkDnsLoopTicker(ctx context.Context) {
}
}
// loopTestMsg creates a DNS test message for loop detection
// loopTestMsg generates DNS message for checking loop.
func loopTestMsg(uid string) *dns.Msg {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)

View File

@@ -5,16 +5,14 @@ import (
"os"
"path/filepath"
"sync/atomic"
"time"
"github.com/kardianos/service"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/rs/zerolog"
"github.com/Control-D-Inc/ctrld"
)
// Global variables for CLI configuration and state management
// These are used across multiple commands and need to persist throughout the application lifecycle
var (
configPath string
configBase64 string
@@ -43,14 +41,11 @@ var (
startOnly bool
rfc1918 bool
mainLog atomic.Pointer[ctrld.Logger]
consoleWriter zapcore.Core
consoleWriterLevel zapcore.Level
noConfigStart bool
mainLog atomic.Pointer[zerolog.Logger]
consoleWriter zerolog.ConsoleWriter
noConfigStart bool
)
// Flag name constants for consistent reference across the codebase
// Using constants prevents typos and makes refactoring easier
const (
cdUidFlagName = "cd"
cdOrgFlagName = "cd-org"
@@ -58,26 +53,20 @@ const (
nextdnsFlagName = "nextdns"
)
// init initializes the default logger before any CLI commands are executed
// This ensures logging is available even during early initialization phases
func init() {
l := zap.NewNop()
mainLog.Store(&ctrld.Logger{Logger: l})
l := zerolog.New(io.Discard)
mainLog.Store(&l)
}
// Main is the entry point for the CLI application
// It initializes configuration, sets up the CLI structure, and executes the root command
func Main() {
ctrld.InitConfig(v, "ctrld")
rootCmd := initCLI()
initCLI()
if err := rootCmd.Execute(); err != nil {
mainLog.Load().Error().Msg(err.Error())
os.Exit(1)
}
}
// normalizeLogFilePath converts relative log file paths to absolute paths
// This ensures log files are created in predictable locations regardless of working directory
func normalizeLogFilePath(logFilePath string) string {
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
return logFilePath
@@ -93,36 +82,40 @@ func normalizeLogFilePath(logFilePath string) string {
}
// initConsoleLogging initializes console logging, then storing to mainLog.
// This sets up human-readable logging output for interactive use
func initConsoleLogging() {
consoleWriterLevel = ctrld.NoticeLevel
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
w.TimeFormat = time.StampMilli
})
multi := zerolog.MultiLevelWriter(consoleWriter)
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
mainLog.Store(&l)
switch {
case silent:
// For silent mode, use a no-op logger to suppress all output
l := zap.NewNop()
mainLog.Store(&ctrld.Logger{Logger: l})
zerolog.SetGlobalLevel(zerolog.NoLevel)
case verbose == 1:
// Info level provides basic operational information
consoleWriterLevel = zapcore.InfoLevel
ctrld.ProxyLogger.Store(&l)
zerolog.SetGlobalLevel(zerolog.InfoLevel)
case verbose > 1:
// Debug level provides detailed diagnostic information
consoleWriterLevel = zapcore.DebugLevel
ctrld.ProxyLogger.Store(&l)
zerolog.SetGlobalLevel(zerolog.DebugLevel)
default:
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
}
consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel)
l := zap.New(consoleWriter)
mainLog.Store(&ctrld.Logger{Logger: l})
}
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
// to be used for all interactive commands.
//
// Current log file config will also be ignored.
// This prevents log file conflicts during interactive command execution
func initInteractiveLogging() {
old := cfg.Service.LogPath
cfg.Service.LogPath = ""
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
initLoggingWithBackup(false)
cfg.Service.LogPath = old
l := zerolog.New(io.Discard)
ctrld.ProxyLogger.Store(&l)
}
// initLoggingWithBackup initializes log setup base on current config.
@@ -131,101 +124,68 @@ func initInteractiveLogging() {
// This is only used in runCmd for special handling in case of logging config
// change in cd mode. Without special reason, the caller should use initLogging
// wrapper instead of calling this function directly.
func initLoggingWithBackup(doBackup bool) []zapcore.Core {
func initLoggingWithBackup(doBackup bool) []io.Writer {
var writers []io.Writer
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
// Create parent directory if necessary.
// This ensures log files can be created even if the directory doesn't exist
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
mainLog.Load().Error().Msgf("Failed to create log path: %v", err)
mainLog.Load().Error().Msgf("failed to create log path: %v", err)
os.Exit(1)
}
// Default open log file in append mode.
// This preserves existing log entries across restarts
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
if doBackup {
// Backup old log file with .1 suffix.
// This prevents log file corruption during rotation
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
mainLog.Load().Error().Msgf("Could not backup old log file: %v", err)
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
} else {
// Backup was created, set flags for truncating old log file.
// This ensures a clean start for the new log file
flags = os.O_CREATE | os.O_RDWR
}
}
logFile, err := openLogFile(logFilePath, flags)
if err != nil {
mainLog.Load().Error().Msgf("Failed to create log file: %v", err)
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
os.Exit(1)
}
writers = append(writers, logFile)
}
writers = append(writers, consoleWriter)
multi := zerolog.MultiLevelWriter(writers...)
l := mainLog.Load().Output(multi).With().Logger()
mainLog.Store(&l)
// TODO: find a better way.
ctrld.ProxyLogger.Store(&l)
// Create zap cores for different writers
// Multiple cores allow logging to both console and file simultaneously
var cores []zapcore.Core
cores = append(cores, consoleWriter)
// Determine log level based on verbosity and configuration
// This provides flexible logging control for different use cases
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
logLevel := cfg.Service.LogLevel
switch {
case silent:
// For silent mode, use a no-op logger to suppress all output
l := zap.NewNop()
mainLog.Store(&ctrld.Logger{Logger: l})
return cores
zerolog.SetGlobalLevel(zerolog.NoLevel)
return writers
case verbose == 1:
logLevel = "info"
case verbose > 1:
logLevel = "debug"
}
// Parse log level string to zapcore.Level
// This provides human-readable log level configuration
var level zapcore.Level
switch logLevel {
case "debug":
level = zapcore.DebugLevel
case "info":
level = zapcore.InfoLevel
case "notice":
level = ctrld.NoticeLevel
case "warn":
level = zapcore.WarnLevel
case "error":
level = zapcore.ErrorLevel
default:
level = zapcore.InfoLevel // default level
if logLevel == "" {
return writers
}
consoleWriter.Enabled(level)
// Add cores for all writers
// This enables multi-destination logging (console + file)
for _, writer := range writers {
core := newMachineFriendlyZapCore(writer, level)
cores = append(cores, core)
level, err := zerolog.ParseLevel(logLevel)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not set log level")
return writers
}
// Create a multi-core logger
// This allows simultaneous logging to multiple destinations
multiCore := zapcore.NewTee(cores...)
logger := zap.New(multiCore)
mainLog.Store(&ctrld.Logger{Logger: logger})
return cores
zerolog.SetGlobalLevel(level)
return writers
}
// initCache initializes DNS cache configuration
// This improves performance by caching frequently requested DNS responses
func initCache() {
if !cfg.Service.CacheEnable {
return
}
if cfg.Service.CacheSize == 0 {
// Default cache size provides good balance between memory usage and performance
cfg.Service.CacheSize = 4096
}
}

View File

@@ -5,28 +5,13 @@ import (
"strings"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/Control-D-Inc/ctrld"
"github.com/rs/zerolog"
)
var logOutput strings.Builder
func TestMain(m *testing.M) {
// Create a custom writer that writes to logOutput
writer := zapcore.AddSync(&logOutput)
// Create zap encoder
encoderConfig := zap.NewDevelopmentEncoderConfig()
encoder := zapcore.NewConsoleEncoder(encoderConfig)
// Create core that writes to our string builder
core := zapcore.NewCore(encoder, writer, zap.DebugLevel)
// Create logger
l := zap.New(core)
mainLog.Store(&ctrld.Logger{Logger: l})
l := zerolog.New(&logOutput)
mainLog.Store(&l)
os.Exit(m.Run())
}

View File

@@ -15,7 +15,6 @@ import (
)
// metricsServer represents a server to expose Prometheus metrics via HTTP.
// This provides monitoring and observability for the DNS proxy service
type metricsServer struct {
server *http.Server
mux *http.ServeMux
@@ -25,7 +24,6 @@ type metricsServer struct {
}
// newMetricsServer returns new metrics server.
// This initializes the HTTP server for exposing Prometheus metrics
func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) {
mux := http.NewServeMux()
ms := &metricsServer{
@@ -39,13 +37,11 @@ func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, er
}
// register adds handlers for given pattern.
// This provides a clean interface for adding HTTP endpoints to the metrics server
func (ms *metricsServer) register(pattern string, handler http.Handler) {
ms.mux.Handle(pattern, handler)
}
// registerMetricsServerHandler adds handlers for metrics server.
// This sets up both Prometheus format and JSON format endpoints for metrics
func (ms *metricsServer) registerMetricsServerHandler() {
ms.register("/metrics", promhttp.HandlerFor(
ms.reg,
@@ -78,7 +74,6 @@ func (ms *metricsServer) registerMetricsServerHandler() {
}
// start runs the metricsServer.
// This starts the HTTP server for metrics exposure
func (ms *metricsServer) start() error {
listener, err := net.Listen("tcp", ms.addr)
if err != nil {
@@ -90,7 +85,6 @@ func (ms *metricsServer) start() error {
}
// stop shutdowns the metricsServer within 2 seconds timeout.
// This ensures graceful shutdown of the metrics server
func (ms *metricsServer) stop() error {
if !ms.started {
return nil
@@ -101,7 +95,6 @@ func (ms *metricsServer) stop() error {
}
// runMetricsServer initializes metrics stats and runs the metrics server if enabled.
// This sets up the complete metrics infrastructure including Prometheus collectors
func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
if !p.metricsEnabled() {
return
@@ -122,7 +115,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
addr := p.cfg.Service.MetricsListener
ms, err := newMetricsServer(addr, reg)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Could not create new metrics server")
mainLog.Load().Warn().Err(err).Msg("could not create new metrics server")
return
}
// Only start listener address if defined.
@@ -137,9 +130,9 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
statsVersion.WithLabelValues(commit, runtime.Version(), curVersion()).Inc()
reg.MustRegister(statsTimeStart)
statsTimeStart.Set(float64(time.Now().Unix()))
mainLog.Load().Debug().Msgf("Starting metrics server on: %s", addr)
mainLog.Load().Debug().Msgf("starting metrics server on: %s", addr)
if err := ms.start(); err != nil {
mainLog.Load().Warn().Err(err).Msg("Could not start metrics server")
mainLog.Load().Warn().Err(err).Msg("could not start metrics server")
return
}
}
@@ -151,7 +144,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
}
if err := ms.stop(); err != nil {
mainLog.Load().Warn().Err(err).Msg("Could not stop metrics server")
mainLog.Load().Warn().Err(err).Msg("could not stop metrics server")
return
}
}

View File

@@ -49,3 +49,28 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
_, ok := validIfacesMap[iface.Name]
return ok
}
// validInterfacesMap returns a set of all valid hardware ports.
func validInterfacesMap() map[string]struct{} {
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
if err != nil {
return nil
}
return parseListAllHardwarePorts(bytes.NewReader(b))
}
// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports"
// and returns map presents all hardware ports.
func parseListAllHardwarePorts(r io.Reader) map[string]struct{} {
m := make(map[string]struct{})
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
after, ok := strings.CutPrefix(line, "Device: ")
if !ok {
continue
}
m[after] = struct{}{}
}
return m
}

View File

@@ -2,16 +2,51 @@ package cli
import (
"net"
"net/netip"
"os"
"strings"
"tailscale.com/net/netmon"
)
// patchNetIfaceName patches network interface names on Linux
// This is a no-op on Linux as interface names don't need special handling
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
// validInterface reports whether the *net.Interface is a valid one.
// Only non-virtual interfaces are considered valid.
// This prevents DNS configuration on virtual interfaces like docker, veth, etc.
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
_, ok := validIfacesMap[iface.Name]
return ok
}
// validInterfacesMap returns a set containing non virtual interfaces.
func validInterfacesMap() map[string]struct{} {
m := make(map[string]struct{})
vis := virtualInterfaces()
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
if _, existed := vis[i.Name]; existed {
return
}
m[i.Name] = struct{}{}
})
// Fallback to default route interface if found nothing.
if len(m) == 0 {
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
return m
}
m[defaultRoute.InterfaceName] = struct{}{}
}
return m
}
// virtualInterfaces returns a map of virtual interfaces on current machine.
func virtualInterfaces() map[string]struct{} {
s := make(map[string]struct{})
entries, _ := os.ReadDir("/sys/devices/virtual/net")
for _, entry := range entries {
if entry.IsDir() {
s[strings.TrimSpace(entry.Name())] = struct{}{}
}
}
return s
}

View File

@@ -4,10 +4,19 @@ package cli
import (
"net"
"tailscale.com/net/netmon"
)
// patchNetIfaceName patches network interface names on non-Linux/Darwin platforms
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
// validInterface checks if an interface is valid on non-Linux/Darwin platforms
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
// validInterfacesMap returns a set containing only default route interfaces.
func validInterfacesMap() map[string]struct{} {
defaultRoute, err := netmon.DefaultRoute()
if err != nil {
return nil
}
return map[string]struct{}{defaultRoute.InterfaceName: {}}
}

View File

@@ -1,7 +1,16 @@
package cli
import (
"io"
"log"
"net"
"os"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
)
func patchNetIfaceName(iface *net.Interface) (bool, error) {
@@ -14,3 +23,71 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
_, ok := validIfacesMap[iface.Name]
return ok
}
// validInterfacesMap returns a set of all physical interfaces.
func validInterfacesMap() map[string]struct{} {
m := make(map[string]struct{})
for _, ifaceName := range validInterfaces() {
m[ifaceName] = struct{}{}
}
return m
}
// validInterfaces returns a list of all physical interfaces.
func validInterfaces() []string {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("MSFT_NetAdapter")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
if instances != nil {
defer instances.Close()
}
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
return nil
}
var adapters []string
for _, i := range instances {
adapter, err := netadapter.NewNetworkAdapter(i)
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
continue
}
name, err := adapter.GetPropertyName()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
continue
}
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
//
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
// if this is a physical adapter or FALSE if this is not a physical adapter."
physical, err := adapter.GetPropertyConnectorPresent()
if err != nil {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
continue
}
if !physical {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
continue
}
// Check if it's a hardware interface. Checking only for connector present is not enough
// because some interfaces are not physical but have a connector.
hardware, err := adapter.GetPropertyHardwareInterface()
if err != nil {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
continue
}
if !hardware {
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
continue
}
adapters = append(adapters, name)
}
return adapters
}

View File

@@ -3,23 +3,18 @@ package cli
import (
"bufio"
"bytes"
"context"
"maps"
"slices"
"strings"
"testing"
"time"
"github.com/Control-D-Inc/ctrld"
)
func Test_validInterfaces(t *testing.T) {
verbose = 3
initConsoleLogging()
start := time.Now()
im := ctrld.ValidInterfaces(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
ifaces := validInterfaces()
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
ifaces := slices.Collect(maps.Keys(im))
start = time.Now()
ifacesPowershell := validInterfacesPowershell()

View File

@@ -5,8 +5,6 @@ import (
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"github.com/Control-D-Inc/ctrld"
)
func (p *prog) watchLinkState(ctx context.Context) {
@@ -14,7 +12,7 @@ func (p *prog) watchLinkState(ctx context.Context) {
done := make(chan struct{})
defer close(done)
if err := netlink.LinkSubscribe(ch, done); err != nil {
p.Warn().Err(err).Msg("Could not subscribe link")
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
return
}
for {
@@ -26,9 +24,9 @@ func (p *prog) watchLinkState(ctx context.Context) {
continue
}
if lu.Change&unix.IFF_UP != 0 {
p.Debug().Msgf("Link state changed, re-bootstrapping")
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
for _, uc := range p.cfg.Upstream {
uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load()))
uc.ReBootstrap()
}
}
}

View File

@@ -23,67 +23,66 @@ systemd-resolved=false
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
// hasNetworkManager reports whether NetworkManager executable found.
// hasNetworkManager checks if NetworkManager is available on the system
func hasNetworkManager() bool {
exe, _ := exec.LookPath("NetworkManager")
return exe != ""
}
func (p *prog) setupNetworkManager() error {
func setupNetworkManager() error {
if !hasNetworkManager() {
return nil
}
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
p.Debug().Msg("NetworkManager already setup, nothing to do")
mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do")
return nil
}
err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644))
if os.IsNotExist(err) {
p.Debug().Msg("NetworkManager is not available")
mainLog.Load().Debug().Msg("NetworkManager is not available")
return nil
}
if err != nil {
p.Debug().Err(err).Msg("Could not write NetworkManager ctrld config file")
mainLog.Load().Debug().Err(err).Msg("could not write NetworkManager ctrld config file")
return err
}
p.reloadNetworkManager()
p.Debug().Msg("Setup NetworkManager done")
reloadNetworkManager()
mainLog.Load().Debug().Msg("setup NetworkManager done")
return nil
}
func (p *prog) restoreNetworkManager() error {
func restoreNetworkManager() error {
if !hasNetworkManager() {
return nil
}
err := os.Remove(networkManagerCtrldConfFile)
if os.IsNotExist(err) {
p.Debug().Msg("NetworkManager is not available")
mainLog.Load().Debug().Msg("NetworkManager is not available")
return nil
}
if err != nil {
p.Debug().Err(err).Msg("Could not remove NetworkManager ctrld config file")
mainLog.Load().Debug().Err(err).Msg("could not remove NetworkManager ctrld config file")
return err
}
p.reloadNetworkManager()
p.Debug().Msg("Restore NetworkManager done")
reloadNetworkManager()
mainLog.Load().Debug().Msg("restore NetworkManager done")
return nil
}
func (p *prog) reloadNetworkManager() {
func reloadNetworkManager() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
conn, err := dbus.NewSystemConnectionContext(ctx)
if err != nil {
p.Error().Err(err).Msg("Could not create new system connection")
mainLog.Load().Error().Err(err).Msg("could not create new system connection")
return
}
defer conn.Close()
waitCh := make(chan string)
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
p.Debug().Err(err).Msg("Could not reload NetworkManager")
mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager")
return
}
<-waitCh

View File

@@ -2,14 +2,14 @@
package cli
func (p *prog) setupNetworkManager() error {
p.reloadNetworkManager()
func setupNetworkManager() error {
reloadNetworkManager()
return nil
}
func (p *prog) restoreNetworkManager() error {
p.reloadNetworkManager()
func restoreNetworkManager() error {
reloadNetworkManager()
return nil
}
func (p *prog) reloadNetworkManager() {}
func reloadNetworkManager() {}

View File

@@ -8,12 +8,11 @@ import (
const nextdnsURL = "https://dns.nextdns.io"
// generateNextDNSConfig generates NextDNS configuration for the given UID
func generateNextDNSConfig(uid string) {
if uid == "" {
return
}
mainLog.Load().Info().Msg("Generating ctrld config for NextDNS resolver")
mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver")
cfg = ctrld.Config{
Listener: map[string]*ctrld.ListenerConfig{
"0": {

View File

@@ -8,31 +8,26 @@ import (
"os/exec"
"strings"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
// allocateIP allocates an IP address on the specified interface
// allocate loopback ip
// sudo ifconfig lo0 alias 127.0.0.2 up
func allocateIP(ip string) error {
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up")
if err := cmd.Run(); err != nil {
mainLog.Load().Error().Err(err).Msg("AllocateIP failed")
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
return err
}
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
return nil
}
// deAllocateIP deallocates an IP address from the specified interface
func deAllocateIP(ip string) error {
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
cmd := exec.Command("ifconfig", "lo0", "-alias", ip)
if err := cmd.Run(); err != nil {
mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed")
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
return err
}
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
return nil
}
@@ -52,8 +47,6 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
// TODO(cuonglm): use system API
func setDNS(iface *net.Interface, nameservers []string) error {
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
// Note that networksetup won't modify search domains settings,
// This assignment is just a placeholder to silent linter.
_ = searchDomains
@@ -63,8 +56,6 @@ func setDNS(iface *net.Interface, nameservers []string) error {
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
return fmt.Errorf("%v: %w", string(out), err)
}
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
return nil
}
@@ -82,30 +73,25 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
// TODO(cuonglm): use system API
func resetDNS(iface *net.Interface) error {
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
cmd := "networksetup"
args := []string{"-setdnsservers", iface.Name, "empty"}
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
return fmt.Errorf("%v: %w", string(out), err)
}
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully")
return nil
}
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
if ns := ctrld.SavedStaticNameservers(iface); len(ns) > 0 {
if ns := savedStaticNameservers(iface); len(ns) > 0 {
err = setDNS(iface, ns)
}
return err
}
// currentDNS returns the current DNS servers for the specified interface
func currentDNS(_ *net.Interface) []string {
return ctrld.CurrentNameserversFromResolvconf()
return resolvconffile.NameServers()
}
// currentStaticDNS returns the current static DNS settings of given interface.

View File

@@ -9,32 +9,27 @@ import (
"tailscale.com/health"
"tailscale.com/util/dnsname"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dns"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
// allocateIP allocates an IP address on the specified interface
// allocate loopback ip
// sudo ifconfig lo0 127.0.0.53 alias
func allocateIP(ip string) error {
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
cmd := exec.Command("ifconfig", "lo0", ip, "alias")
if err := cmd.Run(); err != nil {
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
return err
}
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
return nil
}
// deAllocateIP deallocates an IP address from the specified interface
func deAllocateIP(ip string) error {
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
cmd := exec.Command("ifconfig", "lo0", ip, "-alias")
if err := cmd.Run(); err != nil {
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
return err
}
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
return nil
}
@@ -45,11 +40,9 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
// set the dns server for the provided network interface
func setDNS(iface *net.Interface, nameservers []string) error {
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
if err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator")
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
return err
}
@@ -65,15 +58,13 @@ func setDNS(iface *net.Interface, nameservers []string) error {
if sds, err := searchDomains(); err == nil {
osConfig.SearchDomains = sds
} else {
mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list")
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
}
if err := r.SetDNS(osConfig); err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to set DNS")
mainLog.Load().Error().Err(err).Msg("failed to set DNS")
return err
}
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
return nil
}
@@ -82,22 +73,17 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
return resetDNS(iface)
}
// resetDNS resets DNS servers for the specified interface
func resetDNS(iface *net.Interface) error {
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
if err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator")
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
return err
}
if err := r.Close(); err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to rollback DNS setting")
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
return err
}
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully")
return nil
}
@@ -107,9 +93,8 @@ func restoreDNS(iface *net.Interface) (err error) {
return err
}
// currentDNS returns the current DNS servers for the specified interface
func currentDNS(_ *net.Interface) []string {
return ctrld.CurrentNameserversFromResolvconf()
return resolvconffile.NameServers()
}
// currentStaticDNS returns the current static DNS settings of given interface.

View File

@@ -21,36 +21,30 @@ import (
"tailscale.com/health"
"tailscale.com/util/dnsname"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dns"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
type getDNS func(iface string) []string
// allocate loopback ip
// sudo ip a add 127.0.0.2/24 dev lo
func allocateIP(ip string) error {
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
cmd := exec.Command("ip", "a", "add", ip+"/24", "dev", "lo")
if out, err := cmd.CombinedOutput(); err != nil {
mainLog.Load().Error().Err(err).Msgf("AllocateIP failed: %s", string(out))
mainLog.Load().Error().Err(err).Msgf("allocateIP failed: %s", string(out))
return err
}
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
return nil
}
func deAllocateIP(ip string) error {
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
cmd := exec.Command("ip", "a", "del", ip+"/24", "dev", "lo")
if err := cmd.Run(); err != nil {
mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed")
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
return err
}
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
return nil
}
@@ -62,11 +56,9 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
}
func setDNS(iface *net.Interface, nameservers []string) error {
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
if err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to create dns os configurator")
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
return err
}
@@ -80,9 +72,17 @@ func setDNS(iface *net.Interface, nameservers []string) error {
SearchDomains: []dnsname.FQDN{},
}
if sds, err := searchDomains(); err == nil {
osConfig.SearchDomains = sds
// Filter the root domain, since it's not allowed by systemd.
// See https://github.com/systemd/systemd/issues/9515
filteredSds := slices.DeleteFunc(sds, func(s dnsname.FQDN) bool {
return s == "" || s == "."
})
if len(filteredSds) != len(sds) {
mainLog.Load().Debug().Msg(`Removed root domain "." from search domains list`)
}
osConfig.SearchDomains = filteredSds
} else {
mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list")
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
}
trySystemdResolve := false
if err := r.SetDNS(osConfig); err != nil {
@@ -125,8 +125,6 @@ systemdResolve:
}
mainLog.Load().Debug().Msg("DNS was not set for some reason")
}
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
return nil
}
@@ -136,8 +134,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
}
func resetDNS(iface *net.Interface) (err error) {
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
defer func() {
if err == nil {
return
@@ -149,7 +145,7 @@ func resetDNS(iface *net.Interface) (err error) {
if r, oerr := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name); oerr == nil {
_ = r.SetDNS(dns.OSConfig{})
if err := r.Close(); err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to rollback dns setting")
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
return
}
err = nil
@@ -177,18 +173,18 @@ func resetDNS(iface *net.Interface) (err error) {
}
// TODO(cuonglm): handle DHCPv6 properly.
mainLog.Load().Debug().Msg("Checking for ipv6 availability")
mainLog.Load().Debug().Msg("checking for IPv6 availability")
if ctrldnet.IPv6Available(ctx) {
c := client6.NewClient()
conversation, err := c.Exchange(iface.Name)
if err != nil && !errAddrInUse(err) {
mainLog.Load().Debug().Err(err).Msg("Could not exchange dhcpv6")
mainLog.Load().Debug().Err(err).Msg("could not exchange DHCPv6")
}
for _, packet := range conversation {
if packet.Type() == dhcpv6.MessageTypeReply {
msg, err := packet.GetInnerMessage()
if err != nil {
mainLog.Load().Debug().Err(err).Msg("Could not get inner dhcpv6 message")
mainLog.Load().Debug().Err(err).Msg("could not get inner DHCPv6 message")
return nil
}
nameservers := msg.Options.DNS()
@@ -213,7 +209,7 @@ func restoreDNS(iface *net.Interface) (err error) {
}
func currentDNS(iface *net.Interface) []string {
resolvconfFunc := func(_ string) []string { return ctrld.CurrentNameserversFromResolvconf() }
resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() }
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
if ns := fn(iface.Name); len(ns) > 0 {
return ns

View File

@@ -2,12 +2,12 @@
package cli
// allocateIP allocates an IP address on the specified interface
// TODO(cuonglm): implement.
func allocateIP(ip string) error {
return nil
}
// deAllocateIP deallocates an IP address from the specified interface
// TODO(cuonglm): implement.
func deAllocateIP(ip string) error {
return nil
}

View File

@@ -1,18 +1,21 @@
package cli
import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"slices"
"strings"
"sync"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/Control-D-Inc/ctrld"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
@@ -21,6 +24,11 @@ const (
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
)
var (
setDNSOnce sync.Once
resetDNSOnce sync.Once
)
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
return setDNS(iface, nameservers)
@@ -31,7 +39,49 @@ func setDNS(iface *net.Interface, nameservers []string) error {
if len(nameservers) == 0 {
return errors.New("empty DNS nameservers")
}
setDNSOnce.Do(func() {
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
// Configuring the Dns server to forward queries to ctrld instead.
if hasLocalDnsServerRunning() {
mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders")
file := absHomeDir(windowsForwardersFilename)
mainLog.Load().Debug().Msgf("Using forwarders file: %s", file)
oldForwardersContent, err := os.ReadFile(file)
if err != nil {
mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file")
} else {
mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent))
}
hasLocalIPv6Listener := needLocalIPv6Listener()
mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status")
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
if !hasLocalIPv6Listener {
return false
}
return s == "::1"
})
mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list")
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
} else {
mainLog.Load().Debug().Msg("Successfully wrote new forwarders file")
}
oldForwarders := strings.Split(string(oldForwardersContent), ",")
mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders")
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
} else {
mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders")
}
}
})
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return fmt.Errorf("setDNS: %w", err)
@@ -75,8 +125,25 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
return resetDNS(iface)
}
// resetDNS resets DNS servers for the specified interface
// TODO(cuonglm): should we use system API?
func resetDNS(iface *net.Interface) error {
resetDNSOnce.Do(func() {
// See corresponding comment in setDNS.
if hasLocalDnsServerRunning() {
file := absHomeDir(windowsForwardersFilename)
content, err := os.ReadFile(file)
if err != nil {
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
return
}
nameservers := strings.Split(string(content), ",")
if err := removeDnsServerForwarders(nameservers); err != nil {
mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings")
return
}
}
})
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
return fmt.Errorf("resetDNS: %w", err)
@@ -94,7 +161,7 @@ func resetDNS(iface *net.Interface) error {
// restoreDNS restores the DNS settings of the given interface.
// this should only be executed upon turning off the ctrld service.
func restoreDNS(iface *net.Interface) (err error) {
if nss := ctrld.SavedStaticNameservers(iface); len(nss) > 0 {
if nss := savedStaticNameservers(iface); len(nss) > 0 {
v4ns := make([]string, 0, 2)
v6ns := make([]string, 0, 2)
for _, ns := range nss {
@@ -111,24 +178,24 @@ func restoreDNS(iface *net.Interface) (err error) {
}
if len(v4ns) > 0 {
mainLog.Load().Debug().Msgf("Restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
if err := setDNS(iface, v4ns); err != nil {
return fmt.Errorf("restoreDNS (IPv4): %w", err)
}
} else {
mainLog.Load().Debug().Msgf("Restoring IPv4 DHCP for interface %q", iface.Name)
mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name)
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
}
}
if len(v6ns) > 0 {
mainLog.Load().Debug().Msgf("Restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
if err := setDNS(iface, v6ns); err != nil {
return fmt.Errorf("restoreDNS (IPv6): %w", err)
}
} else {
mainLog.Load().Debug().Msgf("Restoring IPv6 DHCP for interface %q", iface.Name)
mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name)
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
}
@@ -137,16 +204,15 @@ func restoreDNS(iface *net.Interface) (err error) {
return err
}
// currentDNS returns the current DNS servers for the specified interface
func currentDNS(iface *net.Interface) []string {
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to get interface LUID")
mainLog.Load().Error().Err(err).Msg("failed to get interface LUID")
return nil
}
nameservers, err := luid.DNS()
if err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to get interface DNS")
mainLog.Load().Error().Err(err).Msg("failed to get interface DNS")
return nil
}
ns := make([]string, 0, len(nameservers))
@@ -174,7 +240,7 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) {
interfaceKeyPath := path + guid.String()
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("Failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
continue
}
func() {
@@ -182,11 +248,11 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) {
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
value, _, err := k.GetStringValue(keyName)
if err != nil && !errors.Is(err, registry.ErrNotExist) {
mainLog.Load().Debug().Err(err).Msgf("Error reading %s registry key", keyName)
mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName)
continue
}
if len(value) > 0 {
mainLog.Load().Debug().Msgf("Found static DNS for interface %q: %s", iface.Name, value)
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
parsed := parseDNSServers(value)
for _, pns := range parsed {
if !slices.Contains(ns, pns) {
@@ -198,7 +264,7 @@ func currentStaticDNS(iface *net.Interface) ([]string, error) {
}()
}
if len(ns) == 0 {
mainLog.Load().Debug().Msgf("No static DNS values found for interface %q", iface.Name)
mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name)
}
return ns, nil
}
@@ -218,3 +284,49 @@ func parseDNSServers(val string) []string {
}
return servers
}
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
// and also removing old forwarders if provided.
func addDnsServerForwarders(nameservers, old []string) error {
newForwardersMap := make(map[string]struct{})
newForwarders := make([]string, len(nameservers))
for i := range nameservers {
newForwardersMap[nameservers[i]] = struct{}{}
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
}
oldForwarders := old[:0]
for _, fwd := range old {
if _, ok := newForwardersMap[fwd]; !ok {
oldForwarders = append(oldForwarders, fwd)
}
}
// NOTE: It is important to add new forwarder before removing old one.
// Testing on Windows Server 2022 shows that removing forwarder1
// then adding forwarder2 sometimes ends up adding both of them
// to the forwarders list.
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
if len(oldForwarders) > 0 {
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
}
if out, err := powershell(cmd); err != nil {
return fmt.Errorf("%w: %s", err, string(out))
}
return nil
}
// removeDnsServerForwarders removes given nameservers from DNS server forwarders list.
func removeDnsServerForwarders(nameservers []string) error {
for _, ns := range nameservers {
cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns)
if out, err := powershell(cmd); err != nil {
return fmt.Errorf("%w: %s", err, string(out))
}
}
return nil
}
// powershell runs the given powershell command.
func powershell(cmd string) ([]byte, error) {
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
return bytes.TrimSpace(out), err
}

View File

@@ -1,10 +1,8 @@
package cli
import (
"bytes"
"fmt"
"net"
"os/exec"
"slices"
"strings"
"testing"
@@ -68,9 +66,3 @@ func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
}
return ns, nil
}
// powershell runs the given powershell command.
func powershell(cmd string) ([]byte, error) {
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
return bytes.TrimSpace(out), err
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,10 +4,8 @@ import (
"github.com/kardianos/service"
)
// setDependencies sets service dependencies for Darwin
func setDependencies(svc *service.Config) {}
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) {
svc.WorkingDirectory = dir
}

View File

@@ -6,11 +6,9 @@ import (
"github.com/kardianos/service"
)
// setDependencies sets service dependencies for FreeBSD
func setDependencies(svc *service.Config) {
// TODO(cuonglm): remove once https://github.com/kardianos/service/issues/359 fixed.
_ = os.MkdirAll("/usr/local/etc/rc.d", 0755)
}
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) {}

View File

@@ -9,6 +9,8 @@ import (
"strings"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld/internal/router"
)
func init() {
@@ -21,7 +23,6 @@ func init() {
}
}
// setDependencies sets service dependencies for Linux
func setDependencies(svc *service.Config) {
svc.Dependencies = []string{
"Wants=network-online.target",
@@ -36,9 +37,11 @@ func setDependencies(svc *service.Config) {
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
}
}
if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 {
svc.Dependencies = append(svc.Dependencies, routerDeps...)
}
}
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) {
svc.WorkingDirectory = dir
}

View File

@@ -1,33 +0,0 @@
package cli
import "github.com/Control-D-Inc/ctrld"
// Debug starts a new message with debug level.
func (p *prog) Debug() *ctrld.LogEvent {
return p.logger.Load().Debug()
}
// Warn starts a new message with warn level.
func (p *prog) Warn() *ctrld.LogEvent {
return p.logger.Load().Warn()
}
// Info starts a new message with info level.
func (p *prog) Info() *ctrld.LogEvent {
return p.logger.Load().Info()
}
// Fatal starts a new message with fatal level.
func (p *prog) Fatal() *ctrld.LogEvent {
return p.logger.Load().Fatal()
}
// Error starts a new message with error level.
func (p *prog) Error() *ctrld.LogEvent {
return p.logger.Load().Error()
}
// Notice starts a new message with notice level.
func (p *prog) Notice() *ctrld.LogEvent {
return p.logger.Load().Notice()
}

View File

@@ -4,10 +4,8 @@ package cli
import "github.com/kardianos/service"
// setDependencies sets service dependencies for other platforms
func setDependencies(svc *service.Config) {}
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) {
// WorkingDirectory is not supported on Windows.
svc.WorkingDirectory = dir

View File

@@ -1,12 +1,13 @@
package cli
import (
"runtime"
"testing"
"time"
"github.com/Masterminds/semver/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/Control-D-Inc/ctrld"
)
@@ -173,10 +174,10 @@ func Test_shouldUpgrade(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Create test logger
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
// Call the function and capture the result
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, testLogger)
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger)
// Assert the expected result
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
@@ -185,6 +186,10 @@ func Test_shouldUpgrade(t *testing.T) {
}
func Test_selfUpgradeCheck(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipped due to Windows file locking issue on Github Action runners")
}
// Helper function to create a version
makeVersion := func(v string) *semver.Version {
ver, err := semver.NewVersion(v)
@@ -221,10 +226,10 @@ func Test_selfUpgradeCheck(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Create test logger
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
// Call the function and capture the result
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, testLogger)
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger)
// Assert the expected result
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
@@ -233,6 +238,10 @@ func Test_selfUpgradeCheck(t *testing.T) {
}
func Test_performUpgrade(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipped due to Windows file locking issue on Github Action runners")
}
tests := []struct {
name string
versionTarget string
@@ -256,10 +265,8 @@ func Test_performUpgrade(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Create test logger
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
// Call the function and capture the result
result := performUpgrade(tc.versionTarget, testLogger)
result := performUpgrade(tc.versionTarget)
assert.Equal(t, tc.expectedResult, result, tc.description)
})
}

View File

@@ -2,10 +2,12 @@ package cli
import "github.com/kardianos/service"
// setDependencies sets service dependencies for Windows
func setDependencies(svc *service.Config) {}
func setDependencies(svc *service.Config) {
if hasLocalDnsServerRunning() {
svc.Dependencies = []string{"DNS"}
}
}
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) {
// WorkingDirectory is not supported on Windows.
svc.WorkingDirectory = dir

View File

@@ -2,8 +2,6 @@ package cli
import "github.com/prometheus/client_golang/prometheus"
// Prometheus metrics label constants for consistent labeling across all metrics
// These ensure standardized metric labeling for monitoring and alerting
const (
metricsLabelListener = "listener"
metricsLabelClientSourceIP = "client_source_ip"
@@ -15,21 +13,17 @@ const (
)
// statsVersion represent ctrld version.
// This metric provides version information for monitoring and debugging
var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ctrld_build_info",
Help: "Version of ctrld process.",
}, []string{"gitref", "goversion", "version"})
// statsTimeStart represents start time of ctrld service.
// This metric tracks service uptime and helps with monitoring service restarts
var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "ctrld_time_seconds",
Help: "Start time of the ctrld process since unix epoch in seconds.",
})
// statsQueriesCountLabels defines the labels for query count metrics
// These labels provide detailed breakdown of DNS query statistics
var statsQueriesCountLabels = []string{
metricsLabelListener,
metricsLabelClientSourceIP,
@@ -41,7 +35,6 @@ var statsQueriesCountLabels = []string{
}
// statsQueriesCount counts total number of queries.
// This provides comprehensive DNS query statistics for monitoring and alerting
var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ctrld_queries_count",
Help: "Total number of queries.",
@@ -51,14 +44,12 @@ var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
//
// The labels "client_source_ip", "client_mac", "client_hostname" are unbounded,
// thus this stat is highly inefficient if there are many devices.
// This metric should be used carefully in high-client environments
var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ctrld_client_queries_count",
Help: "Total number queries of a client.",
}, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname})
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
// This provides conditional metric collection to avoid performance impact when metrics are disabled
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
if p.metricsQueryStats.Load() {
c.WithLabelValues(lvs...).Inc()

View File

@@ -8,12 +8,10 @@ import (
"syscall"
)
// notifyReloadSigCh sends reload signal to the channel
func notifyReloadSigCh(ch chan os.Signal) {
signal.Notify(ch, syscall.SIGUSR1)
}
// sendReloadSignal sends a reload signal to the current process
func (p *prog) sendReloadSignal() error {
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
}

View File

@@ -6,10 +6,8 @@ import (
"time"
)
// notifyReloadSigCh is a no-op on Windows platforms
func notifyReloadSigCh(ch chan os.Signal) {}
// sendReloadSignal sends a reload signal to the program
func (p *prog) sendReloadSignal() error {
select {
case p.reloadCh <- struct{}{}:

View File

@@ -3,45 +3,59 @@ package cli
import (
"net"
"net/netip"
"os"
"path/filepath"
"strings"
"time"
"github.com/fsnotify/fsnotify"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
// Returns nil if no nameservers are found.
// This function parses the system DNS configuration to understand current nameserver settings
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
return resolvconffile.NameserversFromFile(path)
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Parse the file for "nameserver" lines
var currentNS []string
lines := strings.Split(string(content), "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "nameserver") {
parts := strings.Fields(trimmed)
if len(parts) >= 2 {
currentNS = append(currentNS, parts[1])
}
}
}
return currentNS, nil
}
// watchResolvConf watches any changes to /etc/resolv.conf file,
// and reverting to the original config set by ctrld.
// This ensures that DNS settings are not overridden by other applications or system processes
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
resolvConfPath := "/etc/resolv.conf"
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
// This handles systems where resolv.conf is a symlink to another location
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
resolvConfPath = rp
}
p.Debug().Msgf("Start watching %s file", resolvConfPath)
mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath)
watcher, err := fsnotify.NewWatcher()
if err != nil {
p.Warn().Err(err).Msg("Could not create watcher for /etc/resolv.conf")
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
return
}
defer watcher.Close()
// We watch /etc instead of /etc/resolv.conf directly,
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
// This is necessary because some systems don't properly notify on file changes
watchDir := filepath.Dir(resolvConfPath)
if err := watcher.Add(watchDir); err != nil {
p.Warn().Err(err).Msgf("Could not add %s to watcher list", watchDir)
mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
return
}
@@ -50,7 +64,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
case <-p.dnsWatcherStopCh:
return
case <-p.stopCh:
p.Debug().Msgf("Stopping watcher for %s", resolvConfPath)
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
return
case event, ok := <-watcher.Events:
if p.recoveryRunning.Load() {
@@ -63,10 +77,9 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
continue
}
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
// Convert expected nameservers to strings for comparison
// This allows us to detect when the resolv.conf has been modified
expectedNS := make([]string, len(ns))
for i, addr := range ns {
expectedNS[i] = addr.String()
@@ -79,20 +92,18 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
for retry := 0; retry < maxRetries; retry++ {
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
if err != nil {
p.Error().Err(err).Msg("Failed to read resolv.conf content")
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
break
}
// If we found nameservers, break out of retry loop
// This handles cases where the file is being written but not yet complete
if len(foundNS) > 0 {
break
}
// Only retry if we found no nameservers
// This handles temporary file states during updates
if retry < maxRetries-1 {
p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
select {
case <-p.stopCh:
return
@@ -102,7 +113,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
continue
}
} else {
p.Debug().Msg("resolv.conf remained empty after all retries")
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
}
}
@@ -119,7 +130,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
}
}
p.Debug().
mainLog.Load().Debug().
Strs("found", foundNS).
Strs("expected", expectedNS).
Bool("matches", matches).
@@ -128,16 +139,16 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
// Only revert if the nameservers don't match
if !matches {
if err := watcher.Remove(watchDir); err != nil {
p.Error().Err(err).Msg("Failed to pause watcher")
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
continue
}
if err := setDnsFn(iface, ns); err != nil {
p.Error().Err(err).Msg("Failed to revert /etc/resolv.conf changes")
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
}
if err := watcher.Add(watchDir); err != nil {
p.Error().Err(err).Msg("Failed to continue running watcher")
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
return
}
}
@@ -147,7 +158,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
if !ok {
return
}
p.Error().Err(err).Msg("Could not get event for /etc/resolv.conf")
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
}
}
}

View File

@@ -12,7 +12,7 @@ import (
const resolvConfPath = "/etc/resolv.conf"
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
servers := make([]string, len(ns))
for i := range ns {
servers[i] = ns[i].String()

View File

@@ -14,7 +14,7 @@ import (
)
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
r, err := newLoopbackOSConfigurator()
if err != nil {
return err
@@ -27,7 +27,7 @@ func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
if sds, err := searchDomains(); err == nil {
oc.SearchDomains = sds
} else {
p.Debug().Err(err).Msg("Failed to get search domains list when reverting resolv.conf file")
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file")
}
return r.SetDNS(oc)
}

View File

@@ -1,51 +0,0 @@
//go:build unix
package cli
import (
"os"
"slices"
"strings"
"testing"
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
)
func oldParseResolvConfNameservers(path string) ([]string, error) {
content, err := os.ReadFile(path)
if err != nil {
return nil, err
}
// Parse the file for "nameserver" lines
var currentNS []string
lines := strings.Split(string(content), "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "nameserver") {
parts := strings.Fields(trimmed)
if len(parts) >= 2 {
currentNS = append(currentNS, parts[1])
}
}
}
return currentNS, nil
}
// Test_prog_parseResolvConfNameservers tests the parsing of nameservers from resolv.conf content.
// Note: The previous implementation was removed to reduce code duplication and consolidate
// the resolv.conf handling logic into a single unified approach. All resolv.conf parsing
// is now handled by the resolvconffile package, which provides a consistent interface
// for both reading and modifying resolv.conf files across different platforms.
func Test_prog_parseResolvConfNameservers(t *testing.T) {
oldNss, _ := oldParseResolvConfNameservers(resolvconffile.Path)
p := &prog{}
nss, _ := p.parseResolvConfNameservers(resolvconffile.Path)
slices.Sort(oldNss)
slices.Sort(nss)
if !slices.Equal(oldNss, nss) {
t.Errorf("result mismatched, old: %v, new: %v", oldNss, nss)
}
t.Logf("result: %v", nss)
}

View File

@@ -6,7 +6,7 @@ import (
)
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
func (p *prog) setResolvConf(_ *net.Interface, _ []netip.Addr) error {
func setResolvConf(_ *net.Interface, _ []netip.Addr) error {
return nil
}

View File

@@ -33,7 +33,7 @@ func searchDomains() ([]dnsname.FQDN, error) {
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
d, err := dnsname.ToFQDN(a.String())
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("Failed to parse domain: %s", a.String())
mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String())
continue
}
sds = append(sds, d)

View File

@@ -4,5 +4,4 @@ package cli
var supportedSelfDelete = true
// selfDeleteExe performs self-deletion on non-Windows platforms
func selfDeleteExe() error { return nil }

View File

@@ -33,7 +33,6 @@ type FILE_DISPOSITION_INFO struct {
DeleteFile bool
}
// dsOpenHandle opens a handle to the specified file with DELETE access
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
handle, err := windows.CreateFile(
pwPath,
@@ -52,7 +51,6 @@ func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
return handle, nil
}
// dsRenameHandle renames a file handle to a stream name
func dsRenameHandle(hHandle windows.Handle) error {
var fRename FILE_RENAME_INFO
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
@@ -84,7 +82,6 @@ func dsRenameHandle(hHandle windows.Handle) error {
return nil
}
// dsDepositeHandle marks a file handle for deletion
func dsDepositeHandle(hHandle windows.Handle) error {
var fDelete FILE_DISPOSITION_INFO
fDelete.DeleteFile = true
@@ -103,7 +100,6 @@ func dsDepositeHandle(hHandle windows.Handle) error {
return nil
}
// selfDeleteExe performs self-deletion on Windows platforms
func selfDeleteExe() error {
var wcPath [windows.MAX_PATH + 1]uint16
var hCurrent windows.Handle

View File

@@ -5,13 +5,12 @@ package cli
import (
"os"
"github.com/Control-D-Inc/ctrld"
"github.com/rs/zerolog"
)
// selfUninstall performs self-uninstallation on non-Unix platforms
func selfUninstall(p *prog, logger *ctrld.Logger) {
func selfUninstall(p *prog, logger zerolog.Logger) {
if uninstallInvalidCdUID(p, logger, false) {
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
os.Exit(0)
}
}

View File

@@ -9,18 +9,17 @@ import (
"runtime"
"syscall"
"github.com/Control-D-Inc/ctrld"
"github.com/rs/zerolog"
)
// selfUninstall performs self-uninstallation on Unix platforms
func selfUninstall(p *prog, logger *ctrld.Logger) {
func selfUninstall(p *prog, logger zerolog.Logger) {
if runtime.GOOS == "linux" {
selfUninstallLinux(p, logger)
}
bin, err := os.Executable()
if err != nil {
logger.Fatal().Err(err).Msg("Could not determine executable")
logger.Fatal().Err(err).Msg("could not determine executable")
}
args := []string{"uninstall"}
if deactivationPinSet() {
@@ -29,19 +28,18 @@ func selfUninstall(p *prog, logger *ctrld.Logger) {
cmd := exec.Command(bin, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
if err := cmd.Start(); err != nil {
logger.Fatal().Err(err).Msg("Could not start self uninstall command")
logger.Fatal().Err(err).Msg("could not start self uninstall command")
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
_ = cmd.Wait()
os.Exit(0)
}
// selfUninstallLinux performs self-uninstallation on Linux platforms
func selfUninstallLinux(p *prog, logger *ctrld.Logger) {
func selfUninstallLinux(p *prog, logger zerolog.Logger) {
if uninstallInvalidCdUID(p, logger, true) {
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
os.Exit(0)
}
}

View File

@@ -1,31 +1,24 @@
package cli
// semaphore provides a simple synchronization mechanism
type semaphore interface {
acquire()
release()
}
// noopSemaphore is a no-operation implementation of semaphore
type noopSemaphore struct{}
// acquire performs a no-operation for the noop semaphore
func (n noopSemaphore) acquire() {}
// release performs a no-operation for the noop semaphore
func (n noopSemaphore) release() {}
// chanSemaphore is a channel-based implementation of semaphore
type chanSemaphore struct {
ready chan struct{}
}
// acquire blocks until a slot is available in the semaphore
func (c *chanSemaphore) acquire() {
c.ready <- struct{}{}
}
// release signals that a slot has been freed in the semaphore
func (c *chanSemaphore) release() {
<-c.ready
}

View File

@@ -11,6 +11,9 @@ import (
"github.com/coreos/go-systemd/v22/unit"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld/internal/router"
"github.com/Control-D-Inc/ctrld/internal/router/openwrt"
)
// newService wraps service.New call to return service.Service
@@ -21,6 +24,10 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
return nil, err
}
switch {
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
return &procd{sysV: &sysV{s}, svcConfig: c}, nil
case router.IsGLiNet():
return &sysV{s}, nil
case s.Platform() == "unix-systemv":
return &sysV{s}, nil
case s.Platform() == "linux-systemd":
@@ -35,7 +42,7 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
// sysV wraps a service.Service, and provide start/stop/status command
// base on "/etc/init.d/<service_name>".
//
// Use this on system where "service" command is not available.
// Use this on system where "service" command is not available, like GL.iNET router.
type sysV struct {
service.Service
}
@@ -82,6 +89,37 @@ func (s *sysV) Status() (service.Status, error) {
return unixSystemVServiceStatus()
}
// procd wraps a service.Service, and provide start/stop command
// base on "/etc/init.d/<service_name>", status command base on parsing "ps" command output.
//
// Use this on system where "/etc/init.d/<service_name> status" command is not available,
// like old GL.iNET Opal router.
type procd struct {
*sysV
svcConfig *service.Config
}
func (s *procd) Status() (service.Status, error) {
if !s.installed() {
return service.StatusUnknown, service.ErrNotInstalled
}
bin := s.svcConfig.Executable
if bin == "" {
exe, err := os.Executable()
if err != nil {
return service.StatusUnknown, nil
}
bin = exe
}
// Looking for something like "/sbin/ctrld run ".
shellCmd := fmt.Sprintf("ps | grep -q %q", bin+" [r]un ")
if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil {
return service.StatusStopped, nil
}
return service.StatusRunning, nil
}
// systemd wraps a service.Service, and provide status command to
// report the status correctly.
type systemd struct {
@@ -115,7 +153,7 @@ func (s *systemd) Start() error {
if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil {
return fmt.Errorf("systemctl daemon-reload failed: %w\n%s", err, string(out))
}
mainLog.Load().Debug().Msg("Set KillMode=process successfully")
mainLog.Load().Debug().Msg("set KillMode=process successfully")
}
return s.Service.Start()
}
@@ -125,7 +163,7 @@ func (s *systemd) Start() error {
func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
opts, err := unit.DeserializeOptions(r)
if err != nil {
mainLog.Load().Error().Err(err).Msg("Failed to deserialize options")
mainLog.Load().Error().Err(err).Msg("failed to deserialize options")
return
}
change = true
@@ -149,7 +187,6 @@ func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
return opts, change
}
// newLaunchd creates a new launchd service wrapper
func newLaunchd(s service.Service) *launchd {
return &launchd{
Service: s,
@@ -179,30 +216,28 @@ type task struct {
Name string
}
// doTasks executes a list of tasks and returns success status
func doTasks(tasks []task) bool {
for _, task := range tasks {
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
if err := task.f(); err != nil {
if task.abortOnError {
mainLog.Load().Error().Msgf("Error running task %s: %v", task.Name, err)
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
return false
}
// if this is darwin stop command, dont print debug
// since launchctl complains on every start
if runtime.GOOS != "darwin" || task.Name != "Stop" {
mainLog.Load().Debug().Msgf("Error running task %s: %v", task.Name, err)
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
}
}
}
return true
}
// checkHasElevatedPrivilege checks if the process has elevated privileges and exits if not
func checkHasElevatedPrivilege() {
ok, err := hasElevatedPrivilege()
if err != nil {
mainLog.Load().Error().Msgf("Could not detect user privilege: %v", err)
mainLog.Load().Error().Msgf("could not detect user privilege: %v", err)
return
}
if !ok {
@@ -211,10 +246,16 @@ func checkHasElevatedPrivilege() {
}
}
// unixSystemVServiceStatus checks the status of a Unix System V service
func unixSystemVServiceStatus() (service.Status, error) {
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
if err != nil {
// Specific case for openwrt >= 24.10, it returns non-success code
// for above status command, which may not right.
if router.Name() == openwrt.Name {
if string(bytes.ToLower(bytes.TrimSpace(out))) == "inactive" {
return service.StatusStopped, nil
}
}
return service.StatusUnknown, nil
}

View File

@@ -6,15 +6,17 @@ import (
"os"
)
// hasElevatedPrivilege checks if the current process has elevated privileges
func hasElevatedPrivilege() (bool, error) {
return os.Geteuid() == 0, nil
}
// openLogFile opens a log file with the specified flags
func openLogFile(path string, flags int) (*os.File, error) {
return os.OpenFile(path, flags, os.FileMode(0o600))
}
// ConfigureWindowsServiceFailureActions is a no-op on non-Windows platforms
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func hasLocalDnsServerRunning() bool { return false }
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 }

View File

@@ -2,16 +2,22 @@ package cli
import (
"os"
"reflect"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"unsafe"
"github.com/microsoft/wmi/pkg/base/host"
"github.com/microsoft/wmi/pkg/base/instance"
"github.com/microsoft/wmi/pkg/base/query"
"github.com/microsoft/wmi/pkg/constant"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc/mgr"
)
// hasElevatedPrivilege checks if the current process has elevated privileges on Windows
func hasElevatedPrivilege() (bool, error) {
var sid *windows.SID
if err := windows.AllocateAndInitializeSid(
@@ -94,7 +100,6 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error {
return nil
}
// openLogFile opens a log file with the specified mode on Windows
func openLogFile(path string, mode int) (*os.File, error) {
if len(path) == 0 {
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
@@ -146,3 +151,77 @@ func openLogFile(path string, mode int) (*os.File, error) {
return os.NewFile(uintptr(handle), path), nil
}
const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
func hasLocalDnsServerRunning() bool {
h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
if e != nil {
return false
}
p := windows.ProcessEntry32{Size: processEntrySize}
for {
e := windows.Process32Next(h, &p)
if e != nil {
return false
}
if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" {
return true
}
}
}
func isRunningOnDomainControllerWindows() (bool, int) {
whost := host.NewWmiLocalHost()
q := query.NewWmiQuery("Win32_ComputerSystem")
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
if err != nil {
mainLog.Load().Debug().Err(err).Msg("WMI query failed")
return false, 0
}
if instances == nil {
mainLog.Load().Debug().Msg("WMI query returned nil instances")
return false, 0
}
defer instances.Close()
if len(instances) == 0 {
mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem")
return false, 0
}
val, err := instances[0].GetProperty("DomainRole")
if err != nil {
mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property")
return false, 0
}
if val == nil {
mainLog.Load().Debug().Msg("DomainRole property is nil")
return false, 0
}
// Safely handle varied types: string or integer
var roleInt int
switch v := val.(type) {
case string:
// "4", "5", etc.
parsed, parseErr := strconv.Atoi(v)
if parseErr != nil {
mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v)
return false, 0
}
roleInt = parsed
case int8, int16, int32, int64:
roleInt = int(reflect.ValueOf(v).Int())
case uint8, uint16, uint32, uint64:
roleInt = int(reflect.ValueOf(v).Uint())
default:
mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v)
return false, 0
}
// Check if role indicates a domain controller
isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController
return isDC, roleInt
}

View File

@@ -0,0 +1,25 @@
package cli
import (
"testing"
"time"
)
func Test_hasLocalDnsServerRunning(t *testing.T) {
start := time.Now()
hasDns := hasLocalDnsServerRunning()
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
start = time.Now()
hasDnsPowershell := hasLocalDnsServerRunningPowershell()
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
if hasDns != hasDnsPowershell {
t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns)
}
}
func hasLocalDnsServerRunningPowershell() bool {
_, err := powershell("Get-Process -Name DNS")
return err == nil
}

View File

@@ -2,7 +2,6 @@ package cli
import (
"sync"
"sync/atomic"
"time"
"github.com/Control-D-Inc/ctrld"
@@ -17,8 +16,7 @@ const (
// upstreamMonitor performs monitoring upstreams health.
type upstreamMonitor struct {
cfg *ctrld.Config
logger atomic.Pointer[ctrld.Logger]
cfg *ctrld.Config
mu sync.RWMutex
checking map[string]bool
@@ -30,8 +28,7 @@ type upstreamMonitor struct {
failureTimerActive map[string]bool
}
// newUpstreamMonitor creates a new upstream monitor instance
func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor {
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
um := &upstreamMonitor{
cfg: cfg,
checking: make(map[string]bool),
@@ -40,7 +37,6 @@ func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonito
recovered: make(map[string]bool),
failureTimerActive: make(map[string]bool),
}
um.logger.Store(logger)
for n := range cfg.Upstream {
upstream := upstreamPrefix + n
um.reset(upstream)
@@ -57,7 +53,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) {
defer um.mu.Unlock()
if um.recovered[upstream] {
um.logger.Load().Debug().Msgf("Upstream %q is recovered, skipping failure count increase", upstream)
mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream)
return
}
@@ -65,7 +61,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) {
failedCount := um.failureReq[upstream]
// Log the updated failure count.
um.logger.Load().Debug().Msgf("Upstream %q failure count updated to %d", upstream, failedCount)
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
// If this is the first failure and no timer is running, start a 10-second timer.
if failedCount == 1 && !um.failureTimerActive[upstream] {
@@ -78,7 +74,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) {
// and the upstream is not in a recovered state, mark it as down.
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
um.down[upstream] = true
um.logger.Load().Warn().Msgf("Upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
}
// Reset the timer flag so that a new timer can be spawned if needed.
um.failureTimerActive[upstream] = false
@@ -88,7 +84,7 @@ func (um *upstreamMonitor) increaseFailureCount(upstream string) {
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
if failedCount >= maxFailureRequest {
um.down[upstream] = true
um.logger.Load().Warn().Msgf("Upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
}
}

View File

@@ -45,7 +45,7 @@ func (c *Controller) Start(CdUID string, ProvisionID string, CustomHostname stri
}
}
// mapCallback maps the AppCallback interface to cli.AppCallback to avoid circular dependency
// As workaround to avoid circular dependency between cli and ctrld_library module
func mapCallback(callback AppCallback) cli.AppCallback {
return cli.AppCallback{
HostName: func() string {

151
config.go
View File

@@ -114,10 +114,6 @@ func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
// InitConfig initializes default config values for given *viper.Viper instance.
func InitConfig(v *viper.Viper, name string) {
ctx := context.Background()
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "Config initialization started")
v.SetDefault("listener", map[string]*ListenerConfig{
"0": {
IP: "",
@@ -156,8 +152,6 @@ func InitConfig(v *viper.Viper, name string) {
Timeout: 3000,
},
})
Log(ctx, logger.Debug(), "Config initialization completed")
}
// Config represents ctrld supported configuration.
@@ -315,20 +309,14 @@ func (lc *ListenerConfig) IsDirectDnsListener() bool {
}
}
// MatchingConfig defines the configuration for rule matching behavior
type MatchingConfig struct {
Order []string `mapstructure:"order" toml:"order,omitempty"`
}
// ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests.
type ListenerPolicyConfig struct {
Name string `mapstructure:"name" toml:"name,omitempty"`
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
Matching *MatchingConfig `mapstructure:"-" toml:"-"`
Name string `mapstructure:"name" toml:"name,omitempty"`
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
}
// Rule is a map from source to list of upstreams.
@@ -337,13 +325,12 @@ type ListenerPolicyConfig struct {
type Rule map[string][]string
// Init initialized necessary values for an UpstreamConfig.
func (uc *UpstreamConfig) Init(ctx context.Context) {
logger := LoggerFromCtx(ctx)
func (uc *UpstreamConfig) Init() {
if err := uc.initDnsStamps(); err != nil {
logger.Fatal().Err(err).Msg("Invalid dns stamps")
ProxyLogger.Load().Fatal().Err(err).Msg("invalid DNS Stamps")
}
uc.initDoHScheme()
uc.uid = upstreamUID(ctx)
uc.uid = upstreamUID()
if u, err := url.Parse(uc.Endpoint); err == nil {
uc.Domain = u.Hostname()
switch uc.Type {
@@ -363,9 +350,6 @@ func (uc *UpstreamConfig) Init(ctx context.Context) {
}
}
if uc.IPStack == "" {
// Set default IP stack based on upstream type
// Control-D upstreams use split stack for better IPv4/IPv6 handling,
// while other upstreams use both stacks for maximum compatibility
if uc.IsControlD() {
uc.IPStack = IpStackSplit
} else {
@@ -427,7 +411,7 @@ func (uc *UpstreamConfig) IsDiscoverable() bool {
return *uc.Discoverable
}
switch uc.Type {
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate:
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate, ResolverTypeLocal:
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
}
@@ -450,18 +434,21 @@ func (uc *UpstreamConfig) UID() string {
return uc.uid
}
// SetupBootstrapIP sets up bootstrap IPs for the upstream config.
// SetupBootstrapIP manually find all available IPs of the upstream.
// The first usable IP will be used as bootstrap IP of the upstream.
// The upstream domain will be looked up using following orders:
//
// - Current system DNS settings.
// - Direct IPs table for ControlD upstreams.
// - ControlD Bootstrap DNS 76.76.2.22
//
// The setup process will block until there's usable IPs found.
func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) {
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "Setting up bootstrap IPs for upstream: %s", uc.Name)
func (uc *UpstreamConfig) SetupBootstrapIP() {
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
isControlD := uc.IsControlD()
nss := initDefaultOsResolver(ctx)
nss := initDefaultOsResolver()
for {
Log(ctx, logger.Debug(), "Looking up bootstrap IPs for domain: %s", uc.Domain)
uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, nss)
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss)
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
// filtering them out here to prevent weird behavior.
if isControlD {
@@ -476,18 +463,18 @@ func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) {
uc.bootstrapIPs = uc.bootstrapIPs[:n]
if len(uc.bootstrapIPs) == 0 {
uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain)
logger.Warn().Msgf("No record found for %q, lookup from direct ip table", uc.Domain)
ProxyLogger.Load().Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain)
}
}
if len(uc.bootstrapIPs) == 0 {
logger.Warn().Msgf("No record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP)
uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")})
ProxyLogger.Load().Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP)
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")})
}
if len(uc.bootstrapIPs) > 0 {
break
}
logger.Warn().Msg("Could not resolve bootstrap ips, retrying...")
ProxyLogger.Load().Warn().Msg("could not resolve bootstrap IPs, retrying...")
b.BackOff(context.Background(), errors.New("no bootstrap IPs"))
}
for _, ip := range uc.bootstrapIPs {
@@ -497,12 +484,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) {
uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip)
}
}
logger.Debug().Msgf("Bootstrap ips: %v", uc.bootstrapIPs)
Log(ctx, logger.Debug(), "Bootstrap IP setup completed for upstream: %s", uc.Name)
ProxyLogger.Load().Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs)
}
// ReBootstrap re-setup the bootstrap IP and the transport.
func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
func (uc *UpstreamConfig) ReBootstrap() {
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
default:
@@ -510,8 +496,7 @@ func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
}
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
if uc.rebootstrap.CompareAndSwap(false, true) {
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "Re-bootstrapping upstream: %s", uc.Name)
ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
}
return true, nil
})
@@ -519,35 +504,35 @@ func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
// SetupTransport initializes the network transport used to connect to upstream server.
// For now, only DoH upstream is supported.
func (uc *UpstreamConfig) SetupTransport(ctx context.Context) {
func (uc *UpstreamConfig) SetupTransport() {
switch uc.Type {
case ResolverTypeDOH:
uc.setupDOHTransport(ctx)
uc.setupDOHTransport()
case ResolverTypeDOH3:
uc.setupDOH3Transport(ctx)
uc.setupDOH3Transport()
}
}
func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) {
func (uc *UpstreamConfig) setupDOHTransport() {
switch uc.IPStack {
case IpStackBoth, "":
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
case IpStackV4:
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
case IpStackV6:
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
case IpStackSplit:
uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
if HasIPv6(ctx) {
uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
if HasIPv6() {
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
} else {
uc.transport6 = uc.transport4
}
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport {
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.MaxIdleConnsPerHost = 100
transport.TLSClientConfig = &tls.Config{
@@ -567,13 +552,12 @@ func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *
dialerTimeoutMs = uc.Timeout
}
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
logger := LoggerFromCtx(ctx)
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
if uc.BootstrapIP != "" {
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
addr := net.JoinHostPort(uc.BootstrapIP, port)
Log(ctx, logger.Debug(), "Sending doh request to: %s", addr)
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", addr)
return dialer.DialContext(ctx, network, addr)
}
pd := &ctrldnet.ParallelDialer{}
@@ -583,11 +567,11 @@ func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
conn, err := pd.DialContext(ctx, network, dialAddrs, logger.Logger)
conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load())
if err != nil {
return nil, err
}
Log(ctx, logger.Debug(), "Sending doh request to: %s", conn.RemoteAddr())
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", conn.RemoteAddr())
return conn, nil
}
runtime.SetFinalizer(transport, func(transport *http.Transport) {
@@ -597,20 +581,19 @@ func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *
}
// Ping warms up the connection to DoH/DoH3 upstream.
func (uc *UpstreamConfig) Ping(ctx context.Context) {
if err := uc.ping(ctx); err != nil {
logger := LoggerFromCtx(ctx)
logger.Debug().Err(err).Msgf("Upstream ping failed: %s", uc.Endpoint)
_ = uc.FallbackToDirectIP(ctx)
func (uc *UpstreamConfig) Ping() {
if err := uc.ping(); err != nil {
ProxyLogger.Load().Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint)
_ = uc.FallbackToDirectIP()
}
}
// ErrorPing is like Ping, but return an error if any.
func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error {
return uc.ping(ctx)
func (uc *UpstreamConfig) ErrorPing() error {
return uc.ping()
}
func (uc *UpstreamConfig) ping(ctx context.Context) error {
func (uc *UpstreamConfig) ping() error {
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
default:
@@ -639,11 +622,11 @@ func (uc *UpstreamConfig) ping(ctx context.Context) error {
for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} {
switch uc.Type {
case ResolverTypeDOH:
if err := ping(uc.dohTransport(ctx, typ)); err != nil {
if err := ping(uc.dohTransport(typ)); err != nil {
return err
}
case ResolverTypeDOH3:
if err := ping(uc.doh3Transport(ctx, typ)); err != nil {
if err := ping(uc.doh3Transport(typ)); err != nil {
return err
}
}
@@ -678,12 +661,12 @@ func (uc *UpstreamConfig) isNextDNS() bool {
return domain == "dns.nextdns.io"
}
func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper {
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
uc.transportOnce.Do(func() {
uc.SetupTransport(ctx)
uc.SetupTransport()
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport(ctx)
uc.SetupTransport()
}
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
@@ -699,7 +682,7 @@ func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http
return uc.transport
}
func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string {
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
switch uc.IPStack {
case IpStackBoth:
return pick(uc.bootstrapIPs)
@@ -712,7 +695,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uin
case dns.TypeA:
return pick(uc.bootstrapIPs4)
default:
if HasIPv6(ctx) {
if HasIPv6() {
return pick(uc.bootstrapIPs6)
}
return pick(uc.bootstrapIPs4)
@@ -721,7 +704,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uin
return pick(uc.bootstrapIPs)
}
func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) {
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
switch uc.IPStack {
case IpStackBoth:
return "tcp-tls", "udp"
@@ -734,7 +717,7 @@ func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (st
case dns.TypeA:
return "tcp4-tls", "udp4"
default:
if HasIPv6(ctx) {
if HasIPv6() {
return "tcp6-tls", "udp6"
}
return "tcp4-tls", "udp4"
@@ -815,7 +798,7 @@ func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context
}
// FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain.
func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool {
func (uc *UpstreamConfig) FallbackToDirectIP() bool {
if !uc.IsControlD() {
return false
}
@@ -834,8 +817,7 @@ func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool {
default:
return
}
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Warn(), "Using direct IP for %q: %s", uc.Endpoint, ip)
ProxyLogger.Load().Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip)
uc.u.Host = ip
done = true
})
@@ -844,18 +826,12 @@ func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool {
// Init initialized necessary values for an ListenerConfig.
func (lc *ListenerConfig) Init() {
logger := LoggerFromCtx(context.Background())
Log(context.Background(), logger.Debug(), "Initializing listener config")
if lc.Policy != nil {
lc.Policy.FailoverRcodeNumbers = make([]int, len(lc.Policy.FailoverRcodes))
for i, rcode := range lc.Policy.FailoverRcodes {
lc.Policy.FailoverRcodeNumbers[i] = dnsrcode.FromString(rcode)
}
Log(context.Background(), logger.Debug(), "Listener policy initialized with %d failover rcodes", len(lc.Policy.FailoverRcodes))
}
Log(context.Background(), logger.Debug(), "Listener config initialization completed")
}
// ValidateConfig validates the given config.
@@ -975,12 +951,11 @@ func pick(s []string) string {
}
// upstreamUID generates an unique identifier for an upstream.
func upstreamUID(ctx context.Context) string {
logger := LoggerFromCtx(ctx)
func upstreamUID() string {
b := make([]byte, 4)
for {
if _, err := crand.Read(b); err != nil {
logger.Warn().Err(err).Msg("Could not generate uid for upstream, retrying...")
ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
continue
}
return hex.EncodeToString(b)

View File

@@ -1,7 +1,6 @@
package ctrld
import (
"context"
"net/url"
"testing"
@@ -37,10 +36,10 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed.
// t.Parallel()
tc.uc.Init(context.Background())
tc.uc.SetupBootstrapIP(context.Background())
tc.uc.Init()
tc.uc.SetupBootstrapIP()
if len(tc.uc.bootstrapIPs) == 0 {
t.Log(defaultNameservers(context.Background()))
t.Log(defaultNameservers())
t.Fatalf("could not bootstrap ip: %s", tc.uc.String())
}
})
@@ -356,7 +355,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init(context.Background())
tc.uc.Init()
tc.uc.uid = "" // we don't care about the uid.
assert.Equal(t, tc.expected, tc.uc)
})
@@ -498,7 +497,7 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init(context.Background())
tc.uc.Init()
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
}

View File

@@ -14,35 +14,34 @@ import (
"github.com/quic-go/quic-go/http3"
)
func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) {
func (uc *UpstreamConfig) setupDOH3Transport() {
switch uc.IPStack {
case IpStackBoth, "":
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
case IpStackV4:
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
case IpStackV6:
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
case IpStackSplit:
uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
if HasIPv6(ctx) {
uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
if HasIPv6() {
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
} else {
uc.http3RoundTripper6 = uc.http3RoundTripper4
}
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
rt := &http3.Transport{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
logger := LoggerFromCtx(ctx)
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
// if we have a bootstrap ip set, use it to avoid DNS lookup
if uc.BootstrapIP != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr)
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", addr)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
@@ -62,7 +61,7 @@ func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string)
if err != nil {
return nil, err
}
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr())
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
@@ -71,12 +70,12 @@ func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string)
return rt
}
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
uc.transportOnce.Do(func() {
uc.SetupTransport(ctx)
uc.SetupTransport()
})
if uc.rebootstrap.CompareAndSwap(true, false) {
uc.SetupTransport(ctx)
uc.SetupTransport()
}
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:

View File

@@ -5,6 +5,3 @@ package ctrld
func IsDesktopPlatform() bool {
return true
}
// SelfDiscover reports whether ctrld should only do self discover.
func SelfDiscover() bool { return true }

View File

@@ -7,6 +7,3 @@ package ctrld
func IsDesktopPlatform() bool {
return false
}
// SelfDiscover reports whether ctrld should only do self discover.
func SelfDiscover() bool { return false }

View File

@@ -1,22 +1,7 @@
package ctrld
import "golang.org/x/sys/windows"
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
// currently defined as macOS or Windows workstation.
func IsDesktopPlatform() bool {
return isWindowsWorkStation()
}
// SelfDiscover reports whether ctrld should only do self discover.
func SelfDiscover() bool {
return isWindowsWorkStation()
}
// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine.
func isWindowsWorkStation() bool {
// From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa
const VER_NT_WORKSTATION = 0x0000001
osvi := windows.RtlGetVersion()
return osvi.ProductType == VER_NT_WORKSTATION
}

View File

@@ -18,6 +18,10 @@ The config file allows for advanced configuration of the `ctrld` utility to cove
- `/etc/controld` on *nix.
- User's home directory on Windows.
- Same directory with `ctrld` binary on these routers:
- `ddwrt`
- `merlin`
- `freshtomato`
- Current directory.
The user can choose to override default value using command line `--config` or `-c`:
@@ -289,7 +293,7 @@ If a remote upstream fails to resolve a query or is unreachable, `ctrld` will fo
- Type: boolean
- Required: no
- Default: true on Windows, MacOS and Linux.
- Default: true on Windows, MacOS and non-router Linux.
## Upstream
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.

View File

@@ -1,135 +0,0 @@
# ctrld v2.0.0 Breaking Changes
This document outlines the breaking changes introduced in ctrld v2.0.0 and provides migration guidance for affected users.
## Overview
ctrld v2.0.0 removes automatic configuration support for router and server platforms. This means ctrld will no longer perform "magic" configuration to automatically set itself up as an upstream for existing DNS software on these platforms.
## What's Changing
### Removed Platform Support
**Router Platforms:**
- ctrld will no longer automatically configure itself as an upstream for dnsmasq or other DNS software
- No automatic detection and configuration of router-specific DNS settings
**Server Platforms:**
- ctrld will no longer automatically configure Windows Server DNS forwarder settings
- No automatic integration with server DNS services
### What Remains Supported
**Desktop Platforms:**
- Windows Desktop
- macOS Desktop
- Linux Desktop
These platforms continue to receive full automatic configuration support.
## Stay on v1.x.x
ctrld v1.x.x will continue to be supported for router and server platforms:
- Important bug fixes (regression or security) will be cherry-picked to v1.x.x branch
- New features may still be added (but may take longer to implement)
- Long-term support for these platforms
## Migration Path for Router and Server Users
If you're currently using ctrld v1.x.x on router or server platforms, you need to follow these steps to migrate to v2.0.0:
### Step 1: Downloading ctrld v2 binary
To download ctrld v2.0.0, follow these steps:
Stop the current ctrld service:
```sh
ctrld stop
```
Or uninstall the current version:
```sh
ctrld uninstall
```
Download the appropriate binary for your platform: https://dl.controld.com/v2/linux-amd64/ctrld
> **Note**: Replace `amd64` with your platform architecture as needed.
Verify that the binary was updated correctly:
```sh
ctrld --version
```
Expected output:
```
ctrld version v2.0.0
```
### Step 2: Start ctrld without self-checking
You have two ways to start ctrld:
**Option A: Use Remote Configuration (Recommended)**
1. **Export your current configuration:**
- Copy the contents of your current `ctrld.toml` file
2. **Import to Control D Dashboard:**
- Log into your Control D dashboard
- Use the remote configuration feature to upload your configuration
3. **Start ctrld with remote config:**
```bash
sudo ctrld service start --cd=<your_uid> --skip_self_checks
```
> **Note**: You must use `ctrld service start` to prevent DNS being set automatically by ctrld.
**Option B: Use Local Configuration**
```bash
sudo ctrld service start --skip_self_checks
```
### Step 3: Configure DNS Software to Use ctrld as Upstream
**For dnsmasq users:**
1. Configure dnsmasq to use ctrld as upstream:
```bash
# Add to dnsmasq.conf
no-resolv
server=127.0.0.1#5354
add-mac
add-subnet=32,128
# Disable cache or set max-cache-ttl=0
# to prevent queries from caching
cache-size=0
# max-cache-ttl=0
```
2. Restart dnsmasq:
```bash
sudo service dnsmasq restart
```
**For Windows Server users:**
1. Configure DNS forwarder in Windows Server:
- Open DNS Manager
- Right-click on your server name
- Select "Properties" → "Forwarders" tab
- Add `<ctrld listener IP>` as a forwarder
## Getting Help
If you encounter any issues during migration or have questions about the v2.0.0 changes:
1. **File an issue:** [GitHub Issues](https://github.com/Control-D-Inc/ctrld/issues)
2. **Contact support:** Email help@controld.com.
3. **Check documentation:** Review the [configuration documentation](config.md) for detailed setup instructions
## Summary
While ctrld v2.0.0 removes automatic configuration for router and server platforms, it provides a more focused experience for desktop users while still allowing router/server users to continue using ctrld with manual configuration or by staying on the v1.x.x branch.
The migration path is designed to be straightforward, with multiple options to suit different use cases and technical comfort levels.

34
doh.go
View File

@@ -53,9 +53,6 @@ var EncodeArchNameMap = map[string]string{
var DecodeArchNameMap = map[string]string{}
func init() {
// Create reverse mappings for OS and architecture names
// This is needed because the API expects encoded values, but we need to decode
// them back to their original form for processing
for k, v := range EncodeOsNameMap {
DecodeOsNameMap[v] = k
}
@@ -88,12 +85,8 @@ type dohResolver struct {
// Resolve performs DNS query with given DNS message using DOH protocol.
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "DoH resolver query started")
data, err := msg.Pack()
if err != nil {
Log(ctx, logger.Error().Err(err), "Failed to pack DNS message")
return nil, err
}
@@ -105,7 +98,6 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
endpoint.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
if err != nil {
Log(ctx, logger.Error().Err(err), "Could not create HTTP request")
return nil, fmt.Errorf("could not create request: %w", err)
}
addHeader(ctx, req, r.uc)
@@ -113,55 +105,40 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
c := http.Client{Transport: r.uc.dohTransport(ctx, dnsTyp)}
c := http.Client{Transport: r.uc.dohTransport(dnsTyp)}
if r.isDoH3 {
transport := r.uc.doh3Transport(ctx, dnsTyp)
transport := r.uc.doh3Transport(dnsTyp)
if transport == nil {
Log(ctx, logger.Error(), "DoH3 is not supported")
return nil, errors.New("DoH3 is not supported")
}
c.Transport = transport
}
Log(ctx, logger.Debug(), "Sending DoH request to: %s", endpoint.String())
resp, err := c.Do(req)
if err != nil && r.uc.FallbackToDirectIP(ctx) {
if err != nil && r.uc.FallbackToDirectIP() {
retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx))
defer cancel()
logger := LoggerFromCtx(ctx)
logger.Warn().Err(err).Msg("Retrying request after fallback to direct ip")
Log(ctx, ProxyLogger.Load().Warn().Err(err), "retrying request after fallback to direct ip")
resp, err = c.Do(req.Clone(retryCtx))
}
if err != nil {
err = wrapUrlError(err)
if r.isDoH3 {
if closer, ok := c.Transport.(io.Closer); ok {
closer.Close()
}
}
Log(ctx, logger.Error().Err(err), "DoH request failed")
return nil, fmt.Errorf("could not perform request: %w", err)
}
defer resp.Body.Close()
buf, err := io.ReadAll(resp.Body)
if err != nil {
Log(ctx, logger.Error().Err(err), "Could not read response body")
return nil, fmt.Errorf("could not read message from response: %w", err)
}
if resp.StatusCode != http.StatusOK {
Log(ctx, logger.Error(), "Wrong response from DOH server, got: %s, status: %d", string(buf), resp.StatusCode)
return nil, fmt.Errorf("wrong response from DOH server, got: %s, status: %d", string(buf), resp.StatusCode)
}
answer := new(dns.Msg)
if err := answer.Unpack(buf); err != nil {
Log(ctx, logger.Error().Err(err), "Failed to unpack DNS answer")
return nil, fmt.Errorf("answer.Unpack: %w", err)
}
Log(ctx, logger.Debug(), "DoH resolver query successful")
return answer, nil
}
@@ -181,8 +158,7 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
}
}
if printed {
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "Sending request header: %v", dohHeader)
Log(ctx, ProxyLogger.Load().Debug(), "sending request header: %v", dohHeader)
}
dohHeader.Set("Content-Type", headerApplicationDNS)
dohHeader.Set("Accept", headerApplicationDNS)

View File

@@ -157,21 +157,20 @@ func Test_ClientCertificateVerificationError(t *testing.T) {
},
}
ctx := context.Background()
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init(ctx)
tc.uc.SetupBootstrapIP(ctx)
r, err := NewResolver(ctx, tc.uc)
tc.uc.Init()
tc.uc.SetupBootstrapIP()
r, err := NewResolver(tc.uc)
if err != nil {
t.Fatal(err)
}
msg := new(dns.Msg)
msg.SetQuestion("verify.controld.com.", dns.TypeA)
msg.RecursionDesired = true
_, err = r.Resolve(ctx, msg)
_, err = r.Resolve(context.Background(), msg)
// Verify the error contains the expected certificate information
if err == nil {
t.Fatal("expected certificate verification error, got nil")

15
doq.go
View File

@@ -18,9 +18,6 @@ type doqResolver struct {
}
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "DoQ resolver query started")
endpoint := r.uc.Endpoint
tlsConfig := &tls.Config{NextProtos: []string{"doq"}}
ip := r.uc.BootstrapIP
@@ -29,20 +26,12 @@ func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
ip = r.uc.bootstrapIPForDNSType(ctx, dnsTyp)
ip = r.uc.bootstrapIPForDNSType(dnsTyp)
}
tlsConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(ip, port)
Log(ctx, logger.Debug(), "Sending DoQ request to: %s", endpoint)
answer, err := resolve(ctx, msg, endpoint, tlsConfig)
if err != nil {
Log(ctx, logger.Error().Err(err), "DoQ request failed")
} else {
Log(ctx, logger.Debug(), "DoQ resolver query successful")
}
return answer, err
return resolve(ctx, msg, endpoint, tlsConfig)
}
func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) {

11
dot.go
View File

@@ -13,9 +13,6 @@ type dotResolver struct {
}
func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
logger := LoggerFromCtx(ctx)
Log(ctx, logger.Debug(), "DoT resolver query started")
// The dialer is used to prevent bootstrapping cycle.
// If r.endpoint is set to dns.controld.dev, we need to resolve
// dns.controld.dev first. By using a dialer with custom resolver,
@@ -26,7 +23,7 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
tcpNet, _ := r.uc.netForDNSType(ctx, dnsTyp)
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: tcpNet,
Dialer: dialer,
@@ -40,12 +37,6 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
}
Log(ctx, logger.Debug(), "Sending DoT request to: %s", endpoint)
answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint)
if err != nil {
Log(ctx, logger.Error().Err(err), "DoT request failed")
} else {
Log(ctx, logger.Debug(), "DoT resolver query successful")
}
return answer, wrapCertificateVerificationError(err)
}

14
go.mod
View File

@@ -5,6 +5,7 @@ go 1.24
require (
github.com/Masterminds/semver/v3 v3.2.1
github.com/ameshkov/dnsstamps v1.0.3
github.com/brunogui0812/sysprofiler v0.5.0
github.com/coreos/go-systemd/v22 v22.5.0
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf
github.com/docker/go-units v0.5.0
@@ -15,6 +16,7 @@ require (
github.com/hashicorp/golang-lru/v2 v2.0.1
github.com/illarion/gonotify/v2 v2.0.3
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2
github.com/jaypipes/ghw v0.21.0
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
github.com/kardianos/service v1.2.1
@@ -28,11 +30,12 @@ require (
github.com/prometheus/client_model v0.5.0
github.com/prometheus/prom2json v1.3.3
github.com/quic-go/quic-go v0.57.1
github.com/rs/zerolog v1.28.0
github.com/spf13/cobra v1.9.1
github.com/spf13/pflag v1.0.6
github.com/spf13/viper v1.16.0
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.2.1-beta.2
go.uber.org/zap v1.27.0
golang.org/x/net v0.43.0
golang.org/x/sync v0.16.0
golang.org/x/sys v0.35.0
@@ -54,13 +57,17 @@ require (
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/groob/plist v0.0.0-20200425180238-0f631f258c01 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jaypipes/pcidb v1.1.1 // indirect
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
@@ -75,14 +82,14 @@ require (
github.com/quic-go/qpack v0.6.0 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/spakin/awk v1.0.0 // indirect
github.com/spf13/afero v1.9.5 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/subosito/gotenv v1.4.2 // indirect
github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e // indirect
github.com/vishvananda/netns v0.0.4 // indirect
go.uber.org/multierr v1.11.0 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/crypto v0.41.0 // indirect
@@ -93,6 +100,7 @@ require (
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
howett.net/plist v1.0.2-0.20250314012144-ee69052608d9 // indirect
)
replace github.com/mr-karan/doggo => github.com/Windscribe/doggo v0.0.0-20220919152748-2c118fc391f8

35
go.sum
View File

@@ -42,12 +42,16 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
github.com/ameshkov/dnsstamps v1.0.3/go.mod h1:Ii3eUu73dx4Vw5O4wjzmT5+lkCwovjzaEZZ4gKyIH5A=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/brunogui0812/sysprofiler v0.5.0 h1:AUekplOKG/VKH6sPSBRxsKOA9Uv5OsI8qolXM73dXPU=
github.com/brunogui0812/sysprofiler v0.5.0/go.mod h1:lLd7gvylgd4nsTSC8exq1YY6qhLWXkgnalxjVzdlbEM=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -89,6 +93,7 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg=
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
@@ -163,6 +168,8 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
github.com/groob/plist v0.0.0-20200425180238-0f631f258c01 h1:0T3XGXebqLj7zSVLng9wX9axQzTEnvj/h6eT7iLfUas=
github.com/groob/plist v0.0.0-20200425180238-0f631f258c01/go.mod h1:itkABA+w2cw7x5nYUS/pLRef6ludkZKOigbROmCTaFw=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru/v2 v2.0.1 h1:5pv5N1lT1fjLg2VQ5KWc7kmucp2x/kvFOnxuVTqZ6x4=
@@ -179,8 +186,13 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA=
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI=
github.com/jaypipes/ghw v0.21.0 h1:ClG2xWtYY0c1ud9jZYwVGdSgfCI7AbmZmZyw3S5HHz8=
github.com/jaypipes/ghw v0.21.0/go.mod h1:GPrvwbtPoxYUenr74+nAnWbardIZq600vJDD5HnPsPE=
github.com/jaypipes/pcidb v1.1.1 h1:QmPhpsbmmnCwZmHeYAATxEaoRuiMAJusKYkUncMC0ro=
github.com/jaypipes/pcidb v1.1.1/go.mod h1:x27LT2krrUgjf875KxQXKB0Ha/YXLdZRVmw6hH0G7g8=
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c h1:kbTQ8oGf+BVFvt/fM+ECI+NbZDCqoi0vtZTfB2p2hrI=
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c/go.mod h1:k6+89xKz7BSMJ+DzIerBdtpEUeTlBMugO/hcVSzahog=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk=
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8=
@@ -205,6 +217,12 @@ github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
@@ -264,7 +282,10 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spakin/awk v1.0.0 h1:5ulBVgJhdN3XoFGNVv/MOHOIUfPVPvMCIlLH6O6ZqU4=
github.com/spakin/awk v1.0.0/go.mod h1:e7FnxcIEcRqdKwStPYWonox4n9DpharWk+3nnn1IqJs=
github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM=
github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
@@ -305,20 +326,16 @@ github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8=
go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
@@ -436,6 +453,7 @@ golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -469,9 +487,12 @@ golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
@@ -655,6 +676,8 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
howett.net/plist v1.0.2-0.20250314012144-ee69052608d9 h1:eeH1AIcPvSc0Z25ThsYF+Xoqbn0CI/YnXVYoTLFdGQw=
howett.net/plist v1.0.2-0.20250314012144-ee69052608d9/go.mod h1:fyFX5Hj5tP1Mpk8obqA9MZgXT416Q5711SDT7dQLTLk=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=

View File

@@ -20,8 +20,6 @@ func (a *arpDiscover) scan() {
}
// trim brackets
// Unix "arp -an" output formats IP addresses with parentheses like "(192.168.1.1)"
// We need to remove these brackets for proper IP parsing
ip := strings.ReplaceAll(fields[1], "(", "")
ip = strings.ReplaceAll(ip, ")", "")

View File

@@ -17,14 +17,10 @@ func (a *arpDiscover) scan() {
continue // empty lines
}
if line[0] != ' ' {
// Mark that we've found an interface header line
// Windows "arp -a" output has interface headers followed by ARP entries
header = true // "Interface:" lines, next is header line.
continue
}
if header {
// Skip the header line that follows interface names
// These lines contain column headers like "Internet Address" and "Physical Address"
header = false // header lines
continue
}

View File

@@ -79,9 +79,10 @@ type Table struct {
initOnce sync.Once
stopOnce sync.Once
refreshInterval int
logger *ctrld.Logger
dhcp *dhcp
merlin *merlinDiscover
ubios *ubiosDiscover
arp *arpDiscover
ndp *ndpDiscover
ptr *ptrDiscover
@@ -97,18 +98,11 @@ type Table struct {
ptrNameservers []string
}
func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string, logger *ctrld.Logger) *Table {
func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
refreshInterval := cfg.Service.DiscoverRefreshInterval
// Set default refresh interval if not configured
// This ensures client discovery continues to work even without explicit configuration
if refreshInterval <= 0 {
refreshInterval = 2 * 60 // 2 minutes
}
// Use no-op logger if none provided
// This prevents nil pointer dereferences when logging is not configured
if logger == nil {
logger = ctrld.NopLogger
}
return &Table{
svcCfg: cfg.Service,
quitCh: make(chan struct{}),
@@ -117,7 +111,6 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string, logger *ctrl
cdUID: cdUID,
ptrNameservers: ns,
refreshInterval: refreshInterval,
logger: logger,
}
}
@@ -186,7 +179,7 @@ func (t *Table) SetSelfIP(ip string) {
// initSelfDiscover initializes necessary client metadata for self query.
func (t *Table) initSelfDiscover() {
t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger}
t.dhcp = &dhcp{selfIP: t.selfIP}
t.dhcp.addSelf()
t.ipResolvers = append(t.ipResolvers, t.dhcp)
t.macResolvers = append(t.macResolvers, t.dhcp)
@@ -196,24 +189,48 @@ func (t *Table) initSelfDiscover() {
func (t *Table) init() {
// Custom client ID presents, use it as the only source.
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
t.logger.Debug().Msg("Start self discovery with custom client id")
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery with custom client id")
t.initSelfDiscover()
return
}
// If we are running on platforms that should only do self discover, use it as the only source, too.
if ctrld.SelfDiscover() {
t.logger.Debug().Msg("Start self discovery on desktop platforms")
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery on desktop platforms")
t.initSelfDiscover()
return
}
// Otherwise, process all possible sources in order, that means
// the first result of IP/MAC/Hostname lookup will be used.
//
// Routers custom clients:
// - Merlin
// - Ubios
if t.discoverDHCP() || t.discoverARP() {
t.merlin = &merlinDiscover{}
t.ubios = &ubiosDiscover{}
discovers := map[string]interface {
refresher
HostnameResolver
}{
"Merlin": t.merlin,
"Ubios": t.ubios,
}
for platform, discover := range discovers {
if err := discover.refresh(); err != nil {
ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform)
}
t.hostnameResolvers = append(t.hostnameResolvers, discover)
t.refreshers = append(t.refreshers, discover)
}
}
// Hosts file mapping.
if t.discoverHosts() {
t.hf = &hostsFile{logger: t.logger}
t.logger.Debug().Msg("Start hosts file discovery")
t.hf = &hostsFile{}
ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery")
if err := t.hf.init(); err != nil {
t.logger.Error().Err(err).Msg("Could not init hosts file discover")
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init hosts file discover")
} else {
t.hostnameResolvers = append(t.hostnameResolvers, t.hf)
t.refreshers = append(t.refreshers, t.hf)
@@ -222,10 +239,10 @@ func (t *Table) init() {
}
// DHCP lease files.
if t.discoverDHCP() {
t.dhcp = &dhcp{selfIP: t.selfIP, logger: t.logger}
t.logger.Debug().Msg("Start dhcp discovery")
t.dhcp = &dhcp{selfIP: t.selfIP}
ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery")
if err := t.dhcp.init(); err != nil {
t.logger.Error().Err(err).Msg("Could not init dhcp discover")
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init DHCP discover")
} else {
t.ipResolvers = append(t.ipResolvers, t.dhcp)
t.macResolvers = append(t.macResolvers, t.dhcp)
@@ -236,8 +253,8 @@ func (t *Table) init() {
// ARP/NDP table.
if t.discoverARP() {
t.arp = &arpDiscover{}
t.ndp = &ndpDiscover{logger: t.logger}
t.logger.Debug().Msg("Start arp discovery")
t.ndp = &ndpDiscover{}
ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery")
discovers := map[string]interface {
refresher
IpResolver
@@ -249,7 +266,7 @@ func (t *Table) init() {
for protocol, discover := range discovers {
if err := discover.refresh(); err != nil {
t.logger.Error().Err(err).Msgf("Could not init %s discover", protocol)
ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", protocol)
} else {
t.ipResolvers = append(t.ipResolvers, discover)
t.macResolvers = append(t.macResolvers, discover)
@@ -266,10 +283,7 @@ func (t *Table) init() {
}
// PTR lookup.
if t.discoverPTR() {
t.ptr = &ptrDiscover{
resolver: ctrld.NewPrivateResolver(context.Background()),
logger: t.logger,
}
t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()}
if len(t.ptrNameservers) > 0 {
nss := make([]string, 0, len(t.ptrNameservers))
for _, ns := range t.ptrNameservers {
@@ -278,22 +292,21 @@ func (t *Table) init() {
host, port = h, p
}
// Only use valid ip:port pair.
// Invalid nameservers can cause PTR discovery to fail silently
if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil {
nss = append(nss, net.JoinHostPort(host, port))
} else {
t.logger.Warn().Msgf("Ignoring invalid nameserver for ptr discover: %q", ns)
ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns)
}
}
if len(nss) > 0 {
t.ptr.resolver = ctrld.NewResolverWithNameserver(nss)
t.logger.Debug().Msgf("Using nameservers %v for ptr discovery", nss)
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss)
}
}
t.logger.Debug().Msg("Start ptr discovery")
ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery")
if err := t.ptr.refresh(); err != nil {
t.logger.Error().Err(err).Msg("Could not init ptr discover")
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover")
} else {
t.hostnameResolvers = append(t.hostnameResolvers, t.ptr)
t.refreshers = append(t.refreshers, t.ptr)
@@ -301,10 +314,10 @@ func (t *Table) init() {
}
// mdns.
if t.discoverMDNS() {
t.mdns = &mdns{logger: t.logger}
t.logger.Debug().Msg("Start mdns discovery")
t.mdns = &mdns{}
ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery")
if err := t.mdns.init(t.quitCh); err != nil {
t.logger.Error().Err(err).Msg("Could not init mdns discover")
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init mDNS discover")
} else {
t.hostnameResolvers = append(t.hostnameResolvers, t.mdns)
}
@@ -470,7 +483,6 @@ func (t *Table) ListClients() []*Client {
for _, c := range ipMap {
// If we found a client with empty hostname, use hostname from
// an existed client which has the same MAC address.
// This helps fill in missing hostnames when multiple IPs share the same MAC
if cFromMac := clientsByMAC[c.Mac]; cFromMac != nil && c.Hostname == "" {
c.Hostname = cFromMac.Hostname
}

View File

@@ -2,8 +2,6 @@ package clientinfo
import (
"testing"
"github.com/Control-D-Inc/ctrld"
)
func Test_normalizeIP(t *testing.T) {
@@ -30,9 +28,8 @@ func Test_normalizeIP(t *testing.T) {
func TestTable_LookupRFC1918IPv4(t *testing.T) {
table := &Table{
dhcp: &dhcp{},
arp: &arpDiscover{},
logger: ctrld.NopLogger,
dhcp: &dhcp{},
arp: &arpDiscover{},
}
table.ipResolvers = append(table.ipResolvers, table.dhcp)

Some files were not shown because too many files have changed in this diff Show More