diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 9fff2f8..f83d281 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -169,6 +169,8 @@ func initCLI() { runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid") runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "") _ = runCmd.Flags().MarkHidden("homedir") + runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`) + _ = runCmd.Flags().MarkHidden("iface") rootCmd.AddCommand(runCmd) @@ -226,7 +228,7 @@ func initCLI() { } }, } - // Keep these flags in sync with runCmd above, except for "-d", "--iface". + // Keep these flags in sync with runCmd above, except for "-d". startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file") startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config") startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port") @@ -321,7 +323,6 @@ func initCLI() { } initLogging() if doTasks(tasks) { - prog.resetDNS() mainLog.Info().Msg("Service uninstalled") return } @@ -393,6 +394,9 @@ func initCLI() { Use: "start", Short: "Quick start service and configure DNS on interface", Run: func(cmd *cobra.Command, args []string) { + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } iface = ifaceStartStop startCmd.Run(cmd, args) }, @@ -405,6 +409,9 @@ func initCLI() { Use: "stop", Short: "Quick stop service and remove DNS from interface", Run: func(cmd *cobra.Command, args []string) { + if !cmd.Flags().Changed("iface") { + os.Args = append(os.Args, "--iface="+ifaceStartStop) + } iface = ifaceStartStop stopCmd.Run(cmd, args) }, diff --git a/cmd/ctrld/os_linux.go b/cmd/ctrld/os_linux.go index 53d52b9..110a23b 100644 --- a/cmd/ctrld/os_linux.go +++ b/cmd/ctrld/os_linux.go @@ -8,6 +8,7 @@ import ( "net" "net/netip" "os/exec" + "reflect" "strings" "syscall" "time" @@ -41,6 +42,8 @@ func deAllocateIP(ip string) error { return nil } +const maxSetDNSAttempts = 5 + // set the dns server for the provided network interface func setDNS(iface *net.Interface, nameservers []string) error { logf := func(format string, args ...any) { @@ -57,10 +60,22 @@ func setDNS(iface *net.Interface, nameservers []string) error { for _, nameserver := range nameservers { ns = append(ns, netip.MustParseAddr(nameserver)) } - return r.SetDNS(dns.OSConfig{ + + osConfig := dns.OSConfig{ Nameservers: ns, SearchDomains: []dnsname.FQDN{}, - }) + } + + for i := 0; i < maxSetDNSAttempts; i++ { + if err := r.SetDNS(osConfig); err != nil { + return err + } + currentNS := currentDNS(iface) + if reflect.DeepEqual(currentNS, nameservers) { + break + } + } + return nil } func resetDNS(iface *net.Interface) error { diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index b17b94c..fa11dfc 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -35,6 +35,7 @@ func (p *prog) Start(s service.Service) error { } func (p *prog) run() { + p.preRun() if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { @@ -173,6 +174,11 @@ func (p *prog) Stop(s service.Service) error { return nil } +func (p *prog) Uninstall(s service.Service) error { + p.resetDNS() + return nil +} + func (p *prog) allocateIP(ip string) error { if !p.cfg.Service.AllocateIP { return nil diff --git a/cmd/ctrld/prog_linux.go b/cmd/ctrld/prog_linux.go new file mode 100644 index 0000000..155c2fa --- /dev/null +++ b/cmd/ctrld/prog_linux.go @@ -0,0 +1,9 @@ +package main + +import "github.com/kardianos/service" + +func (p *prog) preRun() { + if !service.Interactive() { + p.setDNS() + } +} diff --git a/cmd/ctrld/prog_others.go b/cmd/ctrld/prog_others.go new file mode 100644 index 0000000..10310d5 --- /dev/null +++ b/cmd/ctrld/prog_others.go @@ -0,0 +1,6 @@ +//go:build !linux +// +build !linux + +package main + +func (p *prog) preRun() {}