cmd/ctrld: do set/reset DNS only when start/stop/uninstall

This commit is contained in:
Cuong Manh Le
2023-01-31 11:03:50 +07:00
committed by Cuong Manh Le
parent 4ea1e64795
commit 149941f17f
6 changed files with 89 additions and 58 deletions
+18 -20
View File
@@ -117,10 +117,6 @@ func initCLI() {
}
initCache()
if iface == "auto" {
iface = defaultIfaceName()
}
if daemon {
exe, err := os.Executable()
if err != nil {
@@ -173,8 +169,6 @@ func initCLI() {
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
_ = runCmd.Flags().MarkHidden("homedir")
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
_ = runCmd.Flags().MarkHidden("iface")
rootCmd.AddCommand(runCmd)
@@ -214,7 +208,8 @@ func initCLI() {
}
}
s, err := service.New(&prog{}, sc)
prog := &prog{}
s, err := service.New(prog, sc)
if err != nil {
stderrMsg(err.Error())
return
@@ -226,12 +221,12 @@ func initCLI() {
{s.Start, true},
}
if doTasks(tasks) {
stdoutMsg("Service started")
return
mainLog.Info().Msg("Service started")
prog.setDNS()
}
},
}
// Keep these flags in sync with runCmd above, except for "-d".
// Keep these flags in sync with runCmd above, except for "-d", "--iface".
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
@@ -249,13 +244,16 @@ func initCLI() {
Short: "Stop the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := service.New(&prog{}, svcConfig)
prog := &prog{}
s, err := service.New(prog, svcConfig)
if err != nil {
stderrMsg(err.Error())
return
}
initLogging()
if doTasks([]task{{s.Stop, true}}) {
stdoutMsg("Service stopped")
mainLog.Info().Msg("Service stopped")
prog.resetDNS()
}
},
}
@@ -272,6 +270,7 @@ func initCLI() {
stderrMsg(err.Error())
return
}
initLogging()
if doTasks([]task{{s.Restart, true}}) {
stdoutMsg("Service restarted")
}
@@ -310,7 +309,8 @@ func initCLI() {
Short: "Uninstall the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := service.New(&prog{}, svcConfig)
prog := &prog{}
s, err := service.New(prog, svcConfig)
if err != nil {
stderrMsg(err.Error())
return
@@ -319,8 +319,10 @@ func initCLI() {
{s.Stop, false},
{s.Uninstall, true},
}
initLogging()
if doTasks(tasks) {
stdoutMsg("Service uninstalled")
mainLog.Info().Msg("Service uninstalled")
prog.resetDNS()
return
}
},
@@ -391,9 +393,7 @@ func initCLI() {
Use: "start",
Short: "Quick start service and configure DNS on interface",
Run: func(cmd *cobra.Command, args []string) {
if !cmd.Flags().Changed("iface") {
os.Args = append(os.Args, "--iface="+ifaceStartStop)
}
iface = ifaceStartStop
startCmd.Run(cmd, args)
},
}
@@ -405,9 +405,7 @@ func initCLI() {
Use: "stop",
Short: "Quick stop service and remove DNS from interface",
Run: func(cmd *cobra.Command, args []string) {
if !cmd.Flags().Changed("iface") {
os.Args = append(os.Args, "--iface="+ifaceStartStop)
}
iface = ifaceStartStop
stopCmd.Run(cmd, args)
},
}
-2
View File
@@ -3,7 +3,6 @@ package main
import (
"fmt"
"io"
"net"
"os"
"path/filepath"
"time"
@@ -36,7 +35,6 @@ var (
cdUID string
iface string
netIface *net.Interface
ifaceStartStop string
)
+46 -28
View File
@@ -3,16 +3,18 @@ package main
import (
"bufio"
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
"syscall"
"time"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/client4"
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/dhcpv6/client6"
"github.com/insomniacslk/dhcp/dhcpv6/nclient6"
"tailscale.com/net/dns"
"tailscale.com/util/dnsname"
@@ -63,40 +65,56 @@ func setDNS(iface *net.Interface, nameservers []string) error {
func resetDNS(iface *net.Interface) error {
var ns []string
c := client4.NewClient()
conversation, err := c.Exchange(iface.Name)
c, err := nclient4.New(iface.Name)
if err != nil {
return err
return fmt.Errorf("nclient4.New: %w", err)
}
for _, packet := range conversation {
if packet.MessageType() == dhcpv4.MessageTypeAck {
nameservers := packet.DNS()
for _, nameserver := range nameservers {
ns = append(ns, nameserver.String())
}
defer c.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
lease, err := c.Request(ctx)
if err != nil {
return fmt.Errorf("nclient4.Request: %w", err)
}
for _, nameserver := range lease.ACK.DNS() {
if nameserver.Equal(net.IPv4zero) {
continue
}
ns = append(ns, nameserver.String())
}
if supportsIPv6() {
c := client6.NewClient()
conversation, err := c.Exchange(iface.Name)
c, err := nclient6.New(iface.Name)
if err != nil {
mainLog.Warn().Err(err).Msg("could not exchange DHCPv6")
mainLog.Warn().Err(err).Msg("could not create DHCPv6 client")
return nil
}
for _, packet := range conversation {
if packet.Type() == dhcpv6.MessageTypeReply {
msg, err := packet.GetInnerMessage()
if err != nil {
return err
}
nameservers := msg.Options.DNS()
for _, nameserver := range nameservers {
ns = append(ns, nameserver.String())
}
}
}
}
defer c.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
solicit, err := dhcpv6.NewSolicit(iface.HardwareAddr)
if err != nil {
return fmt.Errorf("dhcpv6.NewSolicit: %w", err)
}
advertise, err := dhcpv6.NewAdvertiseFromSolicit(solicit)
if err != nil {
return fmt.Errorf("dhcpv6.NewAdvertiseFromSolicit: %w", err)
}
msg, err := c.Request(ctx, advertise)
if err != nil {
return fmt.Errorf("nclient6.Request: %w", err)
}
nameservers := msg.Options.DNS()
for _, nameserver := range nameservers {
if nameserver.Equal(net.IPv6zero) {
continue
}
ns = append(ns, nameserver.String())
}
}
return ignoringEINTR(func() error {
return setDNS(iface, ns)
})
+17 -6
View File
@@ -35,7 +35,6 @@ func (p *prog) Start(s service.Service) error {
}
func (p *prog) run() {
p.setDNS()
if p.cfg.Service.CacheEnable {
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
if err != nil {
@@ -170,7 +169,6 @@ func (p *prog) Stop(s service.Service) error {
mainLog.Error().Err(err).Msg("de-allocate ip failed")
return err
}
p.resetDNS()
mainLog.Info().Msg("Service stopped")
return nil
}
@@ -195,18 +193,23 @@ func (p *prog) deAllocateIP() error {
}
func (p *prog) setDNS() {
if cfg.Listener == nil || cfg.Listener["0"] == nil {
return
}
if iface == "" {
return
}
if iface == "auto" {
iface = defaultIfaceName()
}
logger := mainLog.With().Str("iface", iface).Logger()
var err error
netIface, err = netInterface(iface)
netIface, err := netInterface(iface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
return
}
logger.Debug().Msg("setting DNS for interface")
if err := setDNS(netIface, []string{p.cfg.Listener["0"].IP}); err != nil {
if err := setDNS(netIface, []string{cfg.Listener["0"].IP}); err != nil {
logger.Error().Err(err).Msgf("could not set DNS for interface")
return
}
@@ -214,10 +217,18 @@ func (p *prog) setDNS() {
}
func (p *prog) resetDNS() {
if netIface == nil {
if iface == "" {
return
}
if iface == "auto" {
iface = defaultIfaceName()
}
logger := mainLog.With().Str("iface", iface).Logger()
netIface, err := netInterface(iface)
if err != nil {
logger.Error().Err(err).Msg("could not get interface")
return
}
logger.Debug().Msg("Restoring DNS for interface")
if err := resetDNS(netIface); err != nil {
logger.Error().Err(err).Msgf("could not reset DNS")