diff --git a/internal/router/openwrt/openwrt.go b/internal/router/openwrt/openwrt.go index d3bc511..ad98db9 100644 --- a/internal/router/openwrt/openwrt.go +++ b/internal/router/openwrt/openwrt.go @@ -1,9 +1,12 @@ package openwrt import ( + "bytes" + "errors" "fmt" "os" "os/exec" + "strings" "github.com/kardianos/service" @@ -17,7 +20,8 @@ const ( ) type Openwrt struct { - cfg *ctrld.Config + cfg *ctrld.Config + dnsmasqCacheSize string } // New returns a router.Router for configuring/setup/run ctrld on Openwrt routers. @@ -46,6 +50,19 @@ func (o *Openwrt) Setup() error { if o.cfg.FirstListener().IsDirectDnsListener() { return nil } + + // Save current dnsmasq config cache size if present. + if cs, err := uci("get", "dhcp.@dnsmasq[0].cachesize"); err == nil { + o.dnsmasqCacheSize = cs + if _, err := uci("delete", "dhcp.@dnsmasq[0].cachesize"); err != nil { + return err + } + // Commit. + if _, err := uci("commit", "dhcp"); err != nil { + return err + } + } + data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, o.cfg) if err != nil { return err @@ -68,6 +85,18 @@ func (o *Openwrt) Cleanup() error { if err := os.Remove(openwrtDNSMasqConfigPath); err != nil { return err } + + // Restore original value if present. + if o.dnsmasqCacheSize != "" { + if _, err := uci("set", fmt.Sprintf("dhcp.@dnsmasq[0].cachesize=%s", o.dnsmasqCacheSize)); err != nil { + return err + } + // Commit. + if _, err := uci("commit", "dhcp"); err != nil { + return err + } + } + // Restart dnsmasq service. if err := restartDNSMasq(); err != nil { return err @@ -81,3 +110,19 @@ func restartDNSMasq() error { } return nil } + +var errUCIEntryNotFound = errors.New("uci: Entry not found") + +func uci(args ...string) (string, error) { + cmd := exec.Command("uci", args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + if strings.HasPrefix(stderr.String(), errUCIEntryNotFound.Error()) { + return "", errUCIEntryNotFound + } + return "", fmt.Errorf("%s:%w", stderr.String(), err) + } + return strings.TrimSpace(stdout.String()), nil +}