mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-05-15 00:50:25 +02:00
cmd/ctrld: do set/reset DNS only when start/stop/uninstall
This commit is contained in:
committed by
Cuong Manh Le
parent
4ea1e64795
commit
149941f17f
+18
-20
@@ -117,10 +117,6 @@ func initCLI() {
|
||||
}
|
||||
initCache()
|
||||
|
||||
if iface == "auto" {
|
||||
iface = defaultIfaceName()
|
||||
}
|
||||
|
||||
if daemon {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
@@ -173,8 +169,6 @@ func initCLI() {
|
||||
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
@@ -214,7 +208,8 @@ func initCLI() {
|
||||
}
|
||||
}
|
||||
|
||||
s, err := service.New(&prog{}, sc)
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, sc)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
return
|
||||
@@ -226,12 +221,12 @@ func initCLI() {
|
||||
{s.Start, true},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
stdoutMsg("Service started")
|
||||
return
|
||||
mainLog.Info().Msg("Service started")
|
||||
prog.setDNS()
|
||||
}
|
||||
},
|
||||
}
|
||||
// Keep these flags in sync with runCmd above, except for "-d".
|
||||
// Keep these flags in sync with runCmd above, except for "-d", "--iface".
|
||||
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
@@ -249,13 +244,16 @@ func initCLI() {
|
||||
Short: "Stop the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, svcConfig)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
return
|
||||
}
|
||||
initLogging()
|
||||
if doTasks([]task{{s.Stop, true}}) {
|
||||
stdoutMsg("Service stopped")
|
||||
mainLog.Info().Msg("Service stopped")
|
||||
prog.resetDNS()
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -272,6 +270,7 @@ func initCLI() {
|
||||
stderrMsg(err.Error())
|
||||
return
|
||||
}
|
||||
initLogging()
|
||||
if doTasks([]task{{s.Restart, true}}) {
|
||||
stdoutMsg("Service restarted")
|
||||
}
|
||||
@@ -310,7 +309,8 @@ func initCLI() {
|
||||
Short: "Uninstall the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, svcConfig)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
return
|
||||
@@ -319,8 +319,10 @@ func initCLI() {
|
||||
{s.Stop, false},
|
||||
{s.Uninstall, true},
|
||||
}
|
||||
initLogging()
|
||||
if doTasks(tasks) {
|
||||
stdoutMsg("Service uninstalled")
|
||||
mainLog.Info().Msg("Service uninstalled")
|
||||
prog.resetDNS()
|
||||
return
|
||||
}
|
||||
},
|
||||
@@ -391,9 +393,7 @@ func initCLI() {
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
startCmd.Run(cmd, args)
|
||||
},
|
||||
}
|
||||
@@ -405,9 +405,7 @@ func initCLI() {
|
||||
Use: "stop",
|
||||
Short: "Quick stop service and remove DNS from interface",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
stopCmd.Run(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
@@ -36,7 +35,6 @@ var (
|
||||
|
||||
cdUID string
|
||||
iface string
|
||||
netIface *net.Interface
|
||||
ifaceStartStop string
|
||||
)
|
||||
|
||||
|
||||
+46
-28
@@ -3,16 +3,18 @@ package main
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/client4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/client6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/nclient6"
|
||||
"tailscale.com/net/dns"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
@@ -63,40 +65,56 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
var ns []string
|
||||
c := client4.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
c, err := nclient4.New(iface.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("nclient4.New: %w", err)
|
||||
}
|
||||
for _, packet := range conversation {
|
||||
if packet.MessageType() == dhcpv4.MessageTypeAck {
|
||||
nameservers := packet.DNS()
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
lease, err := c.Request(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nclient4.Request: %w", err)
|
||||
}
|
||||
for _, nameserver := range lease.ACK.DNS() {
|
||||
if nameserver.Equal(net.IPv4zero) {
|
||||
continue
|
||||
}
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
|
||||
if supportsIPv6() {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
c, err := nclient6.New(iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not exchange DHCPv6")
|
||||
mainLog.Warn().Err(err).Msg("could not create DHCPv6 client")
|
||||
return nil
|
||||
}
|
||||
for _, packet := range conversation {
|
||||
if packet.Type() == dhcpv6.MessageTypeReply {
|
||||
msg, err := packet.GetInnerMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nameservers := msg.Options.DNS()
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
defer c.Close()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
solicit, err := dhcpv6.NewSolicit(iface.HardwareAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dhcpv6.NewSolicit: %w", err)
|
||||
}
|
||||
advertise, err := dhcpv6.NewAdvertiseFromSolicit(solicit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dhcpv6.NewAdvertiseFromSolicit: %w", err)
|
||||
}
|
||||
msg, err := c.Request(ctx, advertise)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nclient6.Request: %w", err)
|
||||
}
|
||||
nameservers := msg.Options.DNS()
|
||||
for _, nameserver := range nameservers {
|
||||
if nameserver.Equal(net.IPv6zero) {
|
||||
continue
|
||||
}
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
|
||||
}
|
||||
return ignoringEINTR(func() error {
|
||||
return setDNS(iface, ns)
|
||||
})
|
||||
|
||||
+17
-6
@@ -35,7 +35,6 @@ func (p *prog) Start(s service.Service) error {
|
||||
}
|
||||
|
||||
func (p *prog) run() {
|
||||
p.setDNS()
|
||||
if p.cfg.Service.CacheEnable {
|
||||
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
||||
if err != nil {
|
||||
@@ -170,7 +169,6 @@ func (p *prog) Stop(s service.Service) error {
|
||||
mainLog.Error().Err(err).Msg("de-allocate ip failed")
|
||||
return err
|
||||
}
|
||||
p.resetDNS()
|
||||
mainLog.Info().Msg("Service stopped")
|
||||
return nil
|
||||
}
|
||||
@@ -195,18 +193,23 @@ func (p *prog) deAllocateIP() error {
|
||||
}
|
||||
|
||||
func (p *prog) setDNS() {
|
||||
if cfg.Listener == nil || cfg.Listener["0"] == nil {
|
||||
return
|
||||
}
|
||||
if iface == "" {
|
||||
return
|
||||
}
|
||||
if iface == "auto" {
|
||||
iface = defaultIfaceName()
|
||||
}
|
||||
logger := mainLog.With().Str("iface", iface).Logger()
|
||||
var err error
|
||||
netIface, err = netInterface(iface)
|
||||
netIface, err := netInterface(iface)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("could not get interface")
|
||||
return
|
||||
}
|
||||
logger.Debug().Msg("setting DNS for interface")
|
||||
if err := setDNS(netIface, []string{p.cfg.Listener["0"].IP}); err != nil {
|
||||
if err := setDNS(netIface, []string{cfg.Listener["0"].IP}); err != nil {
|
||||
logger.Error().Err(err).Msgf("could not set DNS for interface")
|
||||
return
|
||||
}
|
||||
@@ -214,10 +217,18 @@ func (p *prog) setDNS() {
|
||||
}
|
||||
|
||||
func (p *prog) resetDNS() {
|
||||
if netIface == nil {
|
||||
if iface == "" {
|
||||
return
|
||||
}
|
||||
if iface == "auto" {
|
||||
iface = defaultIfaceName()
|
||||
}
|
||||
logger := mainLog.With().Str("iface", iface).Logger()
|
||||
netIface, err := netInterface(iface)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("could not get interface")
|
||||
return
|
||||
}
|
||||
logger.Debug().Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
logger.Error().Err(err).Msgf("could not reset DNS")
|
||||
|
||||
Reference in New Issue
Block a user