diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 4e0f829..6fac8b4 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -421,32 +421,28 @@ func initCLI() { } } -func writeConfigFile() { +func writeConfigFile() error { if cfu := v.ConfigFileUsed(); cfu != "" { defaultConfigFile = cfu } f, err := os.OpenFile(defaultConfigFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644)) if err != nil { - log.Printf("failed to open config file: %v\n", err) - os.Exit(1) + return err } defer f.Close() if cdUID != "" { if _, err := f.WriteString("# AUTO-GENERATED VIA CD FLAG - DO NOT MODIFY\n\n"); err != nil { - log.Printf("failed to write header to config file: %v\n", err) - os.Exit(1) + return err } } enc := toml.NewEncoder(f).SetIndentTables(true) if err := enc.Encode(v.AllSettings()); err != nil { - log.Printf("failed to encode config file: %v\n", err) - os.Exit(1) + return err } if err := f.Close(); err != nil { - log.Printf("failed to write config file: %v\n", err) - os.Exit(1) + return err } - fmt.Println("writing config file to:", defaultConfigFile) + return nil } func readConfigFile(writeDefaultConfig bool) bool { @@ -464,7 +460,11 @@ func readConfigFile(writeDefaultConfig bool) bool { // If error is viper.ConfigFileNotFoundError, write default config. if _, ok := err.(viper.ConfigFileNotFoundError); ok { - writeConfigFile() + if err := writeConfigFile(); err != nil { + log.Fatalf("failed to write default config file: %v", err) + } else { + fmt.Println("writing default config file to: " + defaultConfigFile) + } defaultConfigWritten = true return false } @@ -526,6 +526,7 @@ func processCDFlags() { iface = "auto" } logger := mainLog.With().Str("mode", "cd").Logger() + logger.Info().Msg("fetching Controld-D configuration") resolverConfig, err := controld.FetchResolverConfig(cdUID) if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode { s, err := service.New(&prog{}, svcConfig) @@ -533,10 +534,8 @@ func processCDFlags() { logger.Warn().Err(err).Msg("failed to create new service") return } - if iface == "auto" { - iface = defaultIfaceName() - } - if netIface, _ := netIfaceFromName(iface); netIface != nil { + + if netIface, _ := netInterface(iface); netIface != nil { if err := resetDNS(netIface); err != nil { logger.Warn().Err(err).Msg("something went wrong while restoring DNS") } @@ -552,6 +551,7 @@ func processCDFlags() { return } + logger.Info().Msg("generating ctrld config from Controld-D configuration") cfg = ctrld.Config{} cfg.Network = make(map[string]*ctrld.NetworkConfig) cfg.Network["0"] = &ctrld.NetworkConfig{ @@ -583,7 +583,11 @@ func processCDFlags() { v.Set("upstream", cfg.Upstream) v.Set("listener", cfg.Listener) processLogAndCacheFlags() - writeConfigFile() + if err := writeConfigFile(); err != nil { + logger.Fatal().Err(err).Msg("failed to write config file") + } else { + logger.Info().Msg("writing config file to: " + defaultConfigFile) + } } func processListenFlag() { @@ -620,7 +624,10 @@ func processLogAndCacheFlags() { v.Set("service", cfg.Service) } -func netIfaceFromName(ifaceName string) (*net.Interface, error) { +func netInterface(ifaceName string) (*net.Interface, error) { + if ifaceName == "auto" { + ifaceName = defaultIfaceName() + } var iface *net.Interface err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) { if i.Name == ifaceName { diff --git a/cmd/ctrld/main.go b/cmd/ctrld/main.go index c336a2b..ff6a4ad 100644 --- a/cmd/ctrld/main.go +++ b/cmd/ctrld/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "io" + "net" "os" "path/filepath" "time" @@ -35,6 +36,7 @@ var ( cdUID string iface string + netIface *net.Interface ifaceStartStop string ) diff --git a/cmd/ctrld/os_linux.go b/cmd/ctrld/os_linux.go index a6bc39e..a8ff7f7 100644 --- a/cmd/ctrld/os_linux.go +++ b/cmd/ctrld/os_linux.go @@ -81,7 +81,7 @@ func resetDNS(iface *net.Interface) error { c := client6.NewClient() conversation, err := c.Exchange(iface.Name) if err != nil { - mainLog.Warn().Err(err).Msg("failed to exchange DHCPv6") + mainLog.Warn().Err(err).Msg("could not exchange DHCPv6") } for _, packet := range conversation { if packet.Type() == dhcpv6.MessageTypeReply { diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 30a0688..bde892f 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -34,17 +34,7 @@ func (p *prog) Start(s service.Service) error { } func (p *prog) run() { - if iface != "" { - netIface, err := netIfaceFromName(iface) - if err != nil { - mainLog.Error().Err(err).Msg("could not get interface") - } else { - if err := setDNS(netIface, []string{cfg.Listener["0"].IP}); err != nil { - mainLog.Error().Err(err).Str("iface", iface).Msgf("could not set DNS for interface") - } - } - } - + p.setDNS() if p.cfg.Service.CacheEnable { cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize) if err != nil { @@ -150,7 +140,11 @@ func (p *prog) run() { Port: port, }, }) - writeConfigFile() + if err := writeConfigFile(); err != nil { + proxyLog.Fatal().Err(err).Msg("failed to write config file") + } else { + mainLog.Info().Msg("writing config file to: " + defaultConfigFile) + } mainLog.Info().Msgf("Starting DNS server on listener.%s: %s", listenerNum, pc.LocalAddr()) // There can be a race between closing the listener and start our own UDP server, but it's // rare, and we only do this once, so let conservative here. @@ -175,15 +169,7 @@ func (p *prog) Stop(s service.Service) error { mainLog.Error().Err(err).Msg("de-allocate ip failed") return err } - if iface != "" { - if netIface, err := netIfaceFromName(iface); err == nil { - if err := resetDNS(netIface); err != nil { - mainLog.Error().Err(err).Str("iface", iface).Msgf("could not reset DNS") - } - } else { - mainLog.Error().Err(err).Msg("could not get interface") - } - } + p.resetDNS() return nil } @@ -205,3 +191,35 @@ func (p *prog) deAllocateIP() error { } return nil } + +func (p *prog) setDNS() { + if iface == "" { + return + } + logger := mainLog.With().Str("iface", iface).Logger() + var err error + netIface, err = netInterface(iface) + if err != nil { + logger.Error().Err(err).Msg("could not get interface") + return + } + logger.Debug().Msg("setting DNS for interface") + if err := setDNS(netIface, []string{p.cfg.Listener["0"].IP}); err != nil { + logger.Error().Err(err).Msgf("could not set DNS for interface") + return + } + logger.Debug().Msg("setting DNS successfully") +} + +func (p *prog) resetDNS() { + if netIface == nil { + return + } + logger := mainLog.With().Str("iface", iface).Logger() + logger.Debug().Msg("Restoring DNS for interface") + if err := resetDNS(netIface); err != nil { + logger.Error().Err(err).Msgf("could not reset DNS") + return + } + logger.Debug().Msg("Restoring DNS successfully") +}