Add comprehensive documentation to CLI components and core functionality

This commit extends the documentation effort by adding detailed explanatory
comments to key CLI components and core functionality throughout the cmd/
directory. The changes focus on explaining WHY certain logic is needed,
not just WHAT the code does, improving code maintainability and helping
developers understand complex business decisions.

Key improvements:
- Main entry points: Document CLI initialization, logging setup, and cache
  configuration with reasoning for design decisions
- DNS proxy core: Explain DNS proxy constants, data structures, and core
  processing pipeline for handling DNS queries
- Service management: Document service command structure, configuration
  patterns, and platform-specific service handling
- Logging infrastructure: Explain log buffer management, level encoders,
  and log formatting decisions for different use cases
- Metrics and monitoring: Document Prometheus metrics structure, HTTP
  endpoints, and conditional metric collection for performance
- Network handling: Explain Linux-specific network interface filtering,
  virtual interface detection, and DNS configuration management
- Hostname validation: Document RFC1123 compliance and DNS naming
  standards for system compatibility
- Mobile integration: Explain HTTP retry logic, fallback mechanisms, and
  mobile platform integration patterns
- Connection management: Document connection wrapper design to prevent
  log pollution during process lifecycle

Technical details:
- Added explanatory comments to 11 additional files in cmd/cli/
- Maintained consistent documentation style and format
- Preserved all existing functionality while improving code clarity
- Enhanced understanding of complex business logic and platform-specific
  behavior

These comments help future developers understand the reasoning behind
complex decisions, making the codebase more maintainable and reducing
the risk of incorrect modifications during maintenance.
This commit is contained in:
Cuong Manh Le
2025-08-07 15:49:20 +07:00
committed by Cuong Manh Le
parent d88c860cac
commit 4792183c0d
39 changed files with 249 additions and 22 deletions

View File

@@ -67,6 +67,7 @@ var (
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"} var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"}
// isNoConfigStart checks if the command is using no-config start mode
func isNoConfigStart(cmd *cobra.Command) bool { func isNoConfigStart(cmd *cobra.Command) bool {
for _, flagName := range basicModeFlags { for _, flagName := range basicModeFlags {
if cmd.Flags().Lookup(flagName).Changed { if cmd.Flags().Lookup(flagName).Changed {
@@ -85,6 +86,7 @@ _/ ___\ __\_ __ \ | / __ |
\/ dns forwarding proxy \/ \/ dns forwarding proxy \/
` `
// curVersion returns the current version string
func curVersion() string { func curVersion() string {
// Ensure version has proper "v" prefix for semantic versioning // Ensure version has proper "v" prefix for semantic versioning
// This is needed because some build systems may provide version without the "v" prefix // This is needed because some build systems may provide version without the "v" prefix
@@ -429,6 +431,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
<-stopCh <-stopCh
} }
// writeConfigFile writes the configuration to a file
func writeConfigFile(cfg *ctrld.Config) error { func writeConfigFile(cfg *ctrld.Config) error {
if cfu := v.ConfigFileUsed(); cfu != "" { if cfu := v.ConfigFileUsed(); cfu != "" {
defaultConfigFile = cfu defaultConfigFile = cfu
@@ -544,6 +547,7 @@ func readBase64Config(configBase64 string) error {
return v.ReadConfig(bytes.NewReader(configStr)) return v.ReadConfig(bytes.NewReader(configStr))
} }
// processNoConfigFlags processes flags for no-config mode
func processNoConfigFlags(noConfigStart bool) { func processNoConfigFlags(noConfigStart bool) {
if !noConfigStart { if !noConfigStart {
return return
@@ -607,6 +611,7 @@ func deactivationPinSet() bool {
return cdDeactivationPin.Load() != defaultDeactivationPin return cdDeactivationPin.Load() != defaultDeactivationPin
} }
// processCDFlags processes Control D related flags
func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) { func processCDFlags(cfg *ctrld.Config) (*controld.ResolverConfig, error) {
logger := mainLog.Load().With().Str("mode", "cd") logger := mainLog.Load().With().Str("mode", "cd")
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID) logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
@@ -743,6 +748,7 @@ func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) erro
return v.Unmarshal(&cfg) return v.Unmarshal(&cfg)
} }
// processListenFlag processes the listen flag
func processListenFlag() { func processListenFlag() {
if listenAddress == "" { if listenAddress == "" {
return return
@@ -764,6 +770,7 @@ func processListenFlag() {
}) })
} }
// processLogAndCacheFlags processes log and cache related flags
func processLogAndCacheFlags() { func processLogAndCacheFlags() {
if logPath != "" { if logPath != "" {
cfg.Service.LogPath = logPath cfg.Service.LogPath = logPath
@@ -779,6 +786,7 @@ func processLogAndCacheFlags() {
v.Set("service", cfg.Service) v.Set("service", cfg.Service)
} }
// netInterface returns the network interface by name
func netInterface(ifaceName string) (*net.Interface, error) { func netInterface(ifaceName string) (*net.Interface, error) {
if ifaceName == "auto" { if ifaceName == "auto" {
ifaceName = defaultIfaceName() ifaceName = defaultIfaceName()
@@ -798,6 +806,7 @@ func netInterface(ifaceName string) (*net.Interface, error) {
return iface, err return iface, err
} }
// defaultIfaceName returns the default interface name
func defaultIfaceName() string { func defaultIfaceName() string {
dri, err := netmon.DefaultRouteInterface() dri, err := netmon.DefaultRouteInterface()
if err != nil { if err != nil {
@@ -948,6 +957,7 @@ func selfCheckResolveDomain(ctx context.Context, addr, scope string, domain stri
return errSelfCheckNoAnswer return errSelfCheckNoAnswer
} }
// userHomeDir returns the user's home directory
func userHomeDir() (string, error) { func userHomeDir() (string, error) {
// Mobile platform should provide a rw dir path for this. // Mobile platform should provide a rw dir path for this.
if isMobile() { if isMobile() {
@@ -1394,6 +1404,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, notifyFunc func(), fatal bool) (
return return
} }
// dirWritable checks if a directory is writable
func dirWritable(dir string) (bool, error) { func dirWritable(dir string) (bool, error) {
f, err := os.CreateTemp(dir, "") f, err := os.CreateTemp(dir, "")
if err != nil { if err != nil {
@@ -1403,6 +1414,7 @@ func dirWritable(dir string) (bool, error) {
return true, f.Close() return true, f.Close()
} }
// osVersion returns the operating system version
func osVersion() string { func osVersion() string {
oi := osinfo.New() oi := osinfo.New()
if runtime.GOOS == "freebsd" { if runtime.GOOS == "freebsd" {
@@ -1544,6 +1556,7 @@ func checkStrFlagEmpty(cmd *cobra.Command, flagName string) {
} }
} }
// validateCdUpstreamProtocol validates the Control D upstream protocol
func validateCdUpstreamProtocol() { func validateCdUpstreamProtocol() {
if cdUID == "" { if cdUID == "" {
return return
@@ -1555,6 +1568,7 @@ func validateCdUpstreamProtocol() {
} }
} }
// validateCdAndNextDNSFlags validates that Control D and NextDNS flags are not used together
func validateCdAndNextDNSFlags() { func validateCdAndNextDNSFlags() {
if (cdUID != "" || cdOrg != "") && nextdns != "" { if (cdUID != "" || cdOrg != "") && nextdns != "" {
mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName) mainLog.Load().Fatal().Msgf("--%s/--%s could not be used with --%s", cdUidFlagName, cdOrgFlagName, nextdnsFlagName)
@@ -1595,6 +1609,7 @@ func doGenerateNextDNSConfig(uid string) error {
return writeConfigFile(&cfg) return writeConfigFile(&cfg)
} }
// noticeWritingControlDConfig logs on notice level that a Control D config is being written
func noticeWritingControlDConfig() error { func noticeWritingControlDConfig() error {
if cdUID != "" { if cdUID != "" {
mainLog.Load().Notice().Msgf("Generating controld config: %s", defaultConfigFile) mainLog.Load().Notice().Msgf("Generating controld config: %s", defaultConfigFile)

View File

@@ -10,6 +10,7 @@ import (
) )
// filterEmptyStrings removes empty strings from a slice // filterEmptyStrings removes empty strings from a slice
// This is used to clean up command line arguments and configuration values
func filterEmptyStrings(slice []string) []string { func filterEmptyStrings(slice []string) []string {
var result []string var result []string
for _, s := range slice { for _, s := range slice {
@@ -21,17 +22,20 @@ func filterEmptyStrings(slice []string) []string {
} }
// ServiceCommand handles service-related operations // ServiceCommand handles service-related operations
// This encapsulates all service management functionality for the CLI
type ServiceCommand struct { type ServiceCommand struct {
serviceManager *ServiceManager serviceManager *ServiceManager
} }
// initializeServiceManager creates a service manager with default configuration // initializeServiceManager creates a service manager with default configuration
// This sets up the basic service infrastructure needed for all service operations
func (sc *ServiceCommand) initializeServiceManager() (service.Service, *prog, error) { func (sc *ServiceCommand) initializeServiceManager() (service.Service, *prog, error) {
svcConfig := sc.createServiceConfig() svcConfig := sc.createServiceConfig()
return sc.initializeServiceManagerWithServiceConfig(svcConfig) return sc.initializeServiceManagerWithServiceConfig(svcConfig)
} }
// initializeServiceManagerWithServiceConfig creates a service manager with the given configuration // initializeServiceManagerWithServiceConfig creates a service manager with the given configuration
// This allows for custom service configuration while maintaining the same initialization pattern
func (sc *ServiceCommand) initializeServiceManagerWithServiceConfig(svcConfig *service.Config) (service.Service, *prog, error) { func (sc *ServiceCommand) initializeServiceManagerWithServiceConfig(svcConfig *service.Config) (service.Service, *prog, error) {
p := &prog{} p := &prog{}
@@ -45,6 +49,7 @@ func (sc *ServiceCommand) initializeServiceManagerWithServiceConfig(svcConfig *s
} }
// newService creates a new service instance using the provided program and configuration. // newService creates a new service instance using the provided program and configuration.
// This abstracts the service creation process for different operating systems
func (sc *ServiceCommand) newService(p *prog, svcConfig *service.Config) (service.Service, error) { func (sc *ServiceCommand) newService(p *prog, svcConfig *service.Config) (service.Service, error) {
s, err := newService(p, svcConfig) s, err := newService(p, svcConfig)
if err != nil { if err != nil {
@@ -54,11 +59,13 @@ func (sc *ServiceCommand) newService(p *prog, svcConfig *service.Config) (servic
} }
// NewServiceCommand creates a new service command handler // NewServiceCommand creates a new service command handler
// This provides a clean factory method for creating service command instances
func NewServiceCommand() *ServiceCommand { func NewServiceCommand() *ServiceCommand {
return &ServiceCommand{} return &ServiceCommand{}
} }
// createServiceConfig creates a properly initialized service configuration // createServiceConfig creates a properly initialized service configuration
// This ensures consistent service naming and description across all platforms
func (sc *ServiceCommand) createServiceConfig() *service.Config { func (sc *ServiceCommand) createServiceConfig() *service.Config {
return &service.Config{ return &service.Config{
Name: ctrldServiceName, Name: ctrldServiceName,
@@ -69,6 +76,7 @@ func (sc *ServiceCommand) createServiceConfig() *service.Config {
} }
// InitServiceCmd creates the service command with proper logic and aliases // InitServiceCmd creates the service command with proper logic and aliases
// This sets up all service-related subcommands with appropriate permissions and flags
func InitServiceCmd(rootCmd *cobra.Command) *cobra.Command { func InitServiceCmd(rootCmd *cobra.Command) *cobra.Command {
// Create service command handlers // Create service command handlers
sc := NewServiceCommand() sc := NewServiceCommand()

View File

@@ -8,44 +8,60 @@ import (
// logConn wraps a net.Conn, override the Write behavior. // logConn wraps a net.Conn, override the Write behavior.
// runCmd uses this wrapper, so as long as startCmd finished, // runCmd uses this wrapper, so as long as startCmd finished,
// ctrld log won't be flushed with un-necessary write errors. // ctrld log won't be flushed with un-necessary write errors.
// This prevents log pollution when the parent process closes the connection
type logConn struct { type logConn struct {
conn net.Conn conn net.Conn
} }
// Read delegates to the underlying connection
// This maintains normal read behavior for the wrapped connection
func (lc *logConn) Read(b []byte) (n int, err error) { func (lc *logConn) Read(b []byte) (n int, err error) {
return lc.conn.Read(b) return lc.conn.Read(b)
} }
// Close delegates to the underlying connection
// This ensures proper cleanup of the wrapped connection
func (lc *logConn) Close() error { func (lc *logConn) Close() error {
return lc.conn.Close() return lc.conn.Close()
} }
// LocalAddr delegates to the underlying connection
// This provides access to local address information
func (lc *logConn) LocalAddr() net.Addr { func (lc *logConn) LocalAddr() net.Addr {
return lc.conn.LocalAddr() return lc.conn.LocalAddr()
} }
// RemoteAddr delegates to the underlying connection
// This provides access to remote address information
func (lc *logConn) RemoteAddr() net.Addr { func (lc *logConn) RemoteAddr() net.Addr {
return lc.conn.RemoteAddr() return lc.conn.RemoteAddr()
} }
// SetDeadline delegates to the underlying connection
// This maintains timeout functionality for the wrapped connection
func (lc *logConn) SetDeadline(t time.Time) error { func (lc *logConn) SetDeadline(t time.Time) error {
return lc.conn.SetDeadline(t) return lc.conn.SetDeadline(t)
} }
// SetReadDeadline delegates to the underlying connection
// This maintains read timeout functionality for the wrapped connection
func (lc *logConn) SetReadDeadline(t time.Time) error { func (lc *logConn) SetReadDeadline(t time.Time) error {
return lc.conn.SetReadDeadline(t) return lc.conn.SetReadDeadline(t)
} }
// SetWriteDeadline delegates to the underlying connection
// This maintains write timeout functionality for the wrapped connection
func (lc *logConn) SetWriteDeadline(t time.Time) error { func (lc *logConn) SetWriteDeadline(t time.Time) error {
return lc.conn.SetWriteDeadline(t) return lc.conn.SetWriteDeadline(t)
} }
// Write performs writes with underlying net.Conn, ignore any errors happen.
// "ctrld run" command use this wrapper to report errors to "ctrld start".
// If no error occurred, "ctrld start" may finish before "ctrld run" attempt
// to close the connection, so ignore errors conservatively here, prevent
// un-necessary error "write to closed connection" flushed to ctrld log.
// This prevents log pollution when the parent process closes the connection prematurely
func (lc *logConn) Write(b []byte) (int, error) { func (lc *logConn) Write(b []byte) (int, error) {
// Write performs writes with underlying net.Conn, ignore any errors happen.
// "ctrld run" command use this wrapper to report errors to "ctrld start".
// If no error occurred, "ctrld start" may finish before "ctrld run" attempt
// to close the connection, so ignore errors conservatively here, prevent
// un-necessary error "write to closed connection" flushed to ctrld log.
_, _ = lc.conn.Write(b) _, _ = lc.conn.Write(b)
return len(b), nil return len(b), nil
} }

View File

@@ -8,10 +8,12 @@ import (
"time" "time"
) )
// controlClient represents an HTTP client for communicating with the control server
type controlClient struct { type controlClient struct {
c *http.Client c *http.Client
} }
// newControlClient creates a new control client with Unix socket transport
func newControlClient(addr string) *controlClient { func newControlClient(addr string) *controlClient {
return &controlClient{c: &http.Client{ return &controlClient{c: &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{

View File

@@ -37,12 +37,14 @@ type ifaceResponse struct {
OK bool `json:"ok"` OK bool `json:"ok"`
} }
// controlServer represents an HTTP server for handling control requests
type controlServer struct { type controlServer struct {
server *http.Server server *http.Server
mux *http.ServeMux mux *http.ServeMux
addr string addr string
} }
// newControlServer creates a new control server instance
func newControlServer(addr string) (*controlServer, error) { func newControlServer(addr string) (*controlServer, error) {
mux := http.NewServeMux() mux := http.NewServeMux()
s := &controlServer{ s := &controlServer{
@@ -338,6 +340,7 @@ func (p *prog) registerControlServerHandler() {
})) }))
} }
// jsonResponse wraps an HTTP handler to set JSON content type
func jsonResponse(next http.Handler) http.Handler { func jsonResponse(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")

View File

@@ -27,24 +27,37 @@ import (
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net" ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
) )
// DNS proxy constants for configuration and behavior control
const ( const (
// staleTTL is the TTL for stale cache entries
// This allows serving cached responses even when upstreams are temporarily unavailable
staleTTL = 60 * time.Second staleTTL = 60 * time.Second
// localTTL is the TTL for local network responses
// Longer TTL for local queries reduces unnecessary repeated lookups
localTTL = 3600 * time.Second localTTL = 3600 * time.Second
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option. // EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81 // https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification. // This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
// This enables MAC address-based client identification for policy routing
EDNS0_OPTION_MAC = 0xFDE9 EDNS0_OPTION_MAC = 0xFDE9
// selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation. // selfUninstallMaxQueries is number of REFUSED queries seen before checking for self-uninstallation.
// This prevents premature self-uninstallation due to temporary network issues
selfUninstallMaxQueries = 32 selfUninstallMaxQueries = 32
) )
// osUpstreamConfig defines the default OS resolver configuration
// This is used as a fallback when all configured upstreams fail
var osUpstreamConfig = &ctrld.UpstreamConfig{ var osUpstreamConfig = &ctrld.UpstreamConfig{
Name: "OS resolver", Name: "OS resolver",
Type: ctrld.ResolverTypeOS, Type: ctrld.ResolverTypeOS,
Timeout: 3000, Timeout: 3000,
} }
// privateUpstreamConfig defines the default private resolver configuration
// This is used for internal network queries that should not go to public resolvers
var privateUpstreamConfig = &ctrld.UpstreamConfig{ var privateUpstreamConfig = &ctrld.UpstreamConfig{
Name: "Private resolver", Name: "Private resolver",
Type: ctrld.ResolverTypePrivate, Type: ctrld.ResolverTypePrivate,
@@ -52,6 +65,7 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{
} }
// proxyRequest contains data for proxying a DNS query to upstream. // proxyRequest contains data for proxying a DNS query to upstream.
// This structure encapsulates all the information needed to process a DNS request
type proxyRequest struct { type proxyRequest struct {
msg *dns.Msg msg *dns.Msg
ci *ctrld.ClientInfo ci *ctrld.ClientInfo
@@ -63,6 +77,7 @@ type proxyRequest struct {
} }
// proxyResponse contains data for proxying a DNS response from upstream. // proxyResponse contains data for proxying a DNS response from upstream.
// This structure encapsulates the response and metadata for logging and metrics
type proxyResponse struct { type proxyResponse struct {
answer *dns.Msg answer *dns.Msg
upstream string upstream string
@@ -72,6 +87,7 @@ type proxyResponse struct {
} }
// upstreamForResult represents the result of processing rules for a request. // upstreamForResult represents the result of processing rules for a request.
// This contains the matched policy information for logging and debugging
type upstreamForResult struct { type upstreamForResult struct {
upstreams []string upstreams []string
matchedPolicy string matchedPolicy string
@@ -81,7 +97,9 @@ type upstreamForResult struct {
srcAddr string srcAddr string
} }
func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error { // serveDNS sets up and starts a DNS server on the specified listener, handling DNS queries and network monitoring.
// This is the main entry point for DNS server functionality
func (p *prog) serveDNS(ctx context.Context, listenerNum string) error {
listenerConfig := p.cfg.Listener[listenerNum] listenerConfig := p.cfg.Listener[listenerNum]
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip") p.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
@@ -92,11 +110,12 @@ func (p *prog) serveDNS(mainCtx context.Context, listenerNum string) error {
p.handleDNSQuery(w, m, listenerNum, listenerConfig) p.handleDNSQuery(w, m, listenerNum, listenerConfig)
}) })
return p.startListeners(mainCtx, listenerConfig, handler) return p.startListeners(ctx, listenerConfig, handler)
} }
// startListeners starts DNS listeners on specified configurations, supporting UDP and TCP protocols. // startListeners starts DNS listeners on specified configurations, supporting UDP and TCP protocols.
// It handles local IPv6, RFC 1918, and specified IP listeners, reacting to stop signals or errors. // It handles local IPv6, RFC 1918, and specified IP listeners, reacting to stop signals or errors.
// This function manages the lifecycle of DNS server listeners
func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, handler dns.Handler) error { func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, handler dns.Handler) error {
g, gctx := errgroup.WithContext(ctx) g, gctx := errgroup.WithContext(ctx)
@@ -153,6 +172,7 @@ func (p *prog) startListeners(ctx context.Context, cfg *ctrld.ListenerConfig, ha
} }
// handleDNSQuery processes incoming DNS queries, validates client access, and routes the query to appropriate handlers. // handleDNSQuery processes incoming DNS queries, validates client access, and routes the query to appropriate handlers.
// This is the main entry point for all DNS query processing
func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum string, listenerConfig *ctrld.ListenerConfig) { func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum string, listenerConfig *ctrld.ListenerConfig) {
p.sema.acquire() p.sema.acquire()
defer p.sema.release() defer p.sema.release()
@@ -191,6 +211,7 @@ func (p *prog) handleDNSQuery(w dns.ResponseWriter, m *dns.Msg, listenerNum stri
} }
// handleSpecialDomains processes special domain queries, handles errors, purges cache if necessary, and returns a bool status. // handleSpecialDomains processes special domain queries, handles errors, purges cache if necessary, and returns a bool status.
// This handles internal test domains and cache management commands
func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m *dns.Msg, domain string) bool { func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m *dns.Msg, domain string) bool {
switch { switch {
case domain == "": case domain == "":
@@ -211,6 +232,7 @@ func (p *prog) handleSpecialDomains(ctx context.Context, w dns.ResponseWriter, m
} }
// standardQueryRequest represents a standard DNS query request with associated context and configuration. // standardQueryRequest represents a standard DNS query request with associated context and configuration.
// This encapsulates all the data needed to process a standard DNS query
type standardQueryRequest struct { type standardQueryRequest struct {
ctx context.Context ctx context.Context
writer dns.ResponseWriter writer dns.ResponseWriter
@@ -221,6 +243,7 @@ type standardQueryRequest struct {
} }
// processStandardQuery handles a standard DNS query by routing it through appropriate upstreams and writing a DNS response. // processStandardQuery handles a standard DNS query by routing it through appropriate upstreams and writing a DNS response.
// This is the main processing pipeline for normal DNS queries
func (p *prog) processStandardQuery(req *standardQueryRequest) { func (p *prog) processStandardQuery(req *standardQueryRequest) {
remoteIP, _, _ := net.SplitHostPort(req.writer.RemoteAddr().String()) remoteIP, _, _ := net.SplitHostPort(req.writer.RemoteAddr().String())
ci := p.getClientInfo(remoteIP, req.msg) ci := p.getClientInfo(remoteIP, req.msg)

View File

@@ -4,11 +4,15 @@ import "regexp"
// validHostname reports whether hostname is a valid hostname. // validHostname reports whether hostname is a valid hostname.
// A valid hostname contains 3 -> 64 characters and conform to RFC1123. // A valid hostname contains 3 -> 64 characters and conform to RFC1123.
// This function validates hostnames to ensure they meet DNS naming standards
// and prevents invalid hostnames from being used in DNS configurations
func validHostname(hostname string) bool { func validHostname(hostname string) bool {
hostnameLen := len(hostname) hostnameLen := len(hostname)
if hostnameLen < 3 || hostnameLen > 64 { if hostnameLen < 3 || hostnameLen > 64 {
return false return false
} }
// RFC1123 regex pattern ensures hostnames follow DNS naming conventions
// This prevents issues with DNS resolution and system compatibility
validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
return validHostnameRfc1123.MatchString(hostname) return validHostnameRfc1123.MatchString(hostname)
} }

View File

@@ -9,6 +9,7 @@ import (
// AppCallback provides hooks for injecting certain functionalities // AppCallback provides hooks for injecting certain functionalities
// from mobile platforms to main ctrld cli. // from mobile platforms to main ctrld cli.
// This allows mobile applications to customize behavior without modifying core CLI code
type AppCallback struct { type AppCallback struct {
HostName func() string HostName func() string
LanIp func() string LanIp func() string
@@ -17,6 +18,7 @@ type AppCallback struct {
} }
// AppConfig allows overwriting ctrld cli flags from mobile platforms. // AppConfig allows overwriting ctrld cli flags from mobile platforms.
// This provides a clean interface for mobile apps to configure ctrld behavior
type AppConfig struct { type AppConfig struct {
CdUID string CdUID string
HomeDir string HomeDir string
@@ -25,18 +27,29 @@ type AppConfig struct {
LogPath string LogPath string
} }
// Network and HTTP configuration constants
const ( const (
// defaultHTTPTimeout provides reasonable timeout for HTTP operations
// This prevents hanging requests while allowing sufficient time for network delays
defaultHTTPTimeout = 30 * time.Second defaultHTTPTimeout = 30 * time.Second
defaultMaxRetries = 3
downloadServerIp = "23.171.240.151" // defaultMaxRetries provides retry attempts for failed HTTP requests
// This improves reliability in unstable network conditions
defaultMaxRetries = 3
// downloadServerIp is the fallback IP for download operations
// This ensures downloads work even when DNS resolution fails
downloadServerIp = "23.171.240.151"
) )
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback // httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
// This ensures reliable HTTP operations by preferring IPv4 and handling timeouts gracefully
func httpClientWithFallback(timeout time.Duration) *http.Client { func httpClientWithFallback(timeout time.Duration) *http.Client {
return &http.Client{ return &http.Client{
Timeout: timeout, Timeout: timeout,
Transport: &http.Transport{ Transport: &http.Transport{
// Prefer IPv4 over IPv6 // Prefer IPv4 over IPv6
// This improves compatibility with networks that have IPv6 issues
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
@@ -47,6 +60,7 @@ func httpClientWithFallback(timeout time.Duration) *http.Client {
} }
// doWithRetry performs an HTTP request with retries // doWithRetry performs an HTTP request with retries
// This improves reliability by automatically retrying failed requests with exponential backoff
func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) { func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) {
var lastErr error var lastErr error
client := httpClientWithFallback(defaultHTTPTimeout) client := httpClientWithFallback(defaultHTTPTimeout)
@@ -58,7 +72,8 @@ func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response,
} }
for attempt := 0; attempt < maxRetries; attempt++ { for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 { if attempt > 0 {
time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff // Linear backoff reduces server load and improves success rate
time.Sleep(time.Second * time.Duration(attempt+1))
} }
resp, err := client.Do(req) resp, err := client.Do(req)
@@ -84,6 +99,7 @@ func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response,
} }
// Helper for making GET requests with retries // Helper for making GET requests with retries
// This provides a simplified interface for common GET operations with built-in retry logic
func getWithRetry(url string, ip string) (*http.Response, error) { func getWithRetry(url string, ip string) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil { if err != nil {

View File

@@ -16,12 +16,30 @@ import (
"github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld"
) )
// Log writer constants for buffer management and log formatting
const ( const (
// logWriterSize is the default buffer size for log writers
// This provides sufficient space for runtime logs without excessive memory usage
logWriterSize = 1024 * 1024 * 5 // 5 MB logWriterSize = 1024 * 1024 * 5 // 5 MB
// logWriterSmallSize is used for memory-constrained environments
// This reduces memory footprint while still maintaining log functionality
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
// logWriterInitialSize is the initial buffer allocation
// This provides immediate space for early log entries
logWriterInitialSize = 32 * 1024 // 32 KB logWriterInitialSize = 32 * 1024 // 32 KB
// logWriterSentInterval controls how often logs are sent to external systems
// This balances real-time logging with system performance
logWriterSentInterval = time.Minute logWriterSentInterval = time.Minute
// logWriterInitEndMarker marks the end of initialization logs
// This helps separate startup logs from runtime logs
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n" logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
// logWriterLogEndMarker marks the end of log sections
// This provides clear boundaries for log parsing and analysis
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n" logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
) )
@@ -31,6 +49,8 @@ const (
// Note: WARN messages will also display as "NOTICE" because they share the same level value. // Note: WARN messages will also display as "NOTICE" because they share the same level value.
// This is the intended behavior for visual distinction. // This is the intended behavior for visual distinction.
// noticeLevelEncoder provides custom level encoding for NOTICE level
// This ensures NOTICE messages are clearly distinguished from other log levels
func noticeLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { func noticeLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
switch l { switch l {
case ctrld.NoticeLevel: case ctrld.NoticeLevel:
@@ -40,6 +60,8 @@ func noticeLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
} }
} }
// noticeColorLevelEncoder provides colored level encoding for NOTICE level
// This uses cyan color to make NOTICE messages visually distinct in terminal output
func noticeColorLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) { func noticeColorLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
switch l { switch l {
case ctrld.NoticeLevel: case ctrld.NoticeLevel:
@@ -49,21 +71,28 @@ func noticeColorLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder)
} }
} }
// logViewResponse represents the response structure for log viewing requests
// This provides a consistent JSON format for log data retrieval
type logViewResponse struct { type logViewResponse struct {
Data string `json:"data"` Data string `json:"data"`
} }
// logSentResponse represents the response structure for log sending operations
// This includes size information and error details for debugging
type logSentResponse struct { type logSentResponse struct {
Size int64 `json:"size"` Size int64 `json:"size"`
Error string `json:"error"` Error string `json:"error"`
} }
// logReader provides read access to log data with size information
// This encapsulates the log reading functionality for external consumers
type logReader struct { type logReader struct {
r io.ReadCloser r io.ReadCloser
size int64 size int64
} }
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled. // logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
// This provides in-memory log storage for debugging and monitoring purposes
type logWriter struct { type logWriter struct {
mu sync.Mutex mu sync.Mutex
buf bytes.Buffer buf bytes.Buffer
@@ -71,30 +100,37 @@ type logWriter struct {
} }
// newLogWriter creates an internal log writer. // newLogWriter creates an internal log writer.
// This provides the default log writer with standard buffer size
func newLogWriter() *logWriter { func newLogWriter() *logWriter {
return newLogWriterWithSize(logWriterSize) return newLogWriterWithSize(logWriterSize)
} }
// newSmallLogWriter creates an internal log writer with small buffer size. // newSmallLogWriter creates an internal log writer with small buffer size.
// This is used in memory-constrained environments or for temporary logging
func newSmallLogWriter() *logWriter { func newSmallLogWriter() *logWriter {
return newLogWriterWithSize(logWriterSmallSize) return newLogWriterWithSize(logWriterSmallSize)
} }
// newLogWriterWithSize creates an internal log writer with a given buffer size. // newLogWriterWithSize creates an internal log writer with a given buffer size.
// This allows customization of log buffer size based on specific requirements
func newLogWriterWithSize(size int) *logWriter { func newLogWriterWithSize(size int) *logWriter {
lw := &logWriter{size: size} lw := &logWriter{size: size}
return lw return lw
} }
// Write implements io.Writer interface for logWriter
// This manages buffer overflow by discarding old data while preserving important markers
func (lw *logWriter) Write(p []byte) (int, error) { func (lw *logWriter) Write(p []byte) (int, error) {
lw.mu.Lock() lw.mu.Lock()
defer lw.mu.Unlock() defer lw.mu.Unlock()
// If writing p causes overflows, discard old data. // If writing p causes overflows, discard old data.
// This prevents unbounded memory growth while maintaining recent logs
if lw.buf.Len()+len(p) > lw.size { if lw.buf.Len()+len(p) > lw.size {
buf := lw.buf.Bytes() buf := lw.buf.Bytes()
haveEndMarker := false haveEndMarker := false
// If there's init end marker already, preserve the data til the marker. // If there's init end marker already, preserve the data til the marker.
// This ensures initialization logs are always available for debugging
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 { if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
buf = buf[:idx+len(logWriterInitEndMarker)] buf = buf[:idx+len(logWriterInitEndMarker)]
haveEndMarker = true haveEndMarker = true

View File

@@ -138,7 +138,7 @@ func (p *prog) checkDnsLoopTicker(ctx context.Context) {
} }
} }
// loopTestMsg generates DNS message for checking loop. // loopTestMsg creates a DNS test message for loop detection
func loopTestMsg(uid string) *dns.Msg { func loopTestMsg(uid string) *dns.Msg {
msg := new(dns.Msg) msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype) msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)

View File

@@ -13,6 +13,8 @@ import (
"github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld"
) )
// Global variables for CLI configuration and state management
// These are used across multiple commands and need to persist throughout the application lifecycle
var ( var (
configPath string configPath string
configBase64 string configBase64 string
@@ -46,6 +48,8 @@ var (
noConfigStart bool noConfigStart bool
) )
// Flag name constants for consistent reference across the codebase
// Using constants prevents typos and makes refactoring easier
const ( const (
cdUidFlagName = "cd" cdUidFlagName = "cd"
cdOrgFlagName = "cd-org" cdOrgFlagName = "cd-org"
@@ -53,11 +57,15 @@ const (
nextdnsFlagName = "nextdns" nextdnsFlagName = "nextdns"
) )
// init initializes the default logger before any CLI commands are executed
// This ensures logging is available even during early initialization phases
func init() { func init() {
l := zap.NewNop() l := zap.NewNop()
mainLog.Store(&ctrld.Logger{Logger: l}) mainLog.Store(&ctrld.Logger{Logger: l})
} }
// Main is the entry point for the CLI application
// It initializes configuration, sets up the CLI structure, and executes the root command
func Main() { func Main() {
ctrld.InitConfig(v, "ctrld") ctrld.InitConfig(v, "ctrld")
rootCmd := initCLI() rootCmd := initCLI()
@@ -67,6 +75,8 @@ func Main() {
} }
} }
// normalizeLogFilePath converts relative log file paths to absolute paths
// This ensures log files are created in predictable locations regardless of working directory
func normalizeLogFilePath(logFilePath string) string { func normalizeLogFilePath(logFilePath string) string {
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() { if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
return logFilePath return logFilePath
@@ -82,18 +92,19 @@ func normalizeLogFilePath(logFilePath string) string {
} }
// initConsoleLogging initializes console logging, then storing to mainLog. // initConsoleLogging initializes console logging, then storing to mainLog.
// This sets up human-readable logging output for interactive use
func initConsoleLogging() { func initConsoleLogging() {
consoleWriterLevel = ctrld.NoticeLevel consoleWriterLevel = ctrld.NoticeLevel
switch { switch {
case silent: case silent:
// For silent mode, use a no-op logger // For silent mode, use a no-op logger to suppress all output
l := zap.NewNop() l := zap.NewNop()
mainLog.Store(&ctrld.Logger{Logger: l}) mainLog.Store(&ctrld.Logger{Logger: l})
case verbose == 1: case verbose == 1:
// Info level // Info level provides basic operational information
consoleWriterLevel = zapcore.InfoLevel consoleWriterLevel = zapcore.InfoLevel
case verbose > 1: case verbose > 1:
// Debug level // Debug level provides detailed diagnostic information
consoleWriterLevel = zapcore.DebugLevel consoleWriterLevel = zapcore.DebugLevel
} }
consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel) consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel)
@@ -105,6 +116,7 @@ func initConsoleLogging() {
// to be used for all interactive commands. // to be used for all interactive commands.
// //
// Current log file config will also be ignored. // Current log file config will also be ignored.
// This prevents log file conflicts during interactive command execution
func initInteractiveLogging() { func initInteractiveLogging() {
old := cfg.Service.LogPath old := cfg.Service.LogPath
cfg.Service.LogPath = "" cfg.Service.LogPath = ""
@@ -122,19 +134,23 @@ func initLoggingWithBackup(doBackup bool) []zapcore.Core {
var writers []io.Writer var writers []io.Writer
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" { if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
// Create parent directory if necessary. // Create parent directory if necessary.
// This ensures log files can be created even if the directory doesn't exist
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil { if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
mainLog.Load().Error().Msgf("failed to create log path: %v", err) mainLog.Load().Error().Msgf("failed to create log path: %v", err)
os.Exit(1) os.Exit(1)
} }
// Default open log file in append mode. // Default open log file in append mode.
// This preserves existing log entries across restarts
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
if doBackup { if doBackup {
// Backup old log file with .1 suffix. // Backup old log file with .1 suffix.
// This prevents log file corruption during rotation
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) { if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
mainLog.Load().Error().Msgf("could not backup old log file: %v", err) mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
} else { } else {
// Backup was created, set flags for truncating old log file. // Backup was created, set flags for truncating old log file.
// This ensures a clean start for the new log file
flags = os.O_CREATE | os.O_RDWR flags = os.O_CREATE | os.O_RDWR
} }
} }
@@ -147,14 +163,16 @@ func initLoggingWithBackup(doBackup bool) []zapcore.Core {
} }
// Create zap cores for different writers // Create zap cores for different writers
// Multiple cores allow logging to both console and file simultaneously
var cores []zapcore.Core var cores []zapcore.Core
cores = append(cores, consoleWriter) cores = append(cores, consoleWriter)
// Determine log level // Determine log level based on verbosity and configuration
// This provides flexible logging control for different use cases
logLevel := cfg.Service.LogLevel logLevel := cfg.Service.LogLevel
switch { switch {
case silent: case silent:
// For silent mode, use a no-op logger // For silent mode, use a no-op logger to suppress all output
l := zap.NewNop() l := zap.NewNop()
mainLog.Store(&ctrld.Logger{Logger: l}) mainLog.Store(&ctrld.Logger{Logger: l})
return cores return cores
@@ -164,7 +182,8 @@ func initLoggingWithBackup(doBackup bool) []zapcore.Core {
logLevel = "debug" logLevel = "debug"
} }
// Parse log level // Parse log level string to zapcore.Level
// This provides human-readable log level configuration
var level zapcore.Level var level zapcore.Level
switch logLevel { switch logLevel {
case "debug": case "debug":
@@ -183,12 +202,14 @@ func initLoggingWithBackup(doBackup bool) []zapcore.Core {
consoleWriter.Enabled(level) consoleWriter.Enabled(level)
// Add cores for all writers // Add cores for all writers
// This enables multi-destination logging (console + file)
for _, writer := range writers { for _, writer := range writers {
core := newMachineFriendlyZapCore(writer, level) core := newMachineFriendlyZapCore(writer, level)
cores = append(cores, core) cores = append(cores, core)
} }
// Create a multi-core logger // Create a multi-core logger
// This allows simultaneous logging to multiple destinations
multiCore := zapcore.NewTee(cores...) multiCore := zapcore.NewTee(cores...)
logger := zap.New(multiCore) logger := zap.New(multiCore)
mainLog.Store(&ctrld.Logger{Logger: logger}) mainLog.Store(&ctrld.Logger{Logger: logger})
@@ -196,11 +217,14 @@ func initLoggingWithBackup(doBackup bool) []zapcore.Core {
return cores return cores
} }
// initCache initializes DNS cache configuration
// This improves performance by caching frequently requested DNS responses
func initCache() { func initCache() {
if !cfg.Service.CacheEnable { if !cfg.Service.CacheEnable {
return return
} }
if cfg.Service.CacheSize == 0 { if cfg.Service.CacheSize == 0 {
// Default cache size provides good balance between memory usage and performance
cfg.Service.CacheSize = 4096 cfg.Service.CacheSize = 4096
} }
} }

View File

@@ -15,6 +15,7 @@ import (
) )
// metricsServer represents a server to expose Prometheus metrics via HTTP. // metricsServer represents a server to expose Prometheus metrics via HTTP.
// This provides monitoring and observability for the DNS proxy service
type metricsServer struct { type metricsServer struct {
server *http.Server server *http.Server
mux *http.ServeMux mux *http.ServeMux
@@ -24,6 +25,7 @@ type metricsServer struct {
} }
// newMetricsServer returns new metrics server. // newMetricsServer returns new metrics server.
// This initializes the HTTP server for exposing Prometheus metrics
func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) { func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) {
mux := http.NewServeMux() mux := http.NewServeMux()
ms := &metricsServer{ ms := &metricsServer{
@@ -37,11 +39,13 @@ func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, er
} }
// register adds handlers for given pattern. // register adds handlers for given pattern.
// This provides a clean interface for adding HTTP endpoints to the metrics server
func (ms *metricsServer) register(pattern string, handler http.Handler) { func (ms *metricsServer) register(pattern string, handler http.Handler) {
ms.mux.Handle(pattern, handler) ms.mux.Handle(pattern, handler)
} }
// registerMetricsServerHandler adds handlers for metrics server. // registerMetricsServerHandler adds handlers for metrics server.
// This sets up both Prometheus format and JSON format endpoints for metrics
func (ms *metricsServer) registerMetricsServerHandler() { func (ms *metricsServer) registerMetricsServerHandler() {
ms.register("/metrics", promhttp.HandlerFor( ms.register("/metrics", promhttp.HandlerFor(
ms.reg, ms.reg,
@@ -74,6 +78,7 @@ func (ms *metricsServer) registerMetricsServerHandler() {
} }
// start runs the metricsServer. // start runs the metricsServer.
// This starts the HTTP server for metrics exposure
func (ms *metricsServer) start() error { func (ms *metricsServer) start() error {
listener, err := net.Listen("tcp", ms.addr) listener, err := net.Listen("tcp", ms.addr)
if err != nil { if err != nil {
@@ -85,6 +90,7 @@ func (ms *metricsServer) start() error {
} }
// stop shutdowns the metricsServer within 2 seconds timeout. // stop shutdowns the metricsServer within 2 seconds timeout.
// This ensures graceful shutdown of the metrics server
func (ms *metricsServer) stop() error { func (ms *metricsServer) stop() error {
if !ms.started { if !ms.started {
return nil return nil
@@ -95,6 +101,7 @@ func (ms *metricsServer) stop() error {
} }
// runMetricsServer initializes metrics stats and runs the metrics server if enabled. // runMetricsServer initializes metrics stats and runs the metrics server if enabled.
// This sets up the complete metrics infrastructure including Prometheus collectors
func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) { func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
if !p.metricsEnabled() { if !p.metricsEnabled() {
return return

View File

@@ -12,16 +12,20 @@ import (
"github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld"
) )
// patchNetIfaceName patches network interface names on Linux
// This is a no-op on Linux as interface names don't need special handling
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
// validInterface reports whether the *net.Interface is a valid one. // validInterface reports whether the *net.Interface is a valid one.
// Only non-virtual interfaces are considered valid. // Only non-virtual interfaces are considered valid.
// This prevents DNS configuration on virtual interfaces like docker, veth, etc.
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
_, ok := validIfacesMap[iface.Name] _, ok := validIfacesMap[iface.Name]
return ok return ok
} }
// validInterfacesMap returns a set containing non virtual interfaces. // validInterfacesMap returns a set containing non virtual interfaces.
// This filters out virtual interfaces to ensure DNS is only configured on physical interfaces
func validInterfacesMap(ctx context.Context) map[string]struct{} { func validInterfacesMap(ctx context.Context) map[string]struct{} {
m := make(map[string]struct{}) m := make(map[string]struct{})
vis := virtualInterfaces(ctx) vis := virtualInterfaces(ctx)
@@ -32,6 +36,7 @@ func validInterfacesMap(ctx context.Context) map[string]struct{} {
m[i.Name] = struct{}{} m[i.Name] = struct{}{}
}) })
// Fallback to the default route interface if found nothing. // Fallback to the default route interface if found nothing.
// This ensures we always have at least one interface to configure
if len(m) == 0 { if len(m) == 0 {
defaultRoute, err := netmon.DefaultRoute() defaultRoute, err := netmon.DefaultRoute()
if err != nil { if err != nil {
@@ -43,6 +48,8 @@ func validInterfacesMap(ctx context.Context) map[string]struct{} {
} }
// virtualInterfaces returns a map of virtual interfaces on the current machine. // virtualInterfaces returns a map of virtual interfaces on the current machine.
// This reads from /sys/devices/virtual/net to identify virtual network interfaces
// Virtual interfaces should not have DNS configured as they don't represent physical network connections
func virtualInterfaces(ctx context.Context) map[string]struct{} { func virtualInterfaces(ctx context.Context) map[string]struct{} {
logger := ctrld.LoggerFromCtx(ctx) logger := ctrld.LoggerFromCtx(ctx)
s := make(map[string]struct{}) s := make(map[string]struct{})

View File

@@ -9,8 +9,10 @@ import (
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
) )
// patchNetIfaceName patches network interface names on non-Linux/Darwin platforms
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil } func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
// validInterface checks if an interface is valid on non-Linux/Darwin platforms
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true } func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
// validInterfacesMap returns a set containing only default route interfaces. // validInterfacesMap returns a set containing only default route interfaces.

View File

@@ -23,6 +23,7 @@ systemd-resolved=false
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename) var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
// hasNetworkManager reports whether NetworkManager executable found. // hasNetworkManager reports whether NetworkManager executable found.
// hasNetworkManager checks if NetworkManager is available on the system
func hasNetworkManager() bool { func hasNetworkManager() bool {
exe, _ := exec.LookPath("NetworkManager") exe, _ := exec.LookPath("NetworkManager")
return exe != "" return exe != ""

View File

@@ -8,6 +8,7 @@ import (
const nextdnsURL = "https://dns.nextdns.io" const nextdnsURL = "https://dns.nextdns.io"
// generateNextDNSConfig generates NextDNS configuration for the given UID
func generateNextDNSConfig(uid string) { func generateNextDNSConfig(uid string) {
if uid == "" { if uid == "" {
return return

View File

@@ -11,7 +11,7 @@ import (
"github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld"
) )
// allocate loopback ip // allocateIP allocates an IP address on the specified interface
// sudo ifconfig lo0 alias 127.0.0.2 up // sudo ifconfig lo0 alias 127.0.0.2 up
func allocateIP(ip string) error { func allocateIP(ip string) error {
cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up") cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up")
@@ -22,6 +22,7 @@ func allocateIP(ip string) error {
return nil return nil
} }
// deAllocateIP deallocates an IP address from the specified interface
func deAllocateIP(ip string) error { func deAllocateIP(ip string) error {
cmd := exec.Command("ifconfig", "lo0", "-alias", ip) cmd := exec.Command("ifconfig", "lo0", "-alias", ip)
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
@@ -90,6 +91,7 @@ func restoreDNS(iface *net.Interface) (err error) {
return err return err
} }
// currentDNS returns the current DNS servers for the specified interface
func currentDNS(_ *net.Interface) []string { func currentDNS(_ *net.Interface) []string {
return ctrld.CurrentNameserversFromResolvconf() return ctrld.CurrentNameserversFromResolvconf()
} }

View File

@@ -13,7 +13,7 @@ import (
"github.com/Control-D-Inc/ctrld/internal/dns" "github.com/Control-D-Inc/ctrld/internal/dns"
) )
// allocate loopback ip // allocateIP allocates an IP address on the specified interface
// sudo ifconfig lo0 127.0.0.53 alias // sudo ifconfig lo0 127.0.0.53 alias
func allocateIP(ip string) error { func allocateIP(ip string) error {
cmd := exec.Command("ifconfig", "lo0", ip, "alias") cmd := exec.Command("ifconfig", "lo0", ip, "alias")
@@ -24,6 +24,7 @@ func allocateIP(ip string) error {
return nil return nil
} }
// deAllocateIP deallocates an IP address from the specified interface
func deAllocateIP(ip string) error { func deAllocateIP(ip string) error {
cmd := exec.Command("ifconfig", "lo0", ip, "-alias") cmd := exec.Command("ifconfig", "lo0", ip, "-alias")
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
@@ -73,6 +74,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
return resetDNS(iface) return resetDNS(iface)
} }
// resetDNS resets DNS servers for the specified interface
func resetDNS(iface *net.Interface) error { func resetDNS(iface *net.Interface) error {
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name) r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
if err != nil { if err != nil {
@@ -93,6 +95,7 @@ func restoreDNS(iface *net.Interface) (err error) {
return err return err
} }
// currentDNS returns the current DNS servers for the specified interface
func currentDNS(_ *net.Interface) []string { func currentDNS(_ *net.Interface) []string {
return ctrld.CurrentNameserversFromResolvconf() return ctrld.CurrentNameserversFromResolvconf()
} }

View File

@@ -2,12 +2,12 @@
package cli package cli
// TODO(cuonglm): implement. // allocateIP allocates an IP address on the specified interface
func allocateIP(ip string) error { func allocateIP(ip string) error {
return nil return nil
} }
// TODO(cuonglm): implement. // deAllocateIP deallocates an IP address from the specified interface
func deAllocateIP(ip string) error { func deAllocateIP(ip string) error {
return nil return nil
} }

View File

@@ -75,6 +75,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
return resetDNS(iface) return resetDNS(iface)
} }
// resetDNS resets DNS servers for the specified interface
func resetDNS(iface *net.Interface) error { func resetDNS(iface *net.Interface) error {
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil { if err != nil {
@@ -136,6 +137,7 @@ func restoreDNS(iface *net.Interface) (err error) {
return err return err
} }
// currentDNS returns the current DNS servers for the specified interface
func currentDNS(iface *net.Interface) []string { func currentDNS(iface *net.Interface) []string {
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
if err != nil { if err != nil {

View File

@@ -4,8 +4,10 @@ import (
"github.com/kardianos/service" "github.com/kardianos/service"
) )
// setDependencies sets service dependencies for Darwin
func setDependencies(svc *service.Config) {} func setDependencies(svc *service.Config) {}
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) { func setWorkingDirectory(svc *service.Config, dir string) {
svc.WorkingDirectory = dir svc.WorkingDirectory = dir
} }

View File

@@ -6,9 +6,11 @@ import (
"github.com/kardianos/service" "github.com/kardianos/service"
) )
// setDependencies sets service dependencies for FreeBSD
func setDependencies(svc *service.Config) { func setDependencies(svc *service.Config) {
// TODO(cuonglm): remove once https://github.com/kardianos/service/issues/359 fixed. // TODO(cuonglm): remove once https://github.com/kardianos/service/issues/359 fixed.
_ = os.MkdirAll("/usr/local/etc/rc.d", 0755) _ = os.MkdirAll("/usr/local/etc/rc.d", 0755)
} }
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) {} func setWorkingDirectory(svc *service.Config, dir string) {}

View File

@@ -21,6 +21,7 @@ func init() {
} }
} }
// setDependencies sets service dependencies for Linux
func setDependencies(svc *service.Config) { func setDependencies(svc *service.Config) {
svc.Dependencies = []string{ svc.Dependencies = []string{
"Wants=network-online.target", "Wants=network-online.target",
@@ -37,6 +38,7 @@ func setDependencies(svc *service.Config) {
} }
} }
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) { func setWorkingDirectory(svc *service.Config, dir string) {
svc.WorkingDirectory = dir svc.WorkingDirectory = dir
} }

View File

@@ -4,8 +4,10 @@ package cli
import "github.com/kardianos/service" import "github.com/kardianos/service"
// setDependencies sets service dependencies for other platforms
func setDependencies(svc *service.Config) {} func setDependencies(svc *service.Config) {}
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) { func setWorkingDirectory(svc *service.Config, dir string) {
// WorkingDirectory is not supported on Windows. // WorkingDirectory is not supported on Windows.
svc.WorkingDirectory = dir svc.WorkingDirectory = dir

View File

@@ -2,8 +2,10 @@ package cli
import "github.com/kardianos/service" import "github.com/kardianos/service"
// setDependencies sets service dependencies for Windows
func setDependencies(svc *service.Config) {} func setDependencies(svc *service.Config) {}
// setWorkingDirectory sets the working directory for the service
func setWorkingDirectory(svc *service.Config, dir string) { func setWorkingDirectory(svc *service.Config, dir string) {
// WorkingDirectory is not supported on Windows. // WorkingDirectory is not supported on Windows.
svc.WorkingDirectory = dir svc.WorkingDirectory = dir

View File

@@ -2,6 +2,8 @@ package cli
import "github.com/prometheus/client_golang/prometheus" import "github.com/prometheus/client_golang/prometheus"
// Prometheus metrics label constants for consistent labeling across all metrics
// These ensure standardized metric labeling for monitoring and alerting
const ( const (
metricsLabelListener = "listener" metricsLabelListener = "listener"
metricsLabelClientSourceIP = "client_source_ip" metricsLabelClientSourceIP = "client_source_ip"
@@ -13,17 +15,21 @@ const (
) )
// statsVersion represent ctrld version. // statsVersion represent ctrld version.
// This metric provides version information for monitoring and debugging
var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{ var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ctrld_build_info", Name: "ctrld_build_info",
Help: "Version of ctrld process.", Help: "Version of ctrld process.",
}, []string{"gitref", "goversion", "version"}) }, []string{"gitref", "goversion", "version"})
// statsTimeStart represents start time of ctrld service. // statsTimeStart represents start time of ctrld service.
// This metric tracks service uptime and helps with monitoring service restarts
var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{ var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "ctrld_time_seconds", Name: "ctrld_time_seconds",
Help: "Start time of the ctrld process since unix epoch in seconds.", Help: "Start time of the ctrld process since unix epoch in seconds.",
}) })
// statsQueriesCountLabels defines the labels for query count metrics
// These labels provide detailed breakdown of DNS query statistics
var statsQueriesCountLabels = []string{ var statsQueriesCountLabels = []string{
metricsLabelListener, metricsLabelListener,
metricsLabelClientSourceIP, metricsLabelClientSourceIP,
@@ -35,6 +41,7 @@ var statsQueriesCountLabels = []string{
} }
// statsQueriesCount counts total number of queries. // statsQueriesCount counts total number of queries.
// This provides comprehensive DNS query statistics for monitoring and alerting
var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{ var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ctrld_queries_count", Name: "ctrld_queries_count",
Help: "Total number of queries.", Help: "Total number of queries.",
@@ -44,12 +51,14 @@ var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
// //
// The labels "client_source_ip", "client_mac", "client_hostname" are unbounded, // The labels "client_source_ip", "client_mac", "client_hostname" are unbounded,
// thus this stat is highly inefficient if there are many devices. // thus this stat is highly inefficient if there are many devices.
// This metric should be used carefully in high-client environments
var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{ var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "ctrld_client_queries_count", Name: "ctrld_client_queries_count",
Help: "Total number queries of a client.", Help: "Total number queries of a client.",
}, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname}) }, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname})
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled. // WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
// This provides conditional metric collection to avoid performance impact when metrics are disabled
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) { func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
if p.metricsQueryStats.Load() { if p.metricsQueryStats.Load() {
c.WithLabelValues(lvs...).Inc() c.WithLabelValues(lvs...).Inc()

View File

@@ -8,10 +8,12 @@ import (
"syscall" "syscall"
) )
// notifyReloadSigCh sends reload signal to the channel
func notifyReloadSigCh(ch chan os.Signal) { func notifyReloadSigCh(ch chan os.Signal) {
signal.Notify(ch, syscall.SIGUSR1) signal.Notify(ch, syscall.SIGUSR1)
} }
// sendReloadSignal sends a reload signal to the current process
func (p *prog) sendReloadSignal() error { func (p *prog) sendReloadSignal() error {
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
} }

View File

@@ -6,8 +6,10 @@ import (
"time" "time"
) )
// notifyReloadSigCh is a no-op on Windows platforms
func notifyReloadSigCh(ch chan os.Signal) {} func notifyReloadSigCh(ch chan os.Signal) {}
// sendReloadSignal sends a reload signal to the program
func (p *prog) sendReloadSignal() error { func (p *prog) sendReloadSignal() error {
select { select {
case p.reloadCh <- struct{}{}: case p.reloadCh <- struct{}{}:

View File

@@ -13,15 +13,18 @@ import (
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found. // parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
// Returns nil if no nameservers are found. // Returns nil if no nameservers are found.
// This function parses the system DNS configuration to understand current nameserver settings
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) { func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
return resolvconffile.NameserversFromFile(path) return resolvconffile.NameserversFromFile(path)
} }
// watchResolvConf watches any changes to /etc/resolv.conf file, // watchResolvConf watches any changes to /etc/resolv.conf file,
// and reverting to the original config set by ctrld. // and reverting to the original config set by ctrld.
// This ensures that DNS settings are not overridden by other applications or system processes
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) { func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
resolvConfPath := "/etc/resolv.conf" resolvConfPath := "/etc/resolv.conf"
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to. // Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
// This handles systems where resolv.conf is a symlink to another location
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" { if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
resolvConfPath = rp resolvConfPath = rp
} }
@@ -35,6 +38,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
// We watch /etc instead of /etc/resolv.conf directly, // We watch /etc instead of /etc/resolv.conf directly,
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well // see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
// This is necessary because some systems don't properly notify on file changes
watchDir := filepath.Dir(resolvConfPath) watchDir := filepath.Dir(resolvConfPath)
if err := watcher.Add(watchDir); err != nil { if err := watcher.Add(watchDir); err != nil {
p.Warn().Err(err).Msgf("could not add %s to watcher list", watchDir) p.Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
@@ -62,6 +66,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...") p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
// Convert expected nameservers to strings for comparison // Convert expected nameservers to strings for comparison
// This allows us to detect when the resolv.conf has been modified
expectedNS := make([]string, len(ns)) expectedNS := make([]string, len(ns))
for i, addr := range ns { for i, addr := range ns {
expectedNS[i] = addr.String() expectedNS[i] = addr.String()
@@ -79,11 +84,13 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
} }
// If we found nameservers, break out of retry loop // If we found nameservers, break out of retry loop
// This handles cases where the file is being written but not yet complete
if len(foundNS) > 0 { if len(foundNS) > 0 {
break break
} }
// Only retry if we found no nameservers // Only retry if we found no nameservers
// This handles temporary file states during updates
if retry < maxRetries-1 { if retry < maxRetries-1 {
p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries) p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
select { select {

View File

@@ -4,4 +4,5 @@ package cli
var supportedSelfDelete = true var supportedSelfDelete = true
// selfDeleteExe performs self-deletion on non-Windows platforms
func selfDeleteExe() error { return nil } func selfDeleteExe() error { return nil }

View File

@@ -33,6 +33,7 @@ type FILE_DISPOSITION_INFO struct {
DeleteFile bool DeleteFile bool
} }
// dsOpenHandle opens a handle to the specified file with DELETE access
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) { func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
handle, err := windows.CreateFile( handle, err := windows.CreateFile(
pwPath, pwPath,
@@ -51,6 +52,7 @@ func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
return handle, nil return handle, nil
} }
// dsRenameHandle renames a file handle to a stream name
func dsRenameHandle(hHandle windows.Handle) error { func dsRenameHandle(hHandle windows.Handle) error {
var fRename FILE_RENAME_INFO var fRename FILE_RENAME_INFO
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef") DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
@@ -82,6 +84,7 @@ func dsRenameHandle(hHandle windows.Handle) error {
return nil return nil
} }
// dsDepositeHandle marks a file handle for deletion
func dsDepositeHandle(hHandle windows.Handle) error { func dsDepositeHandle(hHandle windows.Handle) error {
var fDelete FILE_DISPOSITION_INFO var fDelete FILE_DISPOSITION_INFO
fDelete.DeleteFile = true fDelete.DeleteFile = true
@@ -100,6 +103,7 @@ func dsDepositeHandle(hHandle windows.Handle) error {
return nil return nil
} }
// selfDeleteExe performs self-deletion on Windows platforms
func selfDeleteExe() error { func selfDeleteExe() error {
var wcPath [windows.MAX_PATH + 1]uint16 var wcPath [windows.MAX_PATH + 1]uint16
var hCurrent windows.Handle var hCurrent windows.Handle

View File

@@ -8,6 +8,7 @@ import (
"github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld"
) )
// selfUninstall performs self-uninstallation on non-Unix platforms
func selfUninstall(p *prog, logger *ctrld.Logger) { func selfUninstall(p *prog, logger *ctrld.Logger) {
if uninstallInvalidCdUID(p, logger, false) { if uninstallInvalidCdUID(p, logger, false) {
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)

View File

@@ -12,6 +12,7 @@ import (
"github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld"
) )
// selfUninstall performs self-uninstallation on Unix platforms
func selfUninstall(p *prog, logger *ctrld.Logger) { func selfUninstall(p *prog, logger *ctrld.Logger) {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
selfUninstallLinux(p, logger) selfUninstallLinux(p, logger)
@@ -37,6 +38,7 @@ func selfUninstall(p *prog, logger *ctrld.Logger) {
os.Exit(0) os.Exit(0)
} }
// selfUninstallLinux performs self-uninstallation on Linux platforms
func selfUninstallLinux(p *prog, logger *ctrld.Logger) { func selfUninstallLinux(p *prog, logger *ctrld.Logger) {
if uninstallInvalidCdUID(p, logger, true) { if uninstallInvalidCdUID(p, logger, true) {
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID) logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)

View File

@@ -1,24 +1,31 @@
package cli package cli
// semaphore provides a simple synchronization mechanism
type semaphore interface { type semaphore interface {
acquire() acquire()
release() release()
} }
// noopSemaphore is a no-operation implementation of semaphore
type noopSemaphore struct{} type noopSemaphore struct{}
// acquire performs a no-operation for the noop semaphore
func (n noopSemaphore) acquire() {} func (n noopSemaphore) acquire() {}
// release performs a no-operation for the noop semaphore
func (n noopSemaphore) release() {} func (n noopSemaphore) release() {}
// chanSemaphore is a channel-based implementation of semaphore
type chanSemaphore struct { type chanSemaphore struct {
ready chan struct{} ready chan struct{}
} }
// acquire blocks until a slot is available in the semaphore
func (c *chanSemaphore) acquire() { func (c *chanSemaphore) acquire() {
c.ready <- struct{}{} c.ready <- struct{}{}
} }
// release signals that a slot has been freed in the semaphore
func (c *chanSemaphore) release() { func (c *chanSemaphore) release() {
<-c.ready <-c.ready
} }

View File

@@ -149,6 +149,7 @@ func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
return opts, change return opts, change
} }
// newLaunchd creates a new launchd service wrapper
func newLaunchd(s service.Service) *launchd { func newLaunchd(s service.Service) *launchd {
return &launchd{ return &launchd{
Service: s, Service: s,
@@ -178,6 +179,7 @@ type task struct {
Name string Name string
} }
// doTasks executes a list of tasks and returns success status
func doTasks(tasks []task) bool { func doTasks(tasks []task) bool {
for _, task := range tasks { for _, task := range tasks {
mainLog.Load().Debug().Msgf("Running task %s", task.Name) mainLog.Load().Debug().Msgf("Running task %s", task.Name)
@@ -196,6 +198,7 @@ func doTasks(tasks []task) bool {
return true return true
} }
// checkHasElevatedPrivilege checks if the process has elevated privileges and exits if not
func checkHasElevatedPrivilege() { func checkHasElevatedPrivilege() {
ok, err := hasElevatedPrivilege() ok, err := hasElevatedPrivilege()
if err != nil { if err != nil {
@@ -208,6 +211,7 @@ func checkHasElevatedPrivilege() {
} }
} }
// unixSystemVServiceStatus checks the status of a Unix System V service
func unixSystemVServiceStatus() (service.Status, error) { func unixSystemVServiceStatus() (service.Status, error) {
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput() out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
if err != nil { if err != nil {

View File

@@ -6,12 +6,15 @@ import (
"os" "os"
) )
// hasElevatedPrivilege checks if the current process has elevated privileges
func hasElevatedPrivilege() (bool, error) { func hasElevatedPrivilege() (bool, error) {
return os.Geteuid() == 0, nil return os.Geteuid() == 0, nil
} }
// openLogFile opens a log file with the specified flags
func openLogFile(path string, flags int) (*os.File, error) { func openLogFile(path string, flags int) (*os.File, error) {
return os.OpenFile(path, flags, os.FileMode(0o600)) return os.OpenFile(path, flags, os.FileMode(0o600))
} }
// ConfigureWindowsServiceFailureActions is a no-op on non-Windows platforms
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil } func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }

View File

@@ -11,6 +11,7 @@ import (
"golang.org/x/sys/windows/svc/mgr" "golang.org/x/sys/windows/svc/mgr"
) )
// hasElevatedPrivilege checks if the current process has elevated privileges on Windows
func hasElevatedPrivilege() (bool, error) { func hasElevatedPrivilege() (bool, error) {
var sid *windows.SID var sid *windows.SID
if err := windows.AllocateAndInitializeSid( if err := windows.AllocateAndInitializeSid(
@@ -93,6 +94,7 @@ func ConfigureWindowsServiceFailureActions(serviceName string) error {
return nil return nil
} }
// openLogFile opens a log file with the specified mode on Windows
func openLogFile(path string, mode int) (*os.File, error) { func openLogFile(path string, mode int) (*os.File, error) {
if len(path) == 0 { if len(path) == 0 {
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND} return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}

View File

@@ -30,6 +30,7 @@ type upstreamMonitor struct {
failureTimerActive map[string]bool failureTimerActive map[string]bool
} }
// newUpstreamMonitor creates a new upstream monitor instance
func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor { func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor {
um := &upstreamMonitor{ um := &upstreamMonitor{
cfg: cfg, cfg: cfg,

View File

@@ -43,7 +43,7 @@ func (c *Controller) Start(CdUID string, HomeDir string, UpstreamProto string, l
} }
} }
// As workaround to avoid circular dependency between cli and ctrld_library module // mapCallback maps the AppCallback interface to cli.AppCallback to avoid circular dependency
func mapCallback(callback AppCallback) cli.AppCallback { func mapCallback(callback AppCallback) cli.AppCallback {
return cli.AppCallback{ return cli.AppCallback{
HostName: func() string { HostName: func() string {