Merge pull request #144 from Control-D-Inc/release-branch-v1.3.6

Release branch v1.3.6
This commit is contained in:
Cuong Manh Le
2024-04-20 00:01:23 +07:00
committed by GitHub
35 changed files with 1139 additions and 235 deletions

View File

@@ -12,6 +12,7 @@ import (
"net"
"net/http"
"net/netip"
"net/url"
"os"
"os/exec"
"path/filepath"
@@ -23,11 +24,13 @@ import (
"sync"
"time"
"github.com/Masterminds/semver"
"github.com/cuonglm/osinfo"
"github.com/fsnotify/fsnotify"
"github.com/go-playground/validator/v10"
"github.com/kardianos/service"
"github.com/miekg/dns"
"github.com/minio/selfupdate"
"github.com/olekukonko/tablewriter"
"github.com/pelletier/go-toml/v2"
"github.com/rs/zerolog"
@@ -44,6 +47,9 @@ import (
"github.com/Control-D-Inc/ctrld/internal/router"
)
// selfCheckInternalTestDomain is used for testing ctrld self response to clients.
const selfCheckInternalTestDomain = "ctrld" + loopTestDomain
var (
version = "dev"
commit = "none"
@@ -170,12 +176,63 @@ func initCLI() {
}
setDependencies(sc)
sc.Arguments = append([]string{"run"}, osArgs...)
p := &prog{
router: router.New(&cfg, cdUID != ""),
cfg: &cfg,
}
s, err := newService(p, sc)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
}
status, err := s.Status()
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
// If pin code was set, do not allow running start command.
if status == service.StatusRunning {
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
os.Exit(deactivationPinInvalidExitCode)
}
}
if cdUID != "" {
if _, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); err != nil {
rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
if err != nil {
mainLog.Load().Fatal().Err(err).Msgf("failed to fetch resolver uid: %s", cdUID)
}
// validateCdRemoteConfig clobbers v, saving it here to restore later.
oldV := v
if err := validateCdRemoteConfig(rc, &ctrld.Config{}); err != nil {
if errors.As(err, &viper.ConfigParseError{}) {
if configStr, _ := base64.StdEncoding.DecodeString(rc.Ctrld.CustomConfig); len(configStr) > 0 {
tmpDir := os.TempDir()
tmpConfFile := filepath.Join(tmpDir, "ctrld.toml")
errorLogged := false
// Write remote config to a temporary file to get details error.
if we := os.WriteFile(tmpConfFile, configStr, 0600); we == nil {
if de := decoderErrorFromTomlFile(tmpConfFile); de != nil {
row, col := de.Position()
mainLog.Load().Error().Msgf("failed to parse custom config at line: %d, column: %d, error: %s", row, col, de.Error())
errorLogged = true
}
_ = os.Remove(tmpConfFile)
}
// If we could not log details error, emit what we have already got.
if !errorLogged {
mainLog.Load().Error().Msgf("failed to parse custom config: %v", err)
}
}
} else {
mainLog.Load().Error().Msgf("failed to unmarshal custom config: %v", err)
}
mainLog.Load().Warn().Msg("disregarding invalid custom config")
}
v = oldV
} else if uid := cdUIDFromProvToken(); uid != "" {
cdUID = uid
mainLog.Load().Debug().Msg("using uid from provision token")
removeProvTokenFromArgs(sc)
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
sc.Arguments = append(sc.Arguments, "--cd="+cdUID)
@@ -184,10 +241,6 @@ func initCLI() {
validateCdUpstreamProtocol()
}
p := &prog{
router: router.New(&cfg, cdUID != ""),
cfg: &cfg,
}
if err := p.router.ConfigureService(sc); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to configure service on router")
}
@@ -253,22 +306,6 @@ func initCLI() {
sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile)
}
s, err := newService(p, sc)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
}
status, err := s.Status()
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
// If pin code was set, do not allow running start command.
if status == service.StatusRunning {
if err := checkDeactivationPin(s); isCheckDeactivationPinErr(err) {
os.Exit(deactivationPinInvalidExitCode)
}
}
if router.Name() != "" && iface != "" {
mainLog.Load().Debug().Msg("cleaning up router before installing")
_ = p.router.Cleanup()
@@ -334,7 +371,7 @@ func initCLI() {
}
// If ctrld service is running but selfCheckStatus failed, it could be related
// to user's system firewall configuration, notice users about it.
if status == service.StatusRunning {
if status == service.StatusRunning && err == nil {
_, _ = mainLog.Load().Write(marker)
mainLog.Load().Write([]byte(`ctrld service was running, but a DNS query could not be sent to its listener`))
mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`))
@@ -406,14 +443,14 @@ func initCLI() {
Run: func(cmd *cobra.Command, args []string) {
readConfig(false)
v.Unmarshal(&cfg)
p := &prog{router: router.New(&cfg, cdUID != "")}
p := &prog{router: router.New(&cfg, runInCdMode())}
s, err := newService(p, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
}
initLogging()
if err := checkDeactivationPin(s); isCheckDeactivationPinErr(err) {
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
os.Exit(deactivationPinInvalidExitCode)
}
if doTasks([]task{{s.Stop, true}}) {
@@ -563,7 +600,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
Run: func(cmd *cobra.Command, args []string) {
readConfig(false)
v.Unmarshal(&cfg)
p := &prog{router: router.New(&cfg, cdUID != "")}
p := &prog{router: router.New(&cfg, runInCdMode())}
s, err := newService(p, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
@@ -572,7 +609,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
if iface == "" {
iface = "auto"
}
if err := checkDeactivationPin(s); isCheckDeactivationPinErr(err) {
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
os.Exit(deactivationPinInvalidExitCode)
}
uninstall(p, s)
@@ -816,6 +853,117 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
}
clientsCmd.AddCommand(listClientsCmd)
rootCmd.AddCommand(clientsCmd)
const (
upgradeChannelDev = "dev"
upgradeChannelProd = "prod"
upgradeChannelDefault = "default"
)
upgradeChannel := map[string]string{
upgradeChannelDefault: "https://dl.controld.dev",
upgradeChannelDev: "https://dl.controld.dev",
upgradeChannelProd: "https://dl.controld.com",
}
if isStableVersion(curVersion()) {
upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd]
}
upgradeCmd := &cobra.Command{
Use: "upgrade",
Short: "Upgrading ctrld to latest version",
ValidArgs: []string{upgradeChannelDev, upgradeChannelProd},
Args: cobra.MaximumNArgs(1),
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Run: func(cmd *cobra.Command, args []string) {
s, err := newService(&prog{}, svcConfig)
if err != nil {
mainLog.Load().Error().Msg(err.Error())
return
}
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Warn().Msg("service not installed")
return
}
bin, err := os.Executable()
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to get current ctrld binary path")
}
oldBin := bin + "_previous"
urlString := upgradeChannel[upgradeChannelDefault]
if len(args) > 0 {
channel := args[0]
switch channel {
case upgradeChannelProd, upgradeChannelDev: // ok
default:
mainLog.Load().Fatal().Msgf("uprade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev)
}
urlString = upgradeChannel[channel]
}
dlUrl := fmt.Sprintf("%s/%s-%s/ctrld", urlString, runtime.GOOS, runtime.GOARCH)
if runtime.GOOS == "windows" {
dlUrl += ".exe"
}
mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl)
resp, err := http.Get(dlUrl)
if err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to download binary")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
mainLog.Load().Fatal().Msgf("could not download binary: %s", http.StatusText(resp.StatusCode))
}
mainLog.Load().Debug().Msg("Updating current binary")
if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil {
if rerr := selfupdate.RollbackError(err); rerr != nil {
mainLog.Load().Error().Err(rerr).Msg("could not rollback old binary")
}
mainLog.Load().Fatal().Err(err).Msg("failed to update current binary")
}
doRestart := func() bool {
tasks := []task{
{s.Stop, false},
{s.Start, false},
}
if doTasks(tasks) {
if dir, err := socketDir(); err == nil {
return newSocketControlClient(s, dir) != nil
}
}
return false
}
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
if doRestart() {
_ = os.Remove(oldBin)
_ = os.Chmod(bin, 0755)
ver := "unknown version"
out, err := exec.Command(bin, "--version").CombinedOutput()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version")
}
if after, found := strings.CutPrefix(string(out), "ctrld version "); found {
ver = after
}
mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver)
return
}
mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin)
if err := os.Remove(bin); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to remove new binary")
}
if err := os.Rename(oldBin, bin); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to restore old binary")
}
if doRestart() {
mainLog.Load().Notice().Msg("Restored previous binary successfully")
return
}
},
}
rootCmd.AddCommand(upgradeCmd)
}
// isMobile reports whether the current OS is a mobile platform.
@@ -828,6 +976,15 @@ func isAndroid() bool {
return runtime.GOOS == "android"
}
// isStableVersion reports whether vs is a stable semantic version.
func isStableVersion(vs string) bool {
v, err := semver.NewVersion(vs)
if err != nil {
return false
}
return v.Prerelease() == ""
}
// RunCobraCommand runs ctrld cli.
func RunCobraCommand(cmd *cobra.Command) {
noConfigStart = isNoConfigStart(cmd)
@@ -852,9 +1009,9 @@ func RunMobile(appConfig *AppConfig, appCallback *AppCallback, stopCh chan struc
}
// CheckDeactivationPin checks if deactivation pin is valid
func CheckDeactivationPin(pin int64) int {
func CheckDeactivationPin(pin int64, stopCh chan struct{}) int {
deactivationPin = pin
if err := checkDeactivationPin(nil); isCheckDeactivationPinErr(err) {
if err := checkDeactivationPin(nil, stopCh); isCheckDeactivationPinErr(err) {
return deactivationPinInvalidExitCode
}
return 0
@@ -912,7 +1069,9 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
writeDefaultConfig := !noConfigStart && configBase64 == ""
tryReadingConfig(writeDefaultConfig)
readBase64Config(configBase64)
if err := readBase64Config(configBase64); err != nil {
mainLog.Load().Fatal().Err(err).Msg("failed to read base64 config")
}
processNoConfigFlags(noConfigStart)
p.mu.Lock()
if err := v.Unmarshal(&cfg); err != nil {
@@ -935,7 +1094,7 @@ func run(appCallback *AppCallback, stopCh chan struct{}) {
}
p.router = router.New(&cfg, cdUID != "")
cs, err := newControlServer(filepath.Join(sockDir, ctrldControlUnixSock))
cs, err := newControlServer(filepath.Join(sockDir, ControlSocketName()))
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not create control server")
}
@@ -1141,7 +1300,7 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
}
// If error is viper.ConfigFileNotFoundError, write default config.
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
if errors.As(err, &viper.ConfigFileNotFoundError{}) {
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal default config: %v", err)
}
@@ -1162,13 +1321,11 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
return false
}
if _, ok := err.(viper.ConfigParseError); ok {
if f, _ := os.Open(v.ConfigFileUsed()); f != nil {
var i any
if err, ok := toml.NewDecoder(f).Decode(&i).(*toml.DecodeError); ok {
row, col := err.Position()
mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err)
}
// If error is viper.ConfigParseError, emit details line and column number.
if errors.As(err, &viper.ConfigParseError{}) {
if de := decoderErrorFromTomlFile(v.ConfigFileUsed()); de != nil {
row, col := de.Position()
mainLog.Load().Fatal().Msgf("failed to decode config file at line: %d, column: %d, error: %v", row, col, err)
}
}
@@ -1177,13 +1334,27 @@ func readConfigFile(writeDefaultConfig, notice bool) bool {
return false
}
func readBase64Config(configBase64 string) {
// decoderErrorFromTomlFile parses the invalid toml file, returning the details decoder error.
func decoderErrorFromTomlFile(cf string) *toml.DecodeError {
if f, _ := os.Open(cf); f != nil {
defer f.Close()
var i any
var de *toml.DecodeError
if err := toml.NewDecoder(f).Decode(&i); err != nil && errors.As(err, &de) {
return de
}
}
return nil
}
// readBase64Config reads ctrld config from the base64 input string.
func readBase64Config(configBase64 string) error {
if configBase64 == "" {
return
return nil
}
configStr, err := base64.StdEncoding.DecodeString(configBase64)
if err != nil {
mainLog.Load().Fatal().Msgf("invalid base64 config: %v", err)
return fmt.Errorf("invalid base64 config: %w", err)
}
// readBase64Config is called when:
@@ -1194,9 +1365,7 @@ func readBase64Config(configBase64 string) {
// So we need to re-create viper instance to discard old one.
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
v.SetConfigType("toml")
if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil {
mainLog.Load().Fatal().Msgf("failed to read base64 config: %v", err)
}
return v.ReadConfig(bytes.NewReader(configStr))
}
func processNoConfigFlags(noConfigStart bool) {
@@ -1286,42 +1455,76 @@ func processCDFlags(cfg *ctrld.Config) error {
// Fetch config, unmarshal to cfg.
if resolverConfig.Ctrld.CustomConfig != "" {
logger.Info().Msg("using defined custom config of Control-D resolver")
readBase64Config(resolverConfig.Ctrld.CustomConfig)
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Load().Fatal().Msgf("failed to unmarshal config: %v", err)
if err := validateCdRemoteConfig(resolverConfig, cfg); err == nil {
setListenerDefaultValue(cfg)
return nil
}
} else {
cfg.Network = make(map[string]*ctrld.NetworkConfig)
cfg.Network["0"] = &ctrld.NetworkConfig{
Name: "Network 0",
Cidrs: []string{"0.0.0.0/0"},
}
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
Endpoint: resolverConfig.DOH,
Type: cdUpstreamProto,
Timeout: 5000,
}
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
lc := &ctrld.ListenerConfig{
Policy: &ctrld.ListenerPolicyConfig{
Name: "My Policy",
Rules: rules,
},
}
cfg.Listener["0"] = lc
mainLog.Load().Err(err).Msg("disregarding invalid custom config")
}
bootstrapIP := func(endpoint string) string {
u, err := url.Parse(endpoint)
if err != nil {
logger.Warn().Err(err).Msgf("no bootstrap IP for invalid endpoint: %s", endpoint)
return ""
}
switch {
case dns.IsSubDomain(ctrld.FreeDnsDomain, u.Host):
return ctrld.FreeDNSBoostrapIP
case dns.IsSubDomain(ctrld.PremiumDnsDomain, u.Host):
return ctrld.PremiumDNSBoostrapIP
}
return ""
}
cfg.Network = make(map[string]*ctrld.NetworkConfig)
cfg.Network["0"] = &ctrld.NetworkConfig{
Name: "Network 0",
Cidrs: []string{"0.0.0.0/0"},
}
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
BootstrapIP: bootstrapIP(resolverConfig.DOH),
Endpoint: resolverConfig.DOH,
Type: cdUpstreamProto,
Timeout: 5000,
}
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
lc := &ctrld.ListenerConfig{
Policy: &ctrld.ListenerPolicyConfig{
Name: "My Policy",
Rules: rules,
},
}
cfg.Listener["0"] = lc
// Set default value.
setListenerDefaultValue(cfg)
return nil
}
// setListenerDefaultValue sets the default value for cfg.Listener if none existed.
func setListenerDefaultValue(cfg *ctrld.Config) {
if len(cfg.Listener) == 0 {
cfg.Listener = map[string]*ctrld.ListenerConfig{
"0": {IP: "", Port: 0},
}
}
return nil
}
// validateCdRemoteConfig validates the custom config from ControlD if defined.
func validateCdRemoteConfig(rc *controld.ResolverConfig, cfg *ctrld.Config) error {
if rc.Ctrld.CustomConfig == "" {
return nil
}
if err := readBase64Config(rc.Ctrld.CustomConfig); err != nil {
return err
}
return v.Unmarshal(&cfg)
}
func processListenFlag() {
@@ -1398,6 +1601,13 @@ func defaultIfaceName() string {
// selfCheckStatus performs the end-to-end DNS test by sending query to ctrld listener.
// It returns a boolean to indicate whether the check is succeeded, the actual status
// of ctrld service, and an additional error if any.
//
// We perform two tests:
//
// - Internal testing, ensuring query could be sent from client -> ctrld.
// - External testing, ensuring query could be sent from ctrld -> upstream.
//
// Self-check is considered success only if both tests are ok.
func selfCheckStatus(s service.Service) (bool, service.Status, error) {
status, err := s.Status()
if err != nil {
@@ -1480,8 +1690,9 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
})
v.WatchConfig()
var (
lastAnswer *dns.Msg
lastErr error
lastAnswer *dns.Msg
lastErr error
internalTested bool
)
for i := 0; i < maxAttempts; i++ {
mu.Lock()
@@ -1494,6 +1705,9 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
mu.Unlock()
lc := cfg.FirstListener()
domain = cfg.FirstUpstream().VerifyDomain()
if !internalTested {
domain = selfCheckInternalTestDomain
}
if domain == "" {
continue
}
@@ -1503,7 +1717,13 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
m.RecursionDesired = true
r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)))
if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 {
mainLog.Load().Debug().Msgf("self-check against %q succeeded", domain)
internalTested = domain == selfCheckInternalTestDomain
if internalTested {
mainLog.Load().Debug().Msgf("internal self-check against %q succeeded", domain)
continue // internal domain test ok, continue with external test.
} else {
mainLog.Load().Debug().Msgf("external self-check against %q succeeded", domain)
}
return true, status, nil
}
// Return early if this is a connection refused.
@@ -1515,6 +1735,12 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", exErr))
}
mainLog.Load().Debug().Msgf("self-check against %q failed", domain)
// Ping all upstreams to provide better error message to users.
for name, uc := range cfg.Upstream {
if err := uc.ErrorPing(); err != nil {
mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint)
}
}
lc := cfg.FirstListener()
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
marker := strings.Repeat("=", 32)
@@ -1629,6 +1855,10 @@ func readConfigWithNotice(writeDefaultConfig, notice bool) {
}
func uninstall(p *prog, s service.Service) {
if _, err := s.Status(); err != nil && errors.Is(err, service.ErrNotInstalled) {
mainLog.Load().Error().Msg(err.Error())
return
}
tasks := []task{
{s.Stop, false},
{s.Uninstall, true},
@@ -1677,6 +1907,11 @@ func fieldErrorMsg(fe validator.FieldError) string {
return fmt.Sprintf("must define at least %s element", fe.Param())
}
return fmt.Sprintf("minimum value: %q", fe.Param())
case "max":
if fe.Kind() == reflect.Map || fe.Kind() == reflect.Slice {
return fmt.Sprintf("exceeded maximum number of elements: %s", fe.Param())
}
return fmt.Sprintf("maximum value: %q", fe.Param())
case "len":
if fe.Kind() == reflect.Slice {
return fmt.Sprintf("must have at least %s element", fe.Param())
@@ -1802,10 +2037,7 @@ func tryUpdateListenerConfig(cfg *ctrld.Config, infoLogger *zerolog.Logger, fata
if cdMode {
firstLn.IP = mobileListenerIp()
firstLn.Port = mobileListenerPort()
// TODO: use clear(lcc) once upgrading to go 1.21
for k := range lcc {
delete(lcc, k)
}
clear(lcc)
updated = true
}
}
@@ -2027,6 +2259,7 @@ func cdUIDFromProvToken() string {
if cdOrg == "" {
return ""
}
// Process provision token if provided.
resolverConfig, err := controld.FetchResolverUID(cdOrg, rootCmd.Version, cdDev)
if err != nil {
@@ -2065,6 +2298,8 @@ func newSocketControlClient(s service.Service, dir string) *controlClient {
ctx := context.Background()
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
timeout := time.NewTimer(30 * time.Second)
defer timeout.Stop()
// The socket control server may not start yet, so attempt to ping
// it until we got a response. For each iteration, check ctrld status
@@ -2072,7 +2307,6 @@ func newSocketControlClient(s service.Service, dir string) *controlClient {
for {
curStatus, err := s.Status()
if err != nil {
mainLog.Load().Warn().Err(err).Msg("could not get service status while doing self-check")
return nil
}
if curStatus != service.StatusRunning {
@@ -2084,12 +2318,37 @@ func newSocketControlClient(s service.Service, dir string) *controlClient {
}
// The socket control server is not ready yet, backoff for waiting it to be ready.
bo.BackOff(ctx, err)
select {
case <-timeout.C:
return nil
default:
}
continue
}
return cc
}
func newSocketControlClientMobile(dir string, stopCh chan struct{}) *controlClient {
bo := backoff.NewBackoff("self-check", logf, 3*time.Second)
bo.LogLongerThan = 3 * time.Second
ctx := context.Background()
cc := newControlClient(filepath.Join(dir, ControlSocketName()))
for {
select {
case <-stopCh:
return nil
default:
_, err := cc.post("/", nil)
if err == nil {
return cc
} else {
bo.BackOff(ctx, err)
}
}
}
}
// checkStrFlagEmpty validates if a string flag was set to an empty string.
// If yes, emitting a fatal error message.
func checkStrFlagEmpty(cmd *cobra.Command, flagName string) {
@@ -2170,7 +2429,7 @@ var errInvalidDeactivationPin = errors.New("deactivation pin is invalid")
var errRequiredDeactivationPin = errors.New("deactivation pin is required to stop or uninstall the service")
// checkDeactivationPin validates if the deactivation pin matches one in ControlD config.
func checkDeactivationPin(s service.Service) error {
func checkDeactivationPin(s service.Service, stopCh chan struct{}) error {
dir, err := socketDir()
if err != nil {
mainLog.Load().Err(err).Msg("could not check deactivation pin")
@@ -2178,7 +2437,7 @@ func checkDeactivationPin(s service.Service) error {
}
var cc *controlClient
if s == nil {
cc = newControlClient(filepath.Join(dir, ctrldControlUnixSock))
cc = newSocketControlClientMobile(dir, stopCh)
} else {
cc = newSocketControlClient(s, dir)
}
@@ -2259,3 +2518,20 @@ func absHomeDir(filename string) string {
}
return filepath.Join(dir, filename)
}
// runInCdMode reports whether ctrld service is running in cd mode.
func runInCdMode() bool {
if s, _ := newService(&prog{}, svcConfig); s != nil {
if dir, _ := socketDir(); dir != "" {
cc := newSocketControlClient(s, dir)
if cc != nil {
resp, _ := cc.post(cdPath, nil)
if resp != nil {
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
}
}
}
return false
}

View File

@@ -21,3 +21,26 @@ func Test_writeConfigFile(t *testing.T) {
_, err = os.Stat(configPath)
require.NoError(t, err)
}
func Test_isStableVersion(t *testing.T) {
tests := []struct {
name string
ver string
isStable bool
}{
{"stable", "v1.3.5", true},
{"pre", "v1.3.5-next", false},
{"pre with commit hash", "v1.3.5-next-asdf", false},
{"dev", "dev", false},
{"empty", "dev", false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isStableVersion(tc.ver); got != tc.isStable {
t.Errorf("unexpected result for %s, want: %v, got: %v", tc.ver, tc.isStable, got)
}
})
}
}

View File

@@ -21,6 +21,7 @@ const (
startedPath = "/started"
reloadPath = "/reload"
deactivationPath = "/deactivation"
cdPath = "/cd"
)
type controlServer struct {
@@ -171,6 +172,13 @@ func (p *prog) registerControlServerHandler() {
}
w.WriteHeader(code)
}))
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
if cdUID != "" {
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusBadRequest)
}))
}
func jsonResponse(next http.Handler) http.Handler {

View File

@@ -101,6 +101,15 @@ func (p *prog) serveDNS(listenerNum string) error {
go p.detectLoop(m)
q := m.Question[0]
domain := canonicalName(q.Name)
if domain == selfCheckInternalTestDomain {
answer := resolveInternalDomainTestQuery(ctx, domain, m)
_ = w.WriteMsg(answer)
return
}
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
p.cache.Purge()
ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain)
}
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
ci := p.getClientInfo(remoteIP, m)
ci.ClientIDPref = p.cfg.Service.ClientIDPref
@@ -282,7 +291,7 @@ networkRules:
macRules:
for _, rule := range lc.Policy.Macs {
for source, targets := range rule {
if source != "" && strings.EqualFold(source, srcMac) {
if source != "" && (strings.EqualFold(source, srcMac) || wildcardMatches(strings.ToLower(source), strings.ToLower(srcMac))) {
matchedPolicy = lc.Policy.Name
matchedNetwork = source
networkTargets = targets
@@ -590,7 +599,8 @@ func canonicalName(fqdn string) string {
return q
}
func wildcardMatches(wildcard, domain string) bool {
// wildcardMatches reports whether string str matches the wildcard pattern.
func wildcardMatches(wildcard, str string) bool {
// Wildcard match.
wildCardParts := strings.Split(wildcard, "*")
if len(wildCardParts) != 2 {
@@ -600,15 +610,15 @@ func wildcardMatches(wildcard, domain string) bool {
switch {
case len(wildCardParts[0]) > 0 && len(wildCardParts[1]) > 0:
// Domain must match both prefix and suffix.
return strings.HasPrefix(domain, wildCardParts[0]) && strings.HasSuffix(domain, wildCardParts[1])
return strings.HasPrefix(str, wildCardParts[0]) && strings.HasSuffix(str, wildCardParts[1])
case len(wildCardParts[1]) > 0:
// Only suffix must match.
return strings.HasSuffix(domain, wildCardParts[1])
return strings.HasSuffix(str, wildCardParts[1])
case len(wildCardParts[0]) > 0:
// Only prefix must match.
return strings.HasPrefix(domain, wildCardParts[0])
return strings.HasPrefix(str, wildCardParts[0])
}
return false
@@ -806,6 +816,13 @@ func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
}
ci.Self = queryFromSelf(ci.IP)
// If this is a query from self, but ci.IP is not loopback IP,
// try using hostname mapping for lookback IP if presents.
if ci.Self {
if name := p.ciTable.LocalHostname(); name != "" {
ci.Hostname = name
}
}
p.spoofLoopbackIpInClientInfo(ci)
return ci
}
@@ -936,3 +953,21 @@ func isWanClient(na net.Addr) bool {
!ip.IsLinkLocalMulticast() &&
!tsaddr.CGNATRange().Contains(ip)
}
// resolveInternalDomainTestQuery resolves internal test domain query, returning the answer to the caller.
func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.Msg) *dns.Msg {
ctrld.Log(ctx, mainLog.Load().Debug(), "internal domain test query")
q := m.Question[0]
answer := new(dns.Msg)
rrStr := fmt.Sprintf("%s A %s", domain, net.IPv4zero)
if q.Qtype == dns.TypeAAAA {
rrStr = fmt.Sprintf("%s AAAA %s", domain, net.IPv6zero)
}
rr, err := dns.NewRR(rrStr)
if err == nil {
answer.Answer = append(answer.Answer, rr)
}
answer.SetReply(m)
return answer
}

View File

@@ -22,14 +22,21 @@ func Test_wildcardMatches(t *testing.T) {
domain string
match bool
}{
{"prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
{"prefix", "*.windscribe.com", "anything.windscribe.com", true},
{"prefix not match other domain", "*.windscribe.com", "example.com", false},
{"prefix not match domain in name", "*.windscribe.com", "wwindscribe.com", false},
{"suffix", "suffix.*", "suffix.windscribe.com", true},
{"suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
{"both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
{"both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
{"domain - prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
{"domain - prefix", "*.windscribe.com", "anything.windscribe.com", true},
{"domain - prefix not match other s", "*.windscribe.com", "example.com", false},
{"domain - prefix not match s in name", "*.windscribe.com", "wwindscribe.com", false},
{"domain - suffix", "suffix.*", "suffix.windscribe.com", true},
{"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
{"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
{"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
{"mac - suffix", "d4:67:98:*", "d4:67:98:05:b4:2b", true},
{"mac - suffix not match other", "d4:67:98:*", "d4:67:97:15:b4:2b", false},
{"mac - both", "d4:67:98:*:b4:2b", "d4:67:98:05:b4:2b", true},
{"mac - both not match", "d4:67:98:*:b4:2b", "d4:67:97:05:c4:2b", false},
}
for _, tc := range tests {

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"os/exec"
"strings"
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
)
@@ -30,6 +31,18 @@ func deAllocateIP(ip string) error {
return nil
}
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
if err := setDNS(iface, nameservers); err != nil {
// TODO: investiate whether we can detect this without relying on error message.
if strings.Contains(err.Error(), " is not a recognized network service") {
return nil
}
return err
}
return nil
}
// set the dns server for the provided network interface
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
// TODO(cuonglm): use system API
@@ -43,6 +56,18 @@ func setDNS(iface *net.Interface, nameservers []string) error {
return nil
}
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
if err := resetDNS(iface); err != nil {
// TODO: investiate whether we can detect this without relying on error message.
if strings.Contains(err.Error(), " is not a recognized network service") {
return nil
}
return err
}
return nil
}
// TODO(cuonglm): use system API
func resetDNS(iface *net.Interface) error {
if ns := savedStaticNameservers(iface); len(ns) > 0 {

View File

@@ -29,6 +29,11 @@ func deAllocateIP(ip string) error {
return nil
}
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
return setDNS(iface, nameservers)
}
// set the dns server for the provided network interface
func setDNS(iface *net.Interface, nameservers []string) error {
r, err := dns.NewOSConfigurator(logf, iface.Name)
@@ -49,6 +54,11 @@ func setDNS(iface *net.Interface, nameservers []string) error {
return nil
}
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
return resetDNS(iface)
}
func resetDNS(iface *net.Interface) error {
r, err := dns.NewOSConfigurator(logf, iface.Name)
if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"net"
"net/netip"
"os/exec"
"slices"
"strings"
"syscall"
"time"
@@ -45,7 +46,11 @@ func deAllocateIP(ip string) error {
const maxSetDNSAttempts = 5
// set the dns server for the provided network interface
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
return setDNS(iface, nameservers)
}
func setDNS(iface *net.Interface, nameservers []string) error {
r, err := dns.NewOSConfigurator(logf, iface.Name)
if err != nil {
@@ -115,6 +120,11 @@ func setDNS(iface *net.Interface, nameservers []string) error {
return nil
}
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
return resetDNS(iface)
}
func resetDNS(iface *net.Interface) (err error) {
defer func() {
if err == nil {
@@ -276,8 +286,7 @@ func ignoringEINTR(fn func() error) error {
func isSubSet(s1, s2 []string) bool {
ok := true
for _, ns := range s1 {
// TODO(cuonglm): use slices.Contains once upgrading to go1.21
if sliceContains(s2, ns) {
if slices.Contains(s2, ns) {
continue
}
ok = false
@@ -285,19 +294,3 @@ func isSubSet(s1, s2 []string) bool {
}
return ok
}
// sliceContains reports whether v is present in s.
func sliceContains[S ~[]E, E comparable](s S, v E) bool {
return sliceIndex(s, v) >= 0
}
// sliceIndex returns the index of the first occurrence of v in s,
// or -1 if not present.
func sliceIndex[S ~[]E, E comparable](s S, v E) int {
for i := range s {
if v == s[i] {
return i
}
}
return -1
}

View File

@@ -26,6 +26,12 @@ var (
resetDNSOnce sync.Once
)
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
return setDNS(iface, nameservers)
}
// setDNS sets the dns server for the provided network interface
func setDNS(iface *net.Interface, nameservers []string) error {
if len(nameservers) == 0 {
return errors.New("empty DNS nameservers")
@@ -61,6 +67,11 @@ func setDNS(iface *net.Interface, nameservers []string) error {
return nil
}
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
return resetDNS(iface)
}
// TODO(cuonglm): should we use system API?
func resetDNS(iface *net.Interface) error {
resetDNSOnce.Do(func() {

View File

@@ -33,11 +33,22 @@ const (
defaultSemaphoreCap = 256
ctrldLogUnixSock = "ctrld_start.sock"
ctrldControlUnixSock = "ctrld_control.sock"
upstreamPrefix = "upstream."
upstreamOS = upstreamPrefix + "os"
upstreamPrivate = upstreamPrefix + "private"
// iOS unix socket name max length is 11.
ctrldControlUnixSockMobile = "cd.sock"
upstreamPrefix = "upstream."
upstreamOS = upstreamPrefix + "os"
upstreamPrivate = upstreamPrefix + "private"
)
// ControlSocketName returns name for control unix socket.
func ControlSocketName() string {
if isMobile() {
return ctrldControlUnixSockMobile
} else {
return ctrldControlUnixSock
}
}
var logf = func(format string, args ...any) {
mainLog.Load().Debug().Msgf(format, args...)
}
@@ -59,17 +70,18 @@ type prog struct {
logConn net.Conn
cs *controlServer
cfg *ctrld.Config
localUpstreams []string
ptrNameservers []string
appCallback *AppCallback
cache dnscache.Cacher
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
ptrLoopGuard *loopGuard
lanLoopGuard *loopGuard
cfg *ctrld.Config
localUpstreams []string
ptrNameservers []string
appCallback *AppCallback
cache dnscache.Cacher
cacheFlushDomainsMap map[string]struct{}
sema semaphore
ciTable *clientinfo.Table
um *upstreamMonitor
router router.Router
ptrLoopGuard *loopGuard
lanLoopGuard *loopGuard
loopMu sync.Mutex
loop map[string]bool
@@ -242,12 +254,17 @@ func (p *prog) run(reload bool, reloadCh chan struct{}) {
p.loop = make(map[string]bool)
p.lanLoopGuard = newLoopGuard()
p.ptrLoopGuard = newLoopGuard()
p.cacheFlushDomainsMap = nil
if p.cfg.Service.CacheEnable {
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
if err != nil {
mainLog.Load().Error().Err(err).Msg("failed to create cacher, caching is disabled")
} else {
p.cache = cacher
p.cacheFlushDomainsMap = make(map[string]struct{}, 256)
for _, domain := range p.cfg.Service.CacheFlushDomains {
p.cacheFlushDomainsMap[canonicalName(domain)] = struct{}{}
}
}
}
@@ -477,7 +494,7 @@ func (p *prog) setDNS() {
}
if allIfaces {
withEachPhysicalInterfaces(netIface.Name, "set DNS", func(i *net.Interface) error {
return setDNS(i, nameservers)
return setDnsIgnoreUnusableInterface(i, nameservers)
})
}
}
@@ -509,7 +526,7 @@ func (p *prog) resetDNS() {
}
logger.Debug().Msg("Restoring DNS successfully")
if allIfaces {
withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDNS)
withEachPhysicalInterfaces(netIface.Name, "reset DNS", resetDnsIgnoreUnusableInterface)
}
}

View File

@@ -1,6 +1,8 @@
package cli
import (
"os"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld/internal/dns"
@@ -10,6 +12,10 @@ func init() {
if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo"); err == nil {
useSystemdResolved = r.Mode() == "systemd-resolved"
}
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911
if os.Getenv("QUIC_GO_DISABLE_ECN") == "" {
os.Setenv("QUIC_GO_DISABLE_ECN", "true")
}
}
func setDependencies(svc *service.Config) {

View File

@@ -20,7 +20,7 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
return nil, err
}
switch {
case router.IsOldOpenwrt():
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
return &procd{&sysV{s}}, nil
case router.IsGLiNet():
return &sysV{s}, nil

View File

@@ -61,8 +61,13 @@ func mapCallback(callback AppCallback) cli.AppCallback {
}
}
func (c *Controller) Stop(Pin int64) int {
errorCode := cli.CheckDeactivationPin(Pin)
func (c *Controller) Stop(restart bool, pin int64) int {
var errorCode = 0
// Force disconnect without checking pin.
// In iOS restart is required if vpn detects no connectivity after network change.
if !restart {
errorCode = cli.CheckDeactivationPin(pin, c.stopCh)
}
if errorCode == 0 && c.stopCh != nil {
close(c.stopCh)
c.stopCh = nil

110
config.go
View File

@@ -46,6 +46,15 @@ const (
// depending on the record type of the DNS query.
IpStackSplit = "split"
// FreeDnsDomain is the domain name of free ControlD service.
FreeDnsDomain = "freedns.controld.com"
// FreeDNSBoostrapIP is the IP address of freedns.controld.com.
FreeDNSBoostrapIP = "76.76.2.11"
// PremiumDnsDomain is the domain name of premium ControlD service.
PremiumDnsDomain = "dns.controld.com"
// PremiumDNSBoostrapIP is the IP address of dns.controld.com.
PremiumDNSBoostrapIP = "76.76.2.22"
controlDComDomain = "controld.com"
controlDNetDomain = "controld.net"
controlDDevDomain = "controld.dev"
@@ -104,14 +113,14 @@ func InitConfig(v *viper.Viper, name string) {
})
v.SetDefault("upstream", map[string]*UpstreamConfig{
"0": {
BootstrapIP: "76.76.2.11",
BootstrapIP: FreeDNSBoostrapIP,
Name: "Control D - Anti-Malware",
Type: ResolverTypeDOH,
Endpoint: "https://freedns.controld.com/p1",
Timeout: 5000,
},
"1": {
BootstrapIP: "76.76.2.11",
BootstrapIP: FreeDNSBoostrapIP,
Name: "Control D - No Ads",
Type: ResolverTypeDOQ,
Endpoint: "p2.freedns.controld.com",
@@ -179,26 +188,27 @@ func (c *Config) FirstUpstream() *UpstreamConfig {
// ServiceConfig specifies the general ctrld config.
type ServiceConfig struct {
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
Daemon bool `mapstructure:"-" toml:"-"`
AllocateIP bool `mapstructure:"-" toml:"-"`
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
Daemon bool `mapstructure:"-" toml:"-"`
AllocateIP bool `mapstructure:"-" toml:"-"`
}
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
@@ -285,6 +295,7 @@ type Rule map[string][]string
// Init initialized necessary values for an UpstreamConfig.
func (uc *UpstreamConfig) Init() {
uc.initDoHScheme()
uc.uid = upstreamUID()
if u, err := url.Parse(uc.Endpoint); err == nil {
uc.Domain = u.Host
@@ -510,35 +521,55 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
// Ping warms up the connection to DoH/DoH3 upstream.
func (uc *UpstreamConfig) Ping() {
_ = uc.ping()
}
// ErrorPing is like Ping, but return an error if any.
func (uc *UpstreamConfig) ErrorPing() error {
return uc.ping()
}
func (uc *UpstreamConfig) ping() error {
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
default:
return
return nil
}
ping := func(t http.RoundTripper) {
ping := func(t http.RoundTripper) error {
if t == nil {
return
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
req, _ := http.NewRequestWithContext(ctx, "HEAD", uc.Endpoint, nil)
resp, _ := t.RoundTrip(req)
if resp == nil {
return
req, err := http.NewRequestWithContext(ctx, "HEAD", uc.Endpoint, nil)
if err != nil {
return err
}
resp, err := t.RoundTrip(req)
if err != nil {
return err
}
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body)
return nil
}
for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} {
switch uc.Type {
case ResolverTypeDOH:
ping(uc.dohTransport(typ))
if err := ping(uc.dohTransport(typ)); err != nil {
return err
}
case ResolverTypeDOH3:
ping(uc.doh3Transport(typ))
if err := ping(uc.doh3Transport(typ)); err != nil {
return err
}
}
}
return nil
}
func (uc *UpstreamConfig) isControlD() bool {
@@ -631,6 +662,18 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
return "tcp-tls", "udp"
}
// initDoHScheme initializes the endpoint scheme for DoH/DoH3 upstream if not present.
func (uc *UpstreamConfig) initDoHScheme() {
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
default:
return
}
if !strings.HasPrefix(uc.Endpoint, "https://") {
uc.Endpoint = "https://" + uc.Endpoint
}
}
// Init initialized necessary values for an ListenerConfig.
func (lc *ListenerConfig) Init() {
if lc.Policy != nil {
@@ -683,6 +726,7 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
return
}
uc.initDoHScheme()
// DoH/DoH3 requires endpoint is an HTTP url.
if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 {
u, err := url.Parse(uc.Endpoint)
@@ -690,10 +734,6 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
return
}
if u.Scheme != "http" && u.Scheme != "https" {
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
return
}
}
}

View File

@@ -1,6 +1,7 @@
package ctrld_test
import (
"fmt"
"os"
"strings"
"testing"
@@ -102,6 +103,8 @@ func TestConfigValidation(t *testing.T) {
{"invalid lease file format", configWithInvalidLeaseFileFormat(t), true},
{"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true},
{"invalid client id pref", configWithInvalidClientIDPref(t), true},
{"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false},
{"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true},
}
for _, tc := range tests {
@@ -167,6 +170,12 @@ func invalidUpstreamType(t *testing.T) *ctrld.Config {
return cfg
}
func dohUpstreamEndpointWithoutScheme(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Upstream["0"].Endpoint = "freedns.controld.com/p1"
return cfg
}
func invalidUpstreamTimeout(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Upstream["0"].Timeout = -1
@@ -258,7 +267,7 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Upstream["0"].Endpoint = "1.1.1.1"
cfg.Upstream["0"].Endpoint = "/1.1.1.1"
cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH
return cfg
}
@@ -268,3 +277,12 @@ func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config {
cfg.Service.ClientIDPref = "foo"
return cfg
}
func configWithInvalidFlushCacheDomain(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Service.CacheFlushDomains = make([]string, 257)
for i := range cfg.Service.CacheFlushDomains {
cfg.Service.CacheFlushDomains[i] = fmt.Sprintf("%d.com", i)
}
return cfg
}

View File

@@ -157,6 +157,13 @@ stale cached records (regardless of their TTLs) until upstream comes online.
- Required: no
- Default: false
### cache_flush_domains
When `ctrld` receives query with domain name in `cache_flush_domains`, the local cache will be discarded
before serving the query.
- Type: array of strings
- Required: no
### max_concurrent_requests
The number of concurrent requests that will be handled, must be a non-negative integer.
Tweaking this value depends on the capacity of your system.
@@ -220,7 +227,7 @@ DHCP leases file format.
- Type: string
- Required: no
- Valid values: `dnsmasq`, `isc-dhcp`
- Valid values: `dnsmasq`, `isc-dhcp`, `kea-dhcp4`
- Default: ""
### client_id_preference
@@ -531,7 +538,7 @@ And within each policy, the rules are processed from top to bottom.
### failover_rcodes
For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`.
- Type: array of string
- Type: array of strings
- Required: no
- Default: []
-

18
doh.go
View File

@@ -60,17 +60,10 @@ func init() {
}
}
// TODO: use sync.OnceValue when upgrading to go1.21
var xCdOsValueOnce sync.Once
var xCdOsValue string
func dohOsHeaderValue() string {
xCdOsValueOnce.Do(func() {
oi := osinfo.New()
xCdOsValue = strings.Join([]string{EncodeOsNameMap[runtime.GOOS], EncodeArchNameMap[runtime.GOARCH], oi.Dist}, "-")
})
return xCdOsValue
}
var dohOsHeaderValue = sync.OnceValue(func() string {
oi := osinfo.New()
return strings.Join([]string{EncodeOsNameMap[runtime.GOOS], EncodeArchNameMap[runtime.GOARCH], oi.Dist}, "-")
})()
func newDohResolver(uc *UpstreamConfig) *dohResolver {
r := &dohResolver{
@@ -172,7 +165,6 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
// newControlDHeaders returns DoH/Doh3 HTTP request headers for ControlD upstream.
func newControlDHeaders(ci *ClientInfo) http.Header {
header := make(http.Header)
header.Set(dohOsHeader, dohOsHeaderValue())
if ci.Mac != "" {
header.Set(dohMacHeader, ci.Mac)
}
@@ -183,7 +175,7 @@ func newControlDHeaders(ci *ClientInfo) http.Header {
header.Set(dohHostHeader, ci.Hostname)
}
if ci.Self {
header.Set(dohOsHeader, dohOsHeaderValue())
header.Set(dohOsHeader, dohOsHeaderValue)
}
switch ci.ClientIDPref {
case "mac":

View File

@@ -6,7 +6,7 @@ import (
)
func Test_dohOsHeaderValue(t *testing.T) {
val := dohOsHeaderValue()
val := dohOsHeaderValue
if val == "" {
t.Fatalf("empty %s", dohOsHeader)
}

17
go.mod
View File

@@ -3,6 +3,7 @@ module github.com/Control-D-Inc/ctrld
go 1.21
require (
github.com/Masterminds/semver v1.5.0
github.com/coreos/go-systemd/v22 v22.5.0
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf
github.com/frankban/quicktest v1.14.5
@@ -17,26 +18,28 @@ require (
github.com/kardianos/service v1.2.1
github.com/mdlayher/ndp v1.0.1
github.com/miekg/dns v1.1.55
github.com/minio/selfupdate v0.6.0
github.com/olekukonko/tablewriter v0.0.5
github.com/pelletier/go-toml/v2 v2.0.8
github.com/prometheus/client_golang v1.15.1
github.com/prometheus/client_model v0.4.0
github.com/prometheus/prom2json v1.3.3
github.com/quic-go/quic-go v0.41.0
github.com/quic-go/quic-go v0.42.0
github.com/rs/zerolog v1.28.0
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.16.0
github.com/stretchr/testify v1.8.3
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/net v0.17.0
golang.org/x/net v0.23.0
golang.org/x/sync v0.2.0
golang.org/x/sys v0.13.0
golang.org/x/sys v0.18.0
golang.zx2c4.com/wireguard/windows v0.5.3
tailscale.com v1.44.0
)
require (
aead.dev/minisign v0.2.0 // indirect
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
@@ -77,14 +80,14 @@ require (
github.com/subosito/gotenv v1.4.2 // indirect
github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
go.uber.org/mock v0.3.0 // indirect
go.uber.org/mock v0.4.0 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.9.1 // indirect
google.golang.org/protobuf v1.30.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

41
go.sum
View File

@@ -1,3 +1,5 @@
aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk=
aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
@@ -38,6 +40,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik=
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w=
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 h1:Kk6a4nehpJ3UuJRqlA3JxYxBZEqCeOmATOvrbT4p9RA=
@@ -222,6 +226,8 @@ github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU=
github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
@@ -253,8 +259,8 @@ github.com/prometheus/prom2json v1.3.3 h1:IYfSMiZ7sSOfliBoo89PcufjWO4eAR0gznGcET
github.com/prometheus/prom2json v1.3.3/go.mod h1:Pv4yIPktEkK7btWsrUTWDDDrnpUrAELaOCj+oFwlgmc=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k=
github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA=
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
@@ -310,8 +316,8 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8=
go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -319,11 +325,13 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -394,8 +402,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -429,6 +437,7 @@ golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -454,6 +463,7 @@ golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210225134936-a50acf3fe073/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -465,8 +475,9 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -476,11 +487,13 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
@@ -626,8 +639,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

View File

@@ -14,6 +14,11 @@ import (
"github.com/Control-D-Inc/ctrld/internal/controld"
)
const (
ipV4Loopback = "127.0.0.1"
ipv6Loopback = "::1"
)
// IpResolver is the interface for retrieving IP from Mac.
type IpResolver interface {
fmt.Stringer
@@ -224,6 +229,7 @@ func (t *Table) init() {
cancel()
}()
go t.ndp.listen(ctx)
go t.ndp.subscribe(ctx)
}
// PTR lookup.
if t.discoverPTR() {
@@ -321,6 +327,16 @@ func (t *Table) LookupRFC1918IPv4(mac string) string {
return ""
}
// LocalHostname returns the localhost hostname associated with loopback IP.
func (t *Table) LocalHostname() string {
for _, ip := range []string{ipV4Loopback, ipv6Loopback} {
if name := t.LookupHostname(ip, ""); name != "" {
return name
}
}
return ""
}
type macEntry struct {
mac string
src string

View File

@@ -353,8 +353,8 @@ func (d *dhcp) addSelf() {
return
}
hostname = normalizeHostname(hostname)
d.ip2name.Store("127.0.0.1", hostname)
d.ip2name.Store("::1", hostname)
d.ip2name.Store(ipV4Loopback, hostname)
d.ip2name.Store(ipv6Loopback, hostname)
found := false
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
mac := i.HardwareAddr.String()
@@ -375,15 +375,17 @@ func (d *dhcp) addSelf() {
d.mac.Store(ip.String(), mac)
d.ip.Store(mac, ip.String())
if ip.To4() != nil {
d.mac.Store("127.0.0.1", mac)
d.mac.Store(ipV4Loopback, mac)
} else {
d.mac.Store("::1", mac)
d.mac.Store(ipv6Loopback, mac)
}
d.mac2name.Store(mac, hostname)
d.ip2name.Store(ip.String(), hostname)
// If we have self IP set, and this IP is it, use this IP only.
if ip.String() == d.selfIP {
found = true
d.mac.Store(ipV4Loopback, mac)
d.mac.Store(ipv6Loopback, mac)
}
}
})

View File

@@ -95,7 +95,7 @@ func (hf *hostsFile) LookupHostnameByIP(ip string) string {
hf.mu.Lock()
defer hf.mu.Unlock()
if names := hf.m[ip]; len(names) > 0 {
isLoopback := ip == "127.0.0.1" || ip == "::1"
isLoopback := ip == ipV4Loopback || ip == ipv6Loopback
for _, hostname := range names {
name := normalizeHostname(hostname)
// Ignoring ipv4/ipv6 loopback entry.

View File

@@ -15,6 +15,7 @@ import (
"github.com/mdlayher/ndp"
"github.com/Control-D-Inc/ctrld"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
// ndpDiscover provides client discovery functionality using NDP protocol.
@@ -69,15 +70,45 @@ func (nd *ndpDiscover) List() []string {
return ips
}
// saveInfo saves ip and mac info to mapping table.
func (nd *ndpDiscover) saveInfo(ip, mac string) {
ip = normalizeIP(ip)
// Store ip => map mapping,
nd.mac.Store(ip, mac)
// Do not store mac => ip mapping if new ip is a link local unicast.
if ctrldnet.IsLinkLocalUnicastIPv6(ip) {
return
}
// If there is old ip => mac mapping, delete it.
if old, existed := nd.ip.Load(mac); existed {
oldIP := old.(string)
if oldIP != ip {
nd.mac.Delete(oldIP)
}
}
// Store mac => ip mapping.
nd.ip.Store(mac, ip)
}
// listen listens on ipv6 link local for Neighbor Solicitation message
// to update new neighbors information to ndp table.
func (nd *ndpDiscover) listen(ctx context.Context) {
ifi, err := firstInterfaceWithV6LinkLocal()
ifis, err := allInterfacesWithV6LinkLocal()
if err != nil {
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("failed to find valid ipv6")
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("failed to find valid ipv6 interfaces")
return
}
c, ip, err := ndp.Listen(ifi, ndp.LinkLocal)
for _, ifi := range ifis {
go func(ifi *net.Interface) {
nd.listenOnInterface(ctx, ifi)
}(ifi)
}
}
func (nd *ndpDiscover) listenOnInterface(ctx context.Context, ifi *net.Interface) {
c, ip, err := ndp.Listen(ifi, ndp.Unspecified)
if err != nil {
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("ndp listen failed")
return
@@ -111,8 +142,7 @@ func (nd *ndpDiscover) listen(ctx context.Context) {
for _, opt := range am.Options {
if lla, ok := opt.(*ndp.LinkLayerAddress); ok {
mac := lla.Addr.String()
nd.mac.Store(fromIP, mac)
nd.ip.Store(mac, fromIP)
nd.saveInfo(fromIP, mac)
}
}
}
@@ -127,8 +157,7 @@ func (nd *ndpDiscover) scanWindows(r io.Reader) {
continue
}
if mac := parseMAC(fields[1]); mac != "" {
nd.mac.Store(fields[0], mac)
nd.ip.Store(mac, fields[0])
nd.saveInfo(fields[0], mac)
}
}
}
@@ -147,8 +176,7 @@ func (nd *ndpDiscover) scanUnix(r io.Reader) {
if idx := strings.IndexByte(ip, '%'); idx != -1 {
ip = ip[:idx]
}
nd.mac.Store(ip, mac)
nd.ip.Store(mac, ip)
nd.saveInfo(ip, mac)
}
}
}
@@ -183,14 +211,15 @@ func parseMAC(mac string) string {
return hw.String()
}
// firstInterfaceWithV6LinkLocal returns the first interface which is capable of using NDP.
func firstInterfaceWithV6LinkLocal() (*net.Interface, error) {
// allInterfacesWithV6LinkLocal returns all interfaces which is capable of using NDP.
func allInterfacesWithV6LinkLocal() ([]*net.Interface, error) {
ifis, err := net.Interfaces()
if err != nil {
return nil, err
}
res := make([]*net.Interface, 0, len(ifis))
for _, ifi := range ifis {
ifi := ifi
// Skip if iface is down/loopback/non-multicast.
if ifi.Flags&net.FlagUp == 0 || ifi.Flags&net.FlagLoopback != 0 || ifi.Flags&net.FlagMulticast == 0 {
continue
@@ -211,9 +240,10 @@ func firstInterfaceWithV6LinkLocal() (*net.Interface, error) {
return nil, fmt.Errorf("invalid ip address: %s", ipNet.String())
}
if ip.Is6() && !ip.Is4In6() {
return &ifi, nil
res = append(res, &ifi)
break
}
}
}
return nil, errors.New("no interface can be used")
return res, nil
}

View File

@@ -1,7 +1,10 @@
package clientinfo
import (
"context"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"github.com/Control-D-Inc/ctrld"
)
@@ -15,10 +18,47 @@ func (nd *ndpDiscover) scan() {
}
for _, n := range neighs {
// Skipping non-reachable neighbors.
if n.State&netlink.NUD_REACHABLE == 0 {
continue
}
ip := n.IP.String()
mac := n.HardwareAddr.String()
nd.mac.Store(ip, mac)
nd.ip.Store(mac, ip)
nd.saveInfo(ip, mac)
}
}
// subscribe watches NDP table changes and update new information to local table.
func (nd *ndpDiscover) subscribe(ctx context.Context) {
ch := make(chan netlink.NeighUpdate)
done := make(chan struct{})
defer close(done)
if err := netlink.NeighSubscribe(ch, done); err != nil {
ctrld.ProxyLogger.Load().Err(err).Msg("could not perform neighbor subscribing")
return
}
for {
select {
case <-ctx.Done():
return
case nu := <-ch:
if nu.Family != netlink.FAMILY_V6 {
continue
}
ip := normalizeIP(nu.IP.String())
if nu.Type == unix.RTM_DELNEIGH {
ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor: %s", ip)
nd.mac.Delete(ip)
continue
}
mac := nu.HardwareAddr.String()
switch nu.State {
case netlink.NUD_REACHABLE:
nd.saveInfo(ip, mac)
case netlink.NUD_FAILED:
ctrld.ProxyLogger.Load().Debug().Msgf("removing NDP neighbor with failed state: %s", ip)
nd.mac.Delete(ip)
}
}
}
}

View File

@@ -4,6 +4,7 @@ package clientinfo
import (
"bytes"
"context"
"os/exec"
"runtime"
@@ -29,3 +30,7 @@ func (nd *ndpDiscover) scan() {
nd.scanUnix(bytes.NewReader(data))
}
}
// subscribe watches NDP table changes and update new information to local table.
// This is a stub method, and only works on Linux at this moment.
func (nd *ndpDiscover) subscribe(ctx context.Context) {}

View File

@@ -45,20 +45,22 @@ ff02::c 33-33-00-00-00-0c Permanent
nd.scanWindows(r)
count := 0
expectedCount := 6
nd.mac.Range(func(key, value any) bool {
count++
return true
})
if count != 6 {
t.Errorf("unexpected count, want 6, got: %d", count)
if count != expectedCount {
t.Errorf("unexpected count, want %d, got: %d", expectedCount, count)
}
count = 0
expectedCount = 4
nd.ip.Range(func(key, value any) bool {
count++
return true
})
if count != 5 {
t.Errorf("unexpected count, want 5, got: %d", count)
if count != expectedCount {
t.Errorf("unexpected count, want %d, got: %d", expectedCount, count)
}
}

View File

@@ -12,6 +12,7 @@ import (
type Cacher interface {
Get(Key) *Value
Add(Key, *Value)
Purge()
}
// Key is the caching key for DNS message.
@@ -34,15 +35,22 @@ type LRUCache struct {
cacher *lru.ARCCache[Key, *Value]
}
// Get looks up key's value from cache.
func (l *LRUCache) Get(key Key) *Value {
v, _ := l.cacher.Get(key)
return v
}
// Add adds a value to cache.
func (l *LRUCache) Add(key Key, value *Value) {
l.cacher.Add(key, value)
}
// Purge clears the cache.
func (l *LRUCache) Purge() {
l.cacher.Purge()
}
// NewLRUCache creates a new LRUCache instance with given size.
func NewLRUCache(size int) (*LRUCache, error) {
cacher, err := lru.NewARC[Key, *Value](size)

View File

@@ -115,6 +115,15 @@ func IsIPv6(ip string) bool {
return parsedIP != nil && parsedIP.To4() == nil && parsedIP.To16() != nil
}
// IsLinkLocalUnicastIPv6 checks if the provided IP is a link local unicast v6 address.
func IsLinkLocalUnicastIPv6(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil || parsedIP.To4() != nil || parsedIP.To16() == nil {
return false
}
return parsedIP.To16().IsLinkLocalUnicast()
}
type parallelDialerResult struct {
conn net.Conn
err error

View File

@@ -10,6 +10,8 @@ import (
"github.com/Control-D-Inc/ctrld"
)
const CtrldMarker = `# GENERATED BY ctrld - DO NOT MODIFY`
const ConfigContentTmpl = `# GENERATED BY ctrld - DO NOT MODIFY
no-resolv
{{- range .Upstreams}}

View File

@@ -0,0 +1,22 @@
package netgear
const openWrtScript = `#!/bin/sh /etc/rc.common
USE_PROCD=1
# After dnsmasq starts
START=61
# Before network stops
STOP=89
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
name="{{.Name}}"
pid_file="/var/run/${name}.pid"
start_service() {
echo "Starting ${name}"
procd_open_instance
procd_set_param command ${cmd}
procd_set_param respawn # respawn automatically if something died
procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop
procd_close_instance
echo "${name} has been started"
}
`

View File

@@ -0,0 +1,220 @@
package netgear
import (
"bufio"
"bytes"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
"github.com/Control-D-Inc/ctrld/internal/router/nvram"
)
const (
Name = "netgear_orbi_voxel"
netgearOrbiVoxelDNSMasqConfigPath = "/etc/dnsmasq.conf"
netgearOrbiVoxelHomedir = "/mnt/bitdefender"
netgearOrbiVoxelStartupScript = "/mnt/bitdefender/rc.user"
netgearOrbiVoxelStartupScriptBackup = "/mnt/bitdefender/rc.user.bak"
netgearOrbiVoxelStartupScriptMarker = "\n# GENERATED BY ctrld"
)
var nvramKvMap = map[string]string{
"dns_hijack": "0", // Disable dns hijacking
}
type NetgearOrbiVoxel struct {
cfg *ctrld.Config
}
// New returns a router.Router for configuring/setup/run ctrld on ddwrt routers.
func New(cfg *ctrld.Config) *NetgearOrbiVoxel {
return &NetgearOrbiVoxel{cfg: cfg}
}
func (d *NetgearOrbiVoxel) ConfigureService(svc *service.Config) error {
if err := d.checkInstalledDir(); err != nil {
return err
}
svc.Option["SysvScript"] = openWrtScript
return nil
}
func (d *NetgearOrbiVoxel) Install(_ *service.Config) error {
// Ignoring error here at this moment is ok, since everything will be wiped out on reboot.
_ = exec.Command("/etc/init.d/ctrld", "enable").Run()
if err := d.checkInstalledDir(); err != nil {
return err
}
if err := backupVoxelStartupScript(); err != nil {
return fmt.Errorf("backup startup script: %w", err)
}
if err := writeVoxelStartupScript(); err != nil {
return fmt.Errorf("writing startup script: %w", err)
}
return nil
}
func (d *NetgearOrbiVoxel) Uninstall(_ *service.Config) error {
if err := os.Remove(netgearOrbiVoxelStartupScript); err != nil && !os.IsNotExist(err) {
return err
}
err := os.Rename(netgearOrbiVoxelStartupScriptBackup, netgearOrbiVoxelStartupScript)
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
func (d *NetgearOrbiVoxel) PreRun() error {
return nil
}
func (d *NetgearOrbiVoxel) Setup() error {
if d.cfg.FirstListener().IsDirectDnsListener() {
return nil
}
// Already setup.
if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" {
return nil
}
data, err := dnsmasq.ConfTmplWithCacheDisabled(dnsmasq.ConfigContentTmpl, d.cfg, false)
if err != nil {
return err
}
currentConfig, _ := os.ReadFile(netgearOrbiVoxelDNSMasqConfigPath)
configContent := append(currentConfig, data...)
if err := os.WriteFile(netgearOrbiVoxelDNSMasqConfigPath, configContent, 0600); err != nil {
return err
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
return err
}
if err := nvram.SetKV(nvramKvMap, nvram.CtrldSetupKey); err != nil {
return err
}
return nil
}
func (d *NetgearOrbiVoxel) Cleanup() error {
if d.cfg.FirstListener().IsDirectDnsListener() {
return nil
}
if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val != "1" {
return nil // was restored, nothing to do.
}
// Restore old configs.
if err := nvram.Restore(nvramKvMap, nvram.CtrldSetupKey); err != nil {
return err
}
// Restore dnsmasq config.
if err := restoreDnsmasqConf(); err != nil {
return err
}
// Restart dnsmasq service.
if err := restartDNSMasq(); err != nil {
return err
}
return nil
}
// checkInstalledDir checks that ctrld binary was installed in the correct directory.
func (d *NetgearOrbiVoxel) checkInstalledDir() error {
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("checkHomeDir: failed to get binary path %w", err)
}
if !strings.HasSuffix(filepath.Dir(exePath), netgearOrbiVoxelHomedir) {
return fmt.Errorf("checkHomeDir: could not install service outside %s", netgearOrbiVoxelHomedir)
}
return nil
}
// backupVoxelStartupScript creates a backup of original startup script if existed.
func backupVoxelStartupScript() error {
// Do nothing if the startup script was modified by ctrld.
script, _ := os.ReadFile(netgearOrbiVoxelStartupScript)
if bytes.Contains(script, []byte(netgearOrbiVoxelStartupScriptMarker)) {
return nil
}
err := os.Rename(netgearOrbiVoxelStartupScript, netgearOrbiVoxelStartupScriptBackup)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("backupVoxelStartupScript: %w", err)
}
return nil
}
// writeVoxelStartupScript writes startup script to re-install ctrld upon reboot.
// See: https://github.com/SVoxel/ORBI-RBK50/pull/7
func writeVoxelStartupScript() error {
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("configure service: failed to get binary path %w", err)
}
// This is called when "ctrld start ..." runs, so recording
// the same command line arguments to use in startup script.
argStr := strings.Join(os.Args[1:], " ")
script, _ := os.ReadFile(netgearOrbiVoxelStartupScriptBackup)
script = append(script, fmt.Sprintf("%s\n%q %s\n", netgearOrbiVoxelStartupScriptMarker, exe, argStr)...)
f, err := os.Create(netgearOrbiVoxelStartupScript)
if err != nil {
return fmt.Errorf("failed to create startup script: %w", err)
}
defer f.Close()
if _, err := f.Write(script); err != nil {
return fmt.Errorf("failed to write startup script: %w", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("failed to save startup script: %w", err)
}
return nil
}
// restoreDnsmasqConf restores original dnsmasq configuration.
func restoreDnsmasqConf() error {
f, err := os.Open(netgearOrbiVoxelDNSMasqConfigPath)
if err != nil {
return err
}
defer f.Close()
var bs []byte
buf := bytes.NewBuffer(bs)
removed := false
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if line == dnsmasq.CtrldMarker {
removed = true
}
if !removed {
_, err := buf.WriteString(line + "\n")
if err != nil {
return err
}
}
}
return os.WriteFile(netgearOrbiVoxelDNSMasqConfigPath, buf.Bytes(), 0644)
}
func restartDNSMasq() error {
if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil {
return fmt.Errorf("restartDNSMasq: %s, %w", string(out), err)
}
return nil
}

View File

@@ -0,0 +1,40 @@
package router
import (
"encoding/xml"
"os"
)
// Config represents /conf/config.xml file found on pfsense/opnsense.
type Config struct {
PfsenseUnbound *string `xml:"unbound>enable,omitempty"`
OPNsenseUnbound *string `xml:"OPNsense>unboundplus>general>enabled,omitempty"`
Dnsmasq *string `xml:"dnsmasq>enable,omitempty"`
}
// DnsmasqEnabled reports whether dnsmasq is enabled.
func (c *Config) DnsmasqEnabled() bool {
if isPfsense() { // pfsense only set the attribute if dnsmasq is enabled.
return c.Dnsmasq != nil
}
return c.Dnsmasq != nil && *c.Dnsmasq == "1"
}
// UnboundEnabled reports whether unbound is enabled.
func (c *Config) UnboundEnabled() bool {
if isPfsense() { // pfsense only set the attribute if unbound is enabled.
return c.PfsenseUnbound != nil
}
return c.OPNsenseUnbound != nil && *c.OPNsenseUnbound == "1"
}
// currentConfig does unmarshalling /conf/config.xml file,
// return the corresponding *Config represent it.
func currentConfig() (*Config, error) {
buf, _ := os.ReadFile("/conf/config.xml")
c := Config{}
if err := xml.Unmarshal(buf, &c); err != nil {
return nil, err
}
return &c, nil
}

View File

@@ -111,8 +111,16 @@ func (or *osRouter) Setup() error {
func (or *osRouter) Cleanup() error {
if or.cdMode {
_ = exec.Command(unboundRcPath, "onerestart").Run()
_ = exec.Command(dnsmasqRcPath, "onerestart").Run()
c, err := currentConfig()
if err != nil {
return err
}
if c.UnboundEnabled() {
_ = exec.Command(unboundRcPath, "onerestart").Run()
}
if c.DnsmasqEnabled() {
_ = exec.Command(dnsmasqRcPath, "onerestart").Run()
}
}
return nil
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/Control-D-Inc/ctrld/internal/router/edgeos"
"github.com/Control-D-Inc/ctrld/internal/router/firewalla"
"github.com/Control-D-Inc/ctrld/internal/router/merlin"
netgear "github.com/Control-D-Inc/ctrld/internal/router/netgear_orbi_voxel"
"github.com/Control-D-Inc/ctrld/internal/router/openwrt"
"github.com/Control-D-Inc/ctrld/internal/router/synology"
"github.com/Control-D-Inc/ctrld/internal/router/tomato"
@@ -66,10 +67,17 @@ func New(cfg *ctrld.Config, cdMode bool) Router {
return tomato.New(cfg)
case firewalla.Name:
return firewalla.New(cfg)
case netgear.Name:
return netgear.New(cfg)
}
return newOsRouter(cfg, cdMode)
}
// IsNetGearOrbi reports whether the router is a Netgear Orbi router.
func IsNetGearOrbi() bool {
return Name() == netgear.Name
}
// IsGLiNet reports whether the router is an GL.iNet router.
func IsGLiNet() bool {
if Name() != openwrt.Name {
@@ -145,7 +153,7 @@ func LocalResolverIP() string {
// HomeDir returns the home directory of ctrld on current router.
func HomeDir() (string, error) {
switch Name() {
case ddwrt.Name, merlin.Name, tomato.Name:
case ddwrt.Name, firewalla.Name, merlin.Name, netgear.Name, tomato.Name:
exe, err := os.Executable()
if err != nil {
return "", err
@@ -198,6 +206,9 @@ func distroName() string {
case bytes.HasPrefix(unameO(), []byte("ASUSWRT-Merlin")):
return merlin.Name
case haveFile("/etc/openwrt_version"):
if haveFile("/bin/config") { // TODO: is there any more reliable way?
return netgear.Name
}
return openwrt.Name
case isUbios():
return ubios.Name