Merge pull request #41 from Control-D-Inc/release-branch-v1.2.0

Release branch v1.2.0
This commit is contained in:
Yegor S
2023-05-15 21:48:05 -04:00
committed by GitHub
56 changed files with 7129 additions and 655 deletions

4
.gitignore vendored
View File

@@ -1,3 +1,5 @@
dist/
gon.hcl
/Build
.DS_Store

View File

@@ -28,6 +28,7 @@ All DNS protocols are supported, including:
- Windows (386, amd64, arm)
- Mac (amd64, arm64)
- Linux (386, amd64, arm, mips)
- Common routers (See Router Mode below)
## Download
Download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section.
@@ -138,6 +139,19 @@ Once you run the above command, the following things will happen:
- Your default network interface will be updated to use the listener started by the service
- All OS DNS queries will be sent to the listener
### Router Mode
You can run `ctrld` on any supported router, which will function similarly to the Service Mode mentioned above. The list of supported routers and firmware includes:
- OpenWRT
- DD-WRT
- Asus Merlin
- GL.iNet
- Ubiquiti
In order to start `ctrld` as a DNS provider, simply run `./ctrld setup auto` command. You can optionally supply the `--cd` flag on order to configure a specific Control D device on the router.
In this mode, and when Control D upstreams are used, the router will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
## Configuration
See [Configuration Docs](docs/config.md).

11
client_info.go Normal file
View File

@@ -0,0 +1,11 @@
package ctrld
// ClientInfoCtxKey is the context key to store client info.
type ClientInfoCtxKey struct{}
// ClientInfo represents ctrld's clients information.
type ClientInfo struct {
Mac string
IP string
Hostname string
}

View File

@@ -3,10 +3,10 @@ package main
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"log"
"net"
"net/netip"
"os"
@@ -18,9 +18,8 @@ import (
"sync"
"time"
"github.com/fsnotify/fsnotify"
"github.com/cuonglm/osinfo"
"github.com/fsnotify/fsnotify"
"github.com/go-playground/validator/v10"
"github.com/kardianos/service"
"github.com/miekg/dns"
@@ -31,8 +30,10 @@ import (
"tailscale.com/net/interfaces"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/certs"
"github.com/Control-D-Inc/ctrld/internal/controld"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
"github.com/Control-D-Inc/ctrld/internal/router"
)
const selfCheckFQDN = "verify.controld.com"
@@ -46,6 +47,7 @@ var (
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
defaultConfigWritten = false
defaultConfigFile = "ctrld.toml"
rootCertPool *x509.CertPool
)
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"}
@@ -72,10 +74,13 @@ var rootCmd = &cobra.Command{
Use: "ctrld",
Short: strings.TrimLeft(rootShortDesc, "\n"),
Version: curVersion(),
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
}
func curVersion() string {
if version != "dev" {
if version != "dev" && !strings.HasPrefix(version, "v") {
version = "v" + version
}
if len(commit) > 7 {
@@ -96,6 +101,13 @@ func initCLI() {
"v",
`verbose log output, "-v" basic logging, "-vv" debug level logging`,
)
rootCmd.PersistentFlags().BoolVarP(
&silent,
"silent",
"s",
false,
`do not write any log output`,
)
rootCmd.SetHelpCommand(&cobra.Command{Hidden: true})
rootCmd.CompletionOptions.HiddenDefaultCmd = true
@@ -103,9 +115,12 @@ func initCLI() {
Use: "run",
Short: "Run the DNS proxy server",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: func(cmd *cobra.Command, args []string) {
if daemon && runtime.GOOS == "windows" {
log.Fatal("Cannot run in daemon mode. Please install a Windows service.")
mainLog.Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.")
}
waitCh := make(chan struct{})
@@ -122,6 +137,7 @@ func initCLI() {
if err != nil {
mainLog.Fatal().Err(err).Msg("failed create new service")
}
s = newService(s)
serviceLogger, err := s.Logger(nil)
if err != nil {
mainLog.Error().Err(err).Msg("failed to get service logger")
@@ -138,43 +154,36 @@ func initCLI() {
}
noConfigStart := isNoConfigStart(cmd)
writeDefaultConfig := !noConfigStart && configBase64 == ""
configs := []struct {
name string
written bool
}{
// For compatibility, we check for config.toml first, but only read it if exists.
{"config", false},
{"ctrld", writeDefaultConfig},
}
for _, config := range configs {
ctrld.SetConfigName(v, config.name)
v.SetConfigFile(configPath)
if readConfigFile(config.written) {
break
}
}
tryReadingConfig(writeDefaultConfig)
readBase64Config()
readBase64Config(configBase64)
processNoConfigFlags(noConfigStart)
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
}
log.Printf("starting ctrld %s\n", curVersion())
mainLog.Info().Msgf("starting ctrld %s", curVersion())
oi := osinfo.New()
log.Printf("os: %s\n", oi.String())
mainLog.Info().Msgf("os: %s", oi.String())
// Wait for network up.
if !ctrldnet.Up() {
log.Fatal("network is not up yet")
mainLog.Fatal().Msg("network is not up yet")
}
processLogAndCacheFlags()
// Log config do not have thing to validate, so it's safe to init log here,
// so it's able to log information in processCDFlags.
initLogging()
if setupRouter {
if err := router.PreStart(); err != nil {
mainLog.Fatal().Err(err).Msg("failed to perform router pre-start check")
}
}
processCDFlags()
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
log.Fatalf("invalid config: %v", err)
mainLog.Fatal().Msgf("invalid config: %v", err)
}
initCache()
@@ -200,6 +209,24 @@ func initCLI() {
os.Exit(0)
}
if setupRouter {
switch platform := router.Name(); {
case platform == router.DDWrt:
rootCertPool = certs.CACertPool()
fallthrough
case platform != "":
mainLog.Debug().Msg("Router setup")
err := router.Configure(&cfg)
if errors.Is(err, router.ErrNotSupported) {
unsupportedPlatformHelp(cmd)
os.Exit(1)
}
if err != nil {
mainLog.Fatal().Err(err).Msg("failed to configure router")
}
}
}
close(waitCh)
<-stopCh
},
@@ -218,14 +245,19 @@ func initCLI() {
_ = runCmd.Flags().MarkHidden("homedir")
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
_ = runCmd.Flags().MarkHidden("iface")
runCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
_ = runCmd.Flags().MarkHidden("router")
rootCmd.AddCommand(runCmd)
startCmd := &cobra.Command{
PreRun: checkHasElevatedPrivilege,
Use: "start",
Short: "Install and start the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "start",
Short: "Install and start the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
sc := &service.Config{}
*sc = *svcConfig
@@ -235,6 +267,9 @@ func initCLI() {
}
setDependencies(sc)
sc.Arguments = append([]string{"run"}, osArgs...)
if err := router.ConfigureService(sc); err != nil {
mainLog.Fatal().Err(err).Msg("failed to configure service on router")
}
// No config path, generating config in HOME directory.
noConfigStart := isNoConfigStart(cmd)
@@ -242,18 +277,18 @@ func initCLI() {
if configPath != "" {
v.SetConfigFile(configPath)
}
if dir, err := os.UserHomeDir(); err == nil {
if dir, err := userHomeDir(); err == nil {
setWorkingDirectory(sc, dir)
if configPath == "" && writeDefaultConfig {
defaultConfigFile = filepath.Join(dir, defaultConfigFile)
v.SetConfigFile(defaultConfigFile)
}
sc.Arguments = append(sc.Arguments, "--homedir="+dir)
}
readConfigFile(writeDefaultConfig && cdUID == "")
tryReadingConfig(writeDefaultConfig)
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
}
logPath := cfg.Service.LogPath
@@ -262,20 +297,21 @@ func initCLI() {
cfg.Service.LogPath = logPath
processCDFlags()
// On Windows, the service will be run as SYSTEM, so if ctrld start as Admin,
// the user home dir is different, so pass specific arguments that relevant here.
if runtime.GOOS == "windows" {
if configPath == "" {
sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile)
}
// Explicitly passing config, so on system where home directory could not be obtained,
// or sub-process env is different with the parent, we still behave correctly and use
// the expected config file.
if configPath == "" {
sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile)
}
prog := &prog{}
s, err := service.New(prog, sc)
if err != nil {
stderrMsg(err.Error())
mainLog.Error().Msg(err.Error())
return
}
s = newService(s)
tasks := []task{
{s.Stop, false},
{s.Uninstall, false},
@@ -283,7 +319,11 @@ func initCLI() {
{s.Start, true},
}
if doTasks(tasks) {
status, err := s.Status()
if err := router.PostInstall(); err != nil {
mainLog.Warn().Err(err).Msg("post installation failed, please check system/service log for details error")
return
}
status, err := serviceStatus(s)
if err != nil {
mainLog.Warn().Err(err).Msg("could not get service status")
return
@@ -292,7 +332,7 @@ func initCLI() {
status = selfCheckStatus(status)
switch status {
case service.StatusRunning:
mainLog.Info().Msg("Service started")
mainLog.Notice().Msg("Service started")
default:
mainLog.Error().Msg("Service did not start, please check system/service log for details error")
if runtime.GOOS == "linux" {
@@ -315,42 +355,52 @@ func initCLI() {
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
_ = startCmd.Flags().MarkHidden("router")
stopCmd := &cobra.Command{
PreRun: checkHasElevatedPrivilege,
Use: "stop",
Short: "Stop the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "stop",
Short: "Stop the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
prog := &prog{}
s, err := service.New(prog, svcConfig)
if err != nil {
stderrMsg(err.Error())
mainLog.Error().Msg(err.Error())
return
}
s = newService(s)
initLogging()
if doTasks([]task{{s.Stop, true}}) {
prog.resetDNS()
mainLog.Info().Msg("Service stopped")
mainLog.Notice().Msg("Service stopped")
}
},
}
stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
restartCmd := &cobra.Command{
PreRun: checkHasElevatedPrivilege,
Use: "restart",
Short: "Restart the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "restart",
Short: "Restart the ctrld service",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
s, err := service.New(&prog{}, svcConfig)
if err != nil {
stderrMsg(err.Error())
mainLog.Error().Msg(err.Error())
return
}
s = newService(s)
initLogging()
if doTasks([]task{{s.Restart, true}}) {
stdoutMsg("Service restarted")
mainLog.Notice().Msg("Service restarted")
}
},
}
@@ -359,39 +409,49 @@ func initCLI() {
Use: "status",
Short: "Show status of the ctrld service",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: func(cmd *cobra.Command, args []string) {
s, err := service.New(&prog{}, svcConfig)
if err != nil {
stderrMsg(err.Error())
mainLog.Error().Msg(err.Error())
return
}
status, err := s.Status()
s = newService(s)
status, err := serviceStatus(s)
if err != nil {
stderrMsg(err.Error())
mainLog.Error().Msg(err.Error())
os.Exit(1)
}
switch status {
case service.StatusUnknown:
stdoutMsg("Unknown status")
mainLog.Notice().Msg("Unknown status")
os.Exit(2)
case service.StatusRunning:
stdoutMsg("Service is running")
mainLog.Notice().Msg("Service is running")
os.Exit(0)
case service.StatusStopped:
stdoutMsg("Service is stopped")
mainLog.Notice().Msg("Service is stopped")
os.Exit(1)
}
},
}
if runtime.GOOS == "darwin" {
// On darwin, running status command without privileges may return wrong information.
statusCmd.PreRun = checkHasElevatedPrivilege
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
}
}
uninstallCmd := &cobra.Command{
PreRun: checkHasElevatedPrivilege,
Use: "uninstall",
Short: "Stop and uninstall the ctrld service",
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "uninstall",
Short: "Stop and uninstall the ctrld service",
Long: `Stop and uninstall the ctrld service.
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
@@ -400,7 +460,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
prog := &prog{}
s, err := service.New(prog, svcConfig)
if err != nil {
stderrMsg(err.Error())
mainLog.Error().Msg(err.Error())
return
}
tasks := []task{
@@ -413,7 +473,11 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
iface = "auto"
}
prog.resetDNS()
mainLog.Info().Msg("Service uninstalled")
mainLog.Debug().Msg("Router cleanup")
if err := router.Cleanup(); err != nil {
mainLog.Warn().Err(err).Msg("could not cleanup router")
}
mainLog.Notice().Msg("Service uninstalled")
return
}
},
@@ -424,6 +488,9 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
Use: "list",
Short: "List network interfaces of the host",
Args: cobra.NoArgs,
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: func(cmd *cobra.Command, args []string) {
err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
fmt.Printf("Index : %d\n", i.Index)
@@ -446,7 +513,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
println()
})
if err != nil {
stderrMsg(err.Error())
mainLog.Error().Msg(err.Error())
}
},
}
@@ -481,9 +548,12 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
serviceCmd.AddCommand(interfacesCmd)
rootCmd.AddCommand(serviceCmd)
startCmdAlias := &cobra.Command{
PreRun: checkHasElevatedPrivilege,
Use: "start",
Short: "Quick start service and configure DNS on interface",
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "start",
Short: "Quick start service and configure DNS on interface",
Run: func(cmd *cobra.Command, args []string) {
if !cmd.Flags().Changed("iface") {
os.Args = append(os.Args, "--iface="+ifaceStartStop)
@@ -496,9 +566,12 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
rootCmd.AddCommand(startCmdAlias)
stopCmdAlias := &cobra.Command{
PreRun: checkHasElevatedPrivilege,
Use: "stop",
Short: "Quick stop service and remove DNS from interface",
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
checkHasElevatedPrivilege()
},
Use: "stop",
Short: "Quick stop service and remove DNS from interface",
Run: func(cmd *cobra.Command, args []string) {
if !cmd.Flags().Changed("iface") {
os.Args = append(os.Args, "--iface="+ifaceStartStop)
@@ -510,11 +583,6 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`,
stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
rootCmd.AddCommand(stopCmdAlias)
if err := rootCmd.Execute(); err != nil {
stderrMsg(err.Error())
os.Exit(1)
}
}
func writeConfigFile() error {
@@ -547,7 +615,7 @@ func readConfigFile(writeDefaultConfig bool) bool {
// If err == nil, there's a config supplied via `--config`, no default config written.
err := v.ReadInConfig()
if err == nil {
log.Println("loading config file from:", v.ConfigFileUsed())
mainLog.Info().Msg("loading config file from: " + v.ConfigFileUsed())
defaultConfigFile = v.ConfigFileUsed()
return true
}
@@ -558,29 +626,36 @@ func readConfigFile(writeDefaultConfig bool) bool {
// If error is viper.ConfigFileNotFoundError, write default config.
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Fatal().Msgf("failed to unmarshal default config: %v", err)
}
if err := writeConfigFile(); err != nil {
log.Fatalf("failed to write default config file: %v", err)
mainLog.Fatal().Msgf("failed to write default config file: %v", err)
} else {
log.Println("writing default config file to: " + defaultConfigFile)
fp, err := filepath.Abs(defaultConfigFile)
if err != nil {
mainLog.Fatal().Msgf("failed to get default config file path: %v", err)
}
mainLog.Info().Msg("writing default config file to: " + fp)
}
defaultConfigWritten = true
return false
}
// Otherwise, report fatal error and exit.
log.Fatalf("failed to decode config file: %v", err)
mainLog.Fatal().Msgf("failed to decode config file: %v", err)
return false
}
func readBase64Config() {
func readBase64Config(configBase64 string) {
if configBase64 == "" {
return
}
configStr, err := base64.StdEncoding.DecodeString(configBase64)
if err != nil {
log.Fatalf("invalid base64 config: %v", err)
mainLog.Fatal().Msgf("invalid base64 config: %v", err)
}
if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil {
log.Fatalf("failed to read base64 config: %v", err)
mainLog.Fatal().Msgf("failed to read base64 config: %v", err)
}
}
@@ -589,7 +664,7 @@ func processNoConfigFlags(noConfigStart bool) {
return
}
if listenAddress == "" || primaryUpstream == "" {
log.Fatal(`"listen" and "primary_upstream" flags must be set in no config mode`)
mainLog.Fatal().Msg(`"listen" and "primary_upstream" flags must be set in no config mode`)
}
processListenFlag()
@@ -633,7 +708,7 @@ func processCDFlags() {
}
logger := mainLog.With().Str("mode", "cd").Logger()
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
resolverConfig, err := controld.FetchResolverConfig(cdUID)
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version)
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
s, err := service.New(&prog{}, svcConfig)
if err != nil {
@@ -665,34 +740,57 @@ func processCDFlags() {
return
}
logger.Info().Msg("generating ctrld config from Controld-D configuration")
cfg = ctrld.Config{}
cfg.Network = make(map[string]*ctrld.NetworkConfig)
cfg.Network["0"] = &ctrld.NetworkConfig{
Name: "Network 0",
Cidrs: []string{"0.0.0.0/0"},
}
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
Endpoint: resolverConfig.DOH,
Type: ctrld.ResolverTypeDOH,
Timeout: 5000,
}
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
cfg.Listener["0"] = &ctrld.ListenerConfig{
IP: "127.0.0.1",
Port: 53,
Policy: &ctrld.ListenerPolicyConfig{
Name: "My Policy",
Rules: rules,
},
logger.Info().Msg("generating ctrld config from Control-D configuration")
if resolverConfig.Ctrld.CustomConfig != "" {
logger.Info().Msg("using defined custom config of Control-D resolver")
readBase64Config(resolverConfig.Ctrld.CustomConfig)
if err := v.Unmarshal(&cfg); err != nil {
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
}
for _, listener := range cfg.Listener {
if listener.IP == "" {
listener.IP = randomLocalIP()
}
if listener.Port == 0 {
listener.Port = 53
}
}
// On router, we want to keep the listener address point to dnsmasq listener, aka 127.0.0.1:53.
if router.Name() != "" {
if lc := cfg.Listener["0"]; lc != nil {
lc.IP = "127.0.0.1"
lc.Port = 53
}
}
} else {
cfg = ctrld.Config{}
cfg.Network = make(map[string]*ctrld.NetworkConfig)
cfg.Network["0"] = &ctrld.NetworkConfig{
Name: "Network 0",
Cidrs: []string{"0.0.0.0/0"},
}
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
Endpoint: resolverConfig.DOH,
Type: ctrld.ResolverTypeDOH,
Timeout: 5000,
}
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
cfg.Listener["0"] = &ctrld.ListenerConfig{
IP: "127.0.0.1",
Port: 53,
Policy: &ctrld.ListenerPolicyConfig{
Name: "My Policy",
Rules: rules,
},
}
processLogAndCacheFlags()
}
processLogAndCacheFlags()
if err := writeConfigFile(); err != nil {
logger.Fatal().Err(err).Msg("failed to write config file")
} else {
@@ -706,11 +804,11 @@ func processListenFlag() {
}
host, portStr, err := net.SplitHostPort(listenAddress)
if err != nil {
log.Fatalf("invalid listener address: %v", err)
mainLog.Fatal().Msgf("invalid listener address: %v", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Fatalf("invalid port number: %v", err)
mainLog.Fatal().Msgf("invalid port number: %v", err)
}
lc := &ctrld.ListenerConfig{
IP: host,
@@ -777,7 +875,7 @@ func selfCheckStatus(status service.Status) service.Status {
mu.Lock()
defer mu.Unlock()
if err := v.UnmarshalKey("listener", &lcChanged); err != nil {
log.Printf("failed to unmarshal listener config: %v", err)
mainLog.Error().Msgf("failed to unmarshal listener config: %v", err)
return
}
})
@@ -802,3 +900,50 @@ func selfCheckStatus(status service.Status) service.Status {
mainLog.Debug().Msgf("self-check against %q failed", selfCheckFQDN)
return service.StatusUnknown
}
func unsupportedPlatformHelp(cmd *cobra.Command) {
mainLog.Error().Msg("Unsupported or incorrectly chosen router platform. Please open an issue and provide all relevant information: https://github.com/Control-D-Inc/ctrld/issues/new")
}
func userHomeDir() (string, error) {
switch router.Name() {
case router.DDWrt, router.Merlin:
exe, err := os.Executable()
if err != nil {
return "", err
}
return filepath.Dir(exe), nil
}
// viper will expand for us.
if runtime.GOOS == "windows" {
return os.UserHomeDir()
}
dir := "/etc/controld"
if err := os.MkdirAll(dir, 0750); err != nil {
return "", err
}
return dir, nil
}
func tryReadingConfig(writeDefaultConfig bool) {
configs := []struct {
name string
written bool
}{
// For compatibility, we check for config.toml first, but only read it if exists.
{"config", false},
{"ctrld", writeDefaultConfig},
}
dir, err := userHomeDir()
if err != nil {
mainLog.Fatal().Msgf("failed to get config dir: %v", err)
}
for _, config := range configs {
ctrld.SetConfigNameWithPath(v, config.name, dir)
v.SetConfigFile(configPath)
if readConfigFile(config.written) {
break
}
}
}

View File

@@ -0,0 +1,97 @@
package main
import (
"os"
"os/exec"
"strings"
"github.com/spf13/cobra"
"github.com/Control-D-Inc/ctrld/internal/router"
)
func initRouterCLI() {
validArgs := append(router.SupportedPlatforms(), "auto")
var b strings.Builder
b.WriteString("Auto-setup Control D on a router.\n\nSupported platforms:\n\n")
for _, arg := range validArgs {
b.WriteString(" ₒ ")
b.WriteString(arg)
if arg == "auto" {
b.WriteString(" - detect the platform you are running on")
}
b.WriteString("\n")
}
routerCmd := &cobra.Command{
Use: "setup",
Short: b.String(),
PreRun: func(cmd *cobra.Command, args []string) {
initConsoleLogging()
},
Run: func(cmd *cobra.Command, args []string) {
if len(args) == 0 {
_ = cmd.Help()
return
}
if len(args) != 1 {
_ = cmd.Help()
return
}
platform := args[0]
if platform == "auto" {
platform = router.Name()
}
switch platform {
case router.DDWrt, router.Merlin, router.OpenWrt, router.Ubios:
default:
unsupportedPlatformHelp(cmd)
os.Exit(1)
}
exe, err := os.Executable()
if err != nil {
mainLog.Fatal().Msgf("could not find executable path: %v", err)
os.Exit(1)
}
cmdArgs := []string{"start"}
cmdArgs = append(cmdArgs, osArgs(platform)...)
cmdArgs = append(cmdArgs, "--router")
command := exec.Command(exe, cmdArgs...)
command.Stdout = os.Stdout
command.Stderr = os.Stderr
command.Stdin = os.Stdin
if err := command.Run(); err != nil {
mainLog.Fatal().Msg(err.Error())
}
},
}
// Keep these flags in sync with startCmd, except for "--router".
routerCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
routerCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
routerCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
routerCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
routerCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
routerCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
routerCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
routerCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
routerCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
routerCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
tmpl := routerCmd.UsageTemplate()
tmpl = strings.Replace(tmpl, "{{.UseLine}}", "{{.UseLine}} [platform]", 1)
routerCmd.SetUsageTemplate(tmpl)
rootCmd.AddCommand(routerCmd)
}
func osArgs(platform string) []string {
args := os.Args[2:]
n := 0
for _, x := range args {
if x != platform && x != "auto" {
args[n] = x
n++
}
}
return args[:n]
}

View File

@@ -0,0 +1,5 @@
//go:build !linux
package main
func initRouterCLI() {}

View File

@@ -9,6 +9,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
@@ -17,9 +18,22 @@ import (
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
"github.com/Control-D-Inc/ctrld/internal/router"
)
const staleTTL = 60 * time.Second
const (
staleTTL = 60 * time.Second
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
EDNS0_OPTION_MAC = 0xFDE9
)
var osUpstreamConfig = &ctrld.UpstreamConfig{
Name: "OS resolver",
Type: ctrld.ResolverTypeOS,
Timeout: 2000,
}
func (p *prog) serveDNS(listenerNum string) error {
listenerConfig := p.cfg.Listener[listenerNum]
@@ -59,45 +73,49 @@ func (p *prog) serveDNS(listenerNum string) error {
g, ctx := errgroup.WithContext(context.Background())
for _, proto := range []string{"udp", "tcp"} {
proto := proto
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
// listen on ::1, then spawn a listener for receiving DNS requests.
if runtime.GOOS == "windows" && ctrldnet.SupportsIPv6ListenLocal() {
if needLocalIPv6Listener() {
g.Go(func() error {
s := &dns.Server{
Addr: net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)),
Net: proto,
Handler: handler,
}
go func() {
<-ctx.Done()
_ = s.Shutdown()
}()
if err := s.ListenAndServe(); err != nil {
mainLog.Error().Err(err).Msg("could not serving on ::1")
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler)
defer s.Shutdown()
select {
case <-ctx.Done():
case err := <-errCh:
// Local ipv6 listener should not terminate ctrld.
// It's a workaround for a quirk on Windows.
mainLog.Warn().Err(err).Msg("local ipv6 listener failed")
}
return nil
})
}
g.Go(func() error {
s := &dns.Server{
Addr: net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)),
Net: proto,
Handler: handler,
s, errCh := runDNSServer(dnsListenAddress(listenerNum, listenerConfig), proto, handler)
defer s.Shutdown()
if listenerConfig.Port == 0 {
switch s.Net {
case "udp":
mainLog.Info().Msgf("Random port chosen for udp listener.%s: %s", listenerNum, s.PacketConn.LocalAddr())
case "tcp":
mainLog.Info().Msgf("Random port chosen for tcp listener.%s: %s", listenerNum, s.Listener.Addr())
}
}
go func() {
<-ctx.Done()
_ = s.Shutdown()
}()
if err := s.ListenAndServe(); err != nil {
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
select {
case <-ctx.Done():
return nil
case err := <-errCh:
return err
}
return nil
})
}
return g.Wait()
}
// upstreamFor returns the list of upstreams for resolving the given domain,
// matching by policies defined in the listener config. The second return value
// reports whether the domain matches the policy.
//
// Though domain policy has higher priority than network policy, it is still
// processed later, because policy logging want to know whether a network rule
// is disregarded in favor of the domain level rule.
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
upstreams := []string{"upstream." + defaultUpstreamNum}
matchedPolicy := "no policy"
@@ -121,11 +139,43 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
upstreams = append([]string(nil), policyUpstreams...)
}
var networkTargets []string
var sourceIP net.IP
switch addr := addr.(type) {
case *net.UDPAddr:
sourceIP = addr.IP
case *net.TCPAddr:
sourceIP = addr.IP
}
networkRules:
for _, rule := range lc.Policy.Networks {
for source, targets := range rule {
networkNum := strings.TrimPrefix(source, "network.")
nc := p.cfg.Network[networkNum]
if nc == nil {
continue
}
for _, ipNet := range nc.IPNets {
if ipNet.Contains(sourceIP) {
matchedPolicy = lc.Policy.Name
matchedNetwork = source
networkTargets = targets
matched = true
break networkRules
}
}
}
}
for _, rule := range lc.Policy.Rules {
// There's only one entry per rule, config validation ensures this.
for source, targets := range rule {
if source == domain || wildcardMatches(source, domain) {
matchedPolicy = lc.Policy.Name
if len(networkTargets) > 0 {
matchedNetwork += " (unenforced)"
}
matchedRule = source
do(targets)
matched = true
@@ -134,31 +184,8 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
}
}
var sourceIP net.IP
switch addr := addr.(type) {
case *net.UDPAddr:
sourceIP = addr.IP
case *net.TCPAddr:
sourceIP = addr.IP
}
for _, rule := range lc.Policy.Networks {
for source, targets := range rule {
networkNum := strings.TrimPrefix(source, "network.")
nc := p.cfg.Network[networkNum]
if nc == nil {
continue
}
for _, ipNet := range nc.IPNets {
if ipNet.Contains(sourceIP) {
matchedPolicy = lc.Policy.Name
matchedNetwork = source
do(targets)
matched = true
return upstreams, matched
}
}
}
if matched {
do(networkTargets)
}
return upstreams, matched
@@ -207,8 +234,16 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
return dnsResolver.Resolve(resolveCtx, msg)
}
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
if upstreamConfig.UpstreamSendClientInfo() {
ci := router.GetClientInfoByMac(macFromMsg(msg))
if ci != nil {
ctrld.Log(ctx, mainLog.Debug(), "including client info with the request")
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci)
}
}
answer, err := resolve1(n, upstreamConfig, msg)
if err != nil {
// Only do re-bootstrapping if bootstrap ip is not explicitly set by user.
if err != nil && upstreamConfig.BootstrapIP == "" {
ctrld.Log(ctx, mainLog.Debug().Err(err), "could not resolve query on first attempt, retrying...")
// If any error occurred, re-bootstrap transport/ip, retry the request.
upstreamConfig.ReBootstrap()
@@ -222,6 +257,9 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
return answer
}
for n, upstreamConfig := range upstreamConfigs {
if upstreamConfig == nil {
continue
}
answer := resolve(n, upstreamConfig, msg)
if answer == nil {
if serveStaleCache && staleAnswer != nil {
@@ -236,6 +274,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
ctrld.Log(ctx, mainLog.Debug(), "failover rcode matched, process to next upstream")
continue
}
// set compression, as it is not set by default when unpacking
answer.Compress = true
if p.cache != nil {
ttl := ttlFromMsg(answer)
now := time.Now()
@@ -349,8 +391,58 @@ func ttlFromMsg(msg *dns.Msg) uint32 {
return 0
}
var osUpstreamConfig = &ctrld.UpstreamConfig{
Name: "OS resolver",
Type: ctrld.ResolverTypeOS,
Timeout: 2000,
func needLocalIPv6Listener() bool {
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
// listen on ::1, then spawn a listener for receiving DNS requests.
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
}
func dnsListenAddress(lcNum string, lc *ctrld.ListenerConfig) string {
if addr := router.ListenAddress(); setupRouter && addr != "" && lcNum == "0" {
return addr
}
return net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
}
func macFromMsg(msg *dns.Msg) string {
if opt := msg.IsEdns0(); opt != nil {
for _, s := range opt.Option {
switch e := s.(type) {
case *dns.EDNS0_LOCAL:
if e.Code == EDNS0_OPTION_MAC {
return net.HardwareAddr(e.Data).String()
}
}
}
}
return ""
}
// runDNSServer starts a DNS server for given address and network,
// with the given handler. It ensures the server has started listening.
// Any error will be reported to the caller via returned channel.
//
// It's the caller responsibility to call Shutdown to close the server.
func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-chan error) {
s := &dns.Server{
Addr: addr,
Net: network,
Handler: handler,
}
waitLock := sync.Mutex{}
waitLock.Lock()
s.NotifyStartedFunc = waitLock.Unlock
errCh := make(chan error)
go func() {
defer close(errCh)
if err := s.ListenAndServe(); err != nil {
waitLock.Unlock()
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
errCh <- err
}
}()
waitLock.Lock()
return s, errCh
}

View File

@@ -86,17 +86,17 @@ func Test_prog_upstreamFor(t *testing.T) {
domain string
upstreams []string
matched bool
testLogMsg string
}{
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true},
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true},
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true},
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false},
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
for _, network := range []string{"udp", "tcp"} {
var (
addr net.Addr
@@ -114,6 +114,9 @@ func Test_prog_upstreamFor(t *testing.T) {
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
assert.Equal(t, tc.matched, matched)
assert.Equal(t, tc.upstreams, upstreams)
if tc.testLogMsg != "" {
assert.Contains(t, logOutput.String(), tc.testLogMsg)
}
}
})
}
@@ -152,3 +155,39 @@ func TestCache(t *testing.T) {
assert.Equal(t, answer1.Rcode, got1.Rcode)
assert.Equal(t, answer2.Rcode, got2.Rcode)
}
func Test_macFromMsg(t *testing.T) {
tests := []struct {
name string
mac string
wantMac bool
}{
{"has mac", "4c:20:b8:ab:87:1b", true},
{"no mac", "4c:20:b8:ab:87:1b", false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
hw, err := net.ParseMAC(tc.mac)
if err != nil {
t.Fatal(err)
}
m := new(dns.Msg)
m.SetQuestion(selfCheckFQDN+".", dns.TypeA)
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
if tc.wantMac {
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
o.Option = append(o.Option, ec1)
}
m.Extra = append(m.Extra, o)
got := macFromMsg(m)
if tc.wantMac && got != tc.mac {
t.Errorf("mismatch, want: %q, got: %q", tc.mac, got)
}
if !tc.wantMac && got != "" {
t.Errorf("unexpected mac: %q", got)
}
})
}
}

View File

@@ -2,7 +2,6 @@ package main
import (
"io"
"log"
"os"
"path/filepath"
"time"
@@ -26,18 +25,24 @@ var (
cacheSize int
cfg ctrld.Config
verbose int
silent bool
cdUID string
iface string
ifaceStartStop string
setupRouter bool
rootLogger = zerolog.New(io.Discard)
mainLog = rootLogger
cdUID string
iface string
ifaceStartStop string
mainLog = zerolog.New(io.Discard)
consoleWriter zerolog.ConsoleWriter
)
func main() {
ctrld.InitConfig(v, "ctrld")
initCLI()
initRouterCLI()
if err := rootCmd.Execute(); err != nil {
mainLog.Error().Msg(err.Error())
os.Exit(1)
}
}
func normalizeLogFilePath(logFilePath string) string {
@@ -47,44 +52,65 @@ func normalizeLogFilePath(logFilePath string) string {
if homedir != "" {
return filepath.Join(homedir, logFilePath)
}
dir, _ := os.UserHomeDir()
dir, _ := userHomeDir()
if dir == "" {
return logFilePath
}
return filepath.Join(dir, logFilePath)
}
func initConsoleLogging() {
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
w.TimeFormat = time.StampMilli
})
multi := zerolog.MultiLevelWriter(consoleWriter)
mainLog = mainLog.Output(multi).With().Timestamp().Logger()
switch {
case silent:
zerolog.SetGlobalLevel(zerolog.NoLevel)
case verbose == 1:
zerolog.SetGlobalLevel(zerolog.InfoLevel)
case verbose > 1:
zerolog.SetGlobalLevel(zerolog.DebugLevel)
default:
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
}
}
func initLogging() {
writers := []io.Writer{io.Discard}
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
// Create parent directory if necessary.
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
log.Printf("failed to create log path: %v", err)
mainLog.Error().Msgf("failed to create log path: %v", err)
os.Exit(1)
}
// Backup old log file with .1 suffix.
if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) {
log.Printf("could not backup old log file: %v", err)
mainLog.Error().Msgf("could not backup old log file: %v", err)
}
logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_RDWR, os.FileMode(0o600))
if err != nil {
log.Printf("failed to create log file: %v", err)
mainLog.Error().Msgf("failed to create log file: %v", err)
os.Exit(1)
}
writers = append(writers, logFile)
}
consoleWriter := zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
w.TimeFormat = time.StampMilli
})
writers = append(writers, consoleWriter)
multi := zerolog.MultiLevelWriter(writers...)
mainLog = mainLog.Output(multi).With().Timestamp().Logger()
// TODO: find a better way.
ctrld.ProxyLog = mainLog
zerolog.SetGlobalLevel(zerolog.InfoLevel)
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
logLevel := cfg.Service.LogLevel
if verbose > 1 {
switch {
case silent:
zerolog.SetGlobalLevel(zerolog.NoLevel)
return
case verbose == 1:
logLevel = "info"
case verbose > 1:
logLevel = "debug"
}
if logLevel == "" {

16
cmd/ctrld/main_test.go Normal file
View File

@@ -0,0 +1,16 @@
package main
import (
"os"
"strings"
"testing"
"github.com/rs/zerolog"
)
var logOutput strings.Builder
func TestMain(m *testing.M) {
mainLog = zerolog.New(&logOutput)
os.Exit(m.Run())
}

View File

@@ -0,0 +1,27 @@
package main
import (
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
func (p *prog) watchLinkState() {
ch := make(chan netlink.LinkUpdate)
done := make(chan struct{})
defer close(done)
if err := netlink.LinkSubscribe(ch, done); err != nil {
mainLog.Warn().Err(err).Msg("could not subscribe link")
return
}
for lu := range ch {
if lu.Change == 0xFFFFFFFF {
continue
}
if lu.Change&unix.IFF_UP != 0 {
mainLog.Debug().Msgf("link state changed, re-bootstrapping")
for _, uc := range p.cfg.Upstream {
uc.ReBootstrap()
}
}
}
}

View File

@@ -0,0 +1,5 @@
//go:build !linux
package main
func (p *prog) watchLinkState() {}

View File

@@ -112,7 +112,7 @@ func resetDNS(iface *net.Interface) (err error) {
}
// TODO(cuonglm): handle DHCPv6 properly.
if ctrldnet.SupportsIPv6() {
if ctrldnet.IPv6Available(ctx) {
c := client6.NewClient()
conversation, err := c.Exchange(iface.Name)
if err != nil {

View File

@@ -14,6 +14,7 @@ import (
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/dnscache"
"github.com/Control-D-Inc/ctrld/internal/router"
)
var logf = func(format string, args ...any) {
@@ -25,6 +26,7 @@ var errWindowsAddrInUse = syscall.Errno(0x2740)
var svcConfig = &service.Config{
Name: "ctrld",
DisplayName: "Control-D Helper Service",
Option: service.KeyValue{},
}
type prog struct {
@@ -72,13 +74,16 @@ func (p *prog) run() {
uc.Init()
if uc.BootstrapIP == "" {
uc.SetupBootstrapIP()
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n)
mainLog.Info().Msgf("Bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
} else {
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap IP for upstream.%s", n)
}
uc.SetCertPool(rootCertPool)
uc.SetupTransport()
}
go p.watchLinkState()
for listenerNum := range p.cfg.Listener {
p.cfg.Listener[listenerNum].Init()
go func(listenerNum string) {
@@ -86,8 +91,7 @@ func (p *prog) run() {
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
mainLog.Error().Msgf("missing upstream config for: [listener.%s]", listenerNum)
return
mainLog.Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
mainLog.Info().Msgf("Starting DNS server on listener.%s: %s", listenerNum, addr)
@@ -135,6 +139,10 @@ func (p *prog) Stop(s service.Service) error {
mainLog.Error().Err(err).Msg("de-allocate ip failed")
return err
}
p.preStop()
if err := router.Stop(); err != nil {
mainLog.Warn().Err(err).Msg("problem occurred while stopping router")
}
mainLog.Info().Msg("Service stopped")
close(p.stopCh)
return nil
@@ -164,6 +172,12 @@ func (p *prog) deAllocateIP() error {
}
func (p *prog) setDNS() {
switch router.Name() {
case router.DDWrt, router.OpenWrt, router.Ubios:
// On router, ctrld run as a DNS forwarder, it does not have to change system DNS.
// Except for Merlin, which has WAN DNS setup on boot for NTP.
return
}
if cfg.Listener == nil || cfg.Listener["0"] == nil {
return
}
@@ -192,6 +206,11 @@ func (p *prog) setDNS() {
}
func (p *prog) resetDNS() {
switch router.Name() {
case router.DDWrt, router.OpenWrt, router.Ubios:
// See comment in p.setDNS method.
return
}
if iface == "" {
return
}

23
cmd/ctrld/prog_darwin.go Normal file
View File

@@ -0,0 +1,23 @@
package main
import (
"github.com/kardianos/service"
)
func (p *prog) preRun() {
if !service.Interactive() {
p.setDNS()
}
}
func setDependencies(svc *service.Config) {}
func setWorkingDirectory(svc *service.Config, dir string) {
svc.WorkingDirectory = dir
}
func (p *prog) preStop() {
if !service.Interactive() {
p.resetDNS()
}
}

View File

@@ -18,3 +18,5 @@ func setDependencies(svc *service.Config) {
}
func setWorkingDirectory(svc *service.Config, dir string) {}
func (p *prog) preStop() {}

View File

@@ -22,3 +22,5 @@ func setDependencies(svc *service.Config) {
func setWorkingDirectory(svc *service.Config, dir string) {
svc.WorkingDirectory = dir
}
func (p *prog) preStop() {}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !freebsd
//go:build !linux && !freebsd && !darwin
package main
@@ -12,3 +12,5 @@ func setWorkingDirectory(svc *service.Config, dir string) {
// WorkingDirectory is not supported on Windows.
svc.WorkingDirectory = dir
}
func (p *prog) preStop() {}

View File

@@ -1,18 +1,44 @@
package main
import (
"fmt"
"bytes"
"errors"
"os"
"os/exec"
"github.com/spf13/cobra"
"github.com/kardianos/service"
"github.com/Control-D-Inc/ctrld/internal/router"
)
func stderrMsg(msg string) {
_, _ = fmt.Fprintln(os.Stderr, msg)
func newService(s service.Service) service.Service {
// TODO: unify for other SysV system.
if router.IsGLiNet() {
return &sysV{s}
}
return s
}
func stdoutMsg(msg string) {
_, _ = fmt.Fprintln(os.Stdout, msg)
// sysV wraps a service.Service, and provide start/stop/status command
// base on "/etc/init.d/<service_name>".
//
// Use this on system wherer "service" command is not available, like GL.iNET router.
type sysV struct {
service.Service
}
func (s *sysV) Start() error {
_, err := exec.Command("/etc/init.d/ctrld", "start").CombinedOutput()
return err
}
func (s *sysV) Stop() error {
_, err := exec.Command("/etc/init.d/ctrld", "stop").CombinedOutput()
return err
}
func (s *sysV) Status() (service.Status, error) {
return unixSystemVServiceStatus()
}
type task struct {
@@ -21,25 +47,48 @@ type task struct {
}
func doTasks(tasks []task) bool {
var prevErr error
for _, task := range tasks {
if err := task.f(); err != nil {
if task.abortOnError {
stderrMsg(err.Error())
mainLog.Error().Msg(errors.Join(prevErr, err).Error())
return false
}
prevErr = err
}
}
return true
}
func checkHasElevatedPrivilege(cmd *cobra.Command, args []string) {
func checkHasElevatedPrivilege() {
ok, err := hasElevatedPrivilege()
if err != nil {
fmt.Printf("could not detect user privilege: %v", err)
mainLog.Error().Msgf("could not detect user privilege: %v", err)
return
}
if !ok {
fmt.Println("Please relaunch process with admin/root privilege.")
mainLog.Error().Msg("Please relaunch process with admin/root privilege.")
os.Exit(1)
}
}
func serviceStatus(s service.Service) (service.Status, error) {
status, err := s.Status()
if err != nil && service.Platform() == "unix-systemv" {
return unixSystemVServiceStatus()
}
return status, err
}
func unixSystemVServiceStatus() (service.Status, error) {
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
if err != nil {
return service.StatusUnknown, nil
}
switch string(bytes.TrimSpace(out)) {
case "running":
return service.StatusRunning, nil
default:
return service.StatusStopped, nil
}
}

408
config.go
View File

@@ -2,41 +2,58 @@ package ctrld
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync/atomic"
"sync"
"time"
"github.com/go-playground/validator/v10"
"github.com/miekg/dns"
"github.com/spf13/viper"
"golang.org/x/sync/singleflight"
"tailscale.com/logtail/backoff"
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
// SetConfigName set the config name that ctrld will look for.
func SetConfigName(v *viper.Viper, name string) {
v.SetConfigName(name)
const (
IpStackBoth = "both"
IpStackV4 = "v4"
IpStackV6 = "v6"
IpStackSplit = "split"
)
var controldParentDomains = []string{"controld.com", "controld.net", "controld.dev"}
// SetConfigName set the config name that ctrld will look for.
// DEPRECATED: use SetConfigNameWithPath instead.
func SetConfigName(v *viper.Viper, name string) {
configPath := "$HOME"
// viper has its own way to get user home directory: https://github.com/spf13/viper/blob/v1.14.0/util.go#L134
// To be consistent, we prefer os.UserHomeDir instead.
if homeDir, err := os.UserHomeDir(); err == nil {
configPath = homeDir
}
SetConfigNameWithPath(v, name, configPath)
}
// SetConfigNameWithPath set the config path and name that ctrld will look for.
func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
v.SetConfigName(name)
v.AddConfigPath(configPath)
v.AddConfigPath(".")
}
// InitConfig initializes default config values for given *viper.Viper instance.
func InitConfig(v *viper.Viper, name string) {
SetConfigName(v, name)
v.SetDefault("listener", map[string]*ListenerConfig{
"0": {
IP: "127.0.0.1",
@@ -75,6 +92,17 @@ type Config struct {
Upstream map[string]*UpstreamConfig `mapstructure:"upstream" toml:"upstream" validate:"min=1,dive"`
}
// HasUpstreamSendClientInfo reports whether the config has any upstream
// is configured to send client info to Control D DNS server.
func (c *Config) HasUpstreamSendClientInfo() bool {
for _, uc := range c.Upstream {
if uc.UpstreamSendClientInfo() {
return true
}
}
return false
}
// ServiceConfig specifies the general ctrld config.
type ServiceConfig struct {
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
@@ -96,24 +124,36 @@ type NetworkConfig struct {
// UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to.
type UpstreamConfig struct {
Name string `mapstructure:"name" toml:"name,omitempty"`
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
Domain string `mapstructure:"-" toml:"-"`
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
transport *http.Transport `mapstructure:"-" toml:"-"`
http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"`
Name string `mapstructure:"name" toml:"name,omitempty"`
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
Domain string `mapstructure:"-" toml:"-"`
IPStack string `mapstructure:"ip_stack" toml:"ip_stack,omitempty" validate:"ipstack"`
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
// The caller should not access this field directly.
// Use UpstreamSendClientInfo instead.
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
g singleflight.Group
bootstrapIPs []string
nextBootstrapIP atomic.Uint32
g singleflight.Group
mu sync.Mutex
bootstrapIPs []string
bootstrapIPs4 []string
bootstrapIPs6 []string
transport *http.Transport
transport4 *http.Transport
transport6 *http.Transport
http3RoundTripper http.RoundTripper
http3RoundTripper4 http.RoundTripper
http3RoundTripper6 http.RoundTripper
certPool *x509.CertPool
u *url.URL
}
// ListenerConfig specifies the networks configuration that ctrld will run on.
type ListenerConfig struct {
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"ip"`
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gt=0"`
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
}
@@ -136,20 +176,61 @@ type Rule map[string][]string
func (uc *UpstreamConfig) Init() {
if u, err := url.Parse(uc.Endpoint); err == nil {
uc.Domain = u.Host
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
uc.u = u
}
}
if uc.Domain != "" {
return
if uc.Domain == "" {
if !strings.Contains(uc.Endpoint, ":") {
uc.Domain = uc.Endpoint
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
}
host, _, _ := net.SplitHostPort(uc.Endpoint)
uc.Domain = host
if net.ParseIP(uc.Domain) != nil {
uc.BootstrapIP = uc.Domain
}
}
if uc.IPStack == "" {
if uc.isControlD() {
uc.IPStack = IpStackSplit
} else {
uc.IPStack = IpStackBoth
}
}
}
if !strings.Contains(uc.Endpoint, ":") {
uc.Domain = uc.Endpoint
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
// UpstreamSendClientInfo reports whether the upstream is
// configured to send client info to Control D DNS server.
//
// Client info includes:
// - MAC
// - Lan IP
// - Hostname
func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
if uc.SendClientInfo != nil && !(*uc.SendClientInfo) {
return false
}
host, _, _ := net.SplitHostPort(uc.Endpoint)
uc.Domain = host
if net.ParseIP(uc.Domain) != nil {
uc.BootstrapIP = uc.Domain
if uc.SendClientInfo == nil {
return true
}
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
if uc.isControlD() {
return true
}
}
return false
}
func (uc *UpstreamConfig) BootstrapIPs() []string {
return uc.bootstrapIPs
}
// SetCertPool sets the system cert pool used for TLS connections.
func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) {
uc.certPool = cp
}
// SetupBootstrapIP manually find all available IPs of the upstream.
@@ -161,70 +242,21 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
// SetupBootstrapIP manually find all available IPs of the upstream.
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
bootstrapIP := func(record dns.RR) string {
switch ar := record.(type) {
case *dns.A:
return ar.A.String()
case *dns.AAAA:
return ar.AAAA.String()
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 2*time.Second)
for {
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
if len(uc.bootstrapIPs) > 0 {
break
}
return ""
ProxyLog.Warn().Msg("could not resolve bootstrap IPs, retrying...")
b.BackOff(context.Background(), errors.New("no bootstrap IPs"))
}
resolver := &osResolver{nameservers: availableNameservers()}
if withBootstrapDNS {
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
}
ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", uc.Domain, resolver.nameservers)
timeoutMs := 2000
if uc.Timeout > 0 && uc.Timeout < timeoutMs {
timeoutMs = uc.Timeout
}
do := func(dnsType uint16) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
m := new(dns.Msg)
m.SetQuestion(uc.Domain+".", dnsType)
m.RecursionDesired = true
r, err := resolver.Resolve(ctx, m)
if err != nil {
ProxyLog.Error().Err(err).Str("type", dns.TypeToString[dnsType]).Msgf("could not resolve domain %s for upstream", uc.Domain)
return
for _, ip := range uc.bootstrapIPs {
if ctrldnet.IsIPv6(ip) {
uc.bootstrapIPs6 = append(uc.bootstrapIPs6, ip)
} else {
uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip)
}
if r.Rcode != dns.RcodeSuccess {
ProxyLog.Error().Msgf("could not resolve domain %q, return code: %s", uc.Domain, dns.RcodeToString[r.Rcode])
return
}
if len(r.Answer) == 0 {
ProxyLog.Error().Msg("no answer from bootstrap DNS server")
return
}
for _, a := range r.Answer {
ip := bootstrapIP(a)
if ip == "" {
continue
}
// Storing the ip to uc.bootstrapIPs list, so it can be selected later
// when retrying failed request due to network stack changed.
uc.bootstrapIPs = append(uc.bootstrapIPs, ip)
if uc.BootstrapIP == "" {
// Remember what's the current IP in bootstrap IPs list,
// so we can select next one upon re-bootstrapping.
uc.nextBootstrapIP.Add(1)
// If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip.
if !ctrldnet.SupportsIPv6() && ctrldnet.IsIPv6(ip) {
continue
}
uc.BootstrapIP = ip
}
}
}
// Find all A, AAAA records of the upstream.
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
do(dnsType)
}
ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs)
}
@@ -238,32 +270,6 @@ func (uc *UpstreamConfig) ReBootstrap() {
}
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
ProxyLog.Debug().Msg("re-bootstrapping upstream ip")
n := uint32(len(uc.bootstrapIPs))
if n == 0 {
uc.SetupBootstrapIP()
uc.setupTransportWithoutPingUpstream()
}
timeoutMs := 1000
if uc.Timeout > 0 && uc.Timeout < timeoutMs {
timeoutMs = uc.Timeout
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
hasIPv6 := ctrldnet.IPv6Available(ctx)
// Only attempt n times, because if there's no usable ip,
// the bootstrap ip will be kept as-is.
for i := uint32(0); i < n; i++ {
// Select the next ip in bootstrap ip list.
next := uc.nextBootstrapIP.Add(1)
ip := uc.bootstrapIPs[(next-1)%n]
if !hasIPv6 && ctrldnet.IsIPv6(ip) {
continue
}
uc.BootstrapIP = ip
break
}
uc.setupTransportWithoutPingUpstream()
return true, nil
})
@@ -291,31 +297,65 @@ func (uc *UpstreamConfig) SetupTransport() {
func (uc *UpstreamConfig) setupDOHTransport() {
uc.setupDOHTransportWithoutPingUpstream()
uc.pingUpstream()
go uc.pingUpstream()
}
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
uc.transport = http.DefaultTransport.(*http.Transport).Clone()
uc.transport.IdleConnTimeout = 5 * time.Second
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.IdleConnTimeout = 5 * time.Second
transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
dialerTimeoutMs := 2000
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
dialerTimeoutMs = uc.Timeout
}
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: dialerTimeout,
KeepAlive: dialerTimeout,
}
// if we have a bootstrap ip set, use it to avoid DNS lookup
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
if uc.BootstrapIP != "" {
if _, port, _ := net.SplitHostPort(addr); port != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
}
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
addr := net.JoinHostPort(uc.BootstrapIP, port)
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr)
return dialer.DialContext(ctx, network, addr)
}
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr)
return dialer.DialContext(ctx, network, addr)
pd := &ctrldnet.ParallelDialer{}
pd.Timeout = dialerTimeout
pd.KeepAlive = dialerTimeout
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
conn, err := pd.DialContext(ctx, network, dialAddrs)
if err != nil {
return nil, err
}
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", conn.RemoteAddr())
return conn, nil
}
return transport
}
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
switch uc.IPStack {
case IpStackBoth, "":
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
case IpStackV4:
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
case IpStackV6:
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
case IpStackSplit:
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if ctrldnet.IPv6Available(ctx) {
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
} else {
uc.transport6 = uc.transport4
}
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
}
}
@@ -334,6 +374,76 @@ func (uc *UpstreamConfig) pingUpstream() {
_, _ = dnsResolver.Resolve(ctx, msg)
}
func (uc *UpstreamConfig) isControlD() bool {
domain := uc.Domain
if domain == "" {
if u, err := url.Parse(uc.Endpoint); err == nil {
domain = u.Hostname()
}
}
for _, parent := range controldParentDomains {
if dns.IsSubDomain(parent, domain) {
return true
}
}
return false
}
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
uc.mu.Lock()
defer uc.mu.Unlock()
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.transport
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.transport4
default:
return uc.transport6
}
}
return uc.transport
}
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
switch uc.IPStack {
case IpStackBoth:
return pick(uc.bootstrapIPs)
case IpStackV4:
return pick(uc.bootstrapIPs4)
case IpStackV6:
return pick(uc.bootstrapIPs6)
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return pick(uc.bootstrapIPs4)
default:
return pick(uc.bootstrapIPs6)
}
}
return pick(uc.bootstrapIPs)
}
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
switch uc.IPStack {
case IpStackBoth:
return "tcp-tls", "udp"
case IpStackV4:
return "tcp4-tls", "udp4"
case IpStackV6:
return "tcp6-tls", "udp6"
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return "tcp4-tls", "udp4"
default:
return "tcp6-tls", "udp6"
}
}
return "tcp-tls", "udp"
}
// Init initialized necessary values for an ListenerConfig.
func (lc *ListenerConfig) Init() {
if lc.Policy != nil {
@@ -347,6 +457,8 @@ func (lc *ListenerConfig) Init() {
// ValidateConfig validates the given config.
func ValidateConfig(validate *validator.Validate, cfg *Config) error {
_ = validate.RegisterValidation("dnsrcode", validateDnsRcode)
_ = validate.RegisterValidation("ipstack", validateIpStack)
_ = validate.RegisterValidation("iporempty", validateIpOrEmpty)
return validate.Struct(cfg)
}
@@ -354,6 +466,23 @@ func validateDnsRcode(fl validator.FieldLevel) bool {
return dnsrcode.FromString(fl.Field().String()) != -1
}
func validateIpStack(fl validator.FieldLevel) bool {
switch fl.Field().String() {
case IpStackBoth, IpStackV4, IpStackV6, IpStackSplit, "":
return true
default:
return false
}
}
func validateIpOrEmpty(fl validator.FieldLevel) bool {
val := fl.Field().String()
if val == "" {
return true
}
return net.ParseIP(val) != nil
}
func defaultPortFor(typ string) string {
switch typ {
case ResolverTypeDOH, ResolverTypeDOH3:
@@ -366,21 +495,6 @@ func defaultPortFor(typ string) string {
return "53"
}
func availableNameservers() []string {
nss := nameservers()
n := 0
for _, ns := range nss {
ip, _, _ := net.SplitHostPort(ns)
// skipping invalid entry or ipv6 nameserver if ipv6 not available.
if ip == "" || (ctrldnet.IsIPv6(ip) && !ctrldnet.SupportsIPv6()) {
continue
}
nss[n] = ns
n++
}
return nss[:n]
}
// ResolverTypeFromEndpoint tries guessing the resolver type with a given endpoint
// using following rules:
//
@@ -404,3 +518,7 @@ func ResolverTypeFromEndpoint(endpoint string) string {
}
return ResolverTypeDOT
}
func pick(s []string) string {
return s[rand.Intn(len(s))]
}

View File

@@ -1,7 +1,10 @@
package ctrld
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
@@ -13,9 +16,180 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
}
uc.Init()
uc.setupBootstrapIP(false)
if uc.BootstrapIP == "" {
t.Log(availableNameservers())
if len(uc.bootstrapIPs) == 0 {
t.Log(nameservers())
t.Fatal("could not bootstrap ip without bootstrap DNS")
}
t.Log(uc)
}
func TestUpstreamConfig_Init(t *testing.T) {
u1, _ := url.Parse("https://example.com")
u2, _ := url.Parse("https://example.com?k=v")
tests := []struct {
name string
uc *UpstreamConfig
expected *UpstreamConfig
}{
{
"doh+doh3",
&UpstreamConfig{
Name: "doh",
Type: "doh",
Endpoint: "https://example.com",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&UpstreamConfig{
Name: "doh",
Type: "doh",
Endpoint: "https://example.com",
BootstrapIP: "",
Domain: "example.com",
Timeout: 0,
IPStack: IpStackBoth,
u: u1,
},
},
{
"doh+doh3 with query param",
&UpstreamConfig{
Name: "doh",
Type: "doh",
Endpoint: "https://example.com?k=v",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&UpstreamConfig{
Name: "doh",
Type: "doh",
Endpoint: "https://example.com?k=v",
BootstrapIP: "",
Domain: "example.com",
Timeout: 0,
IPStack: IpStackBoth,
u: u2,
},
},
{
"dot+doq",
&UpstreamConfig{
Name: "dot",
Type: "dot",
Endpoint: "freedns.controld.com:8853",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&UpstreamConfig{
Name: "dot",
Type: "dot",
Endpoint: "freedns.controld.com:8853",
BootstrapIP: "",
Domain: "freedns.controld.com",
Timeout: 0,
IPStack: IpStackSplit,
},
},
{
"dot+doq without port",
&UpstreamConfig{
Name: "dot",
Type: "dot",
Endpoint: "freedns.controld.com",
BootstrapIP: "",
Domain: "",
Timeout: 0,
IPStack: IpStackSplit,
},
&UpstreamConfig{
Name: "dot",
Type: "dot",
Endpoint: "freedns.controld.com:853",
BootstrapIP: "",
Domain: "freedns.controld.com",
Timeout: 0,
IPStack: IpStackSplit,
},
},
{
"legacy",
&UpstreamConfig{
Name: "legacy",
Type: "legacy",
Endpoint: "1.2.3.4:53",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&UpstreamConfig{
Name: "legacy",
Type: "legacy",
Endpoint: "1.2.3.4:53",
BootstrapIP: "1.2.3.4",
Domain: "1.2.3.4",
Timeout: 0,
IPStack: IpStackBoth,
},
},
{
"legacy without port",
&UpstreamConfig{
Name: "legacy",
Type: "legacy",
Endpoint: "1.2.3.4",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&UpstreamConfig{
Name: "legacy",
Type: "legacy",
Endpoint: "1.2.3.4:53",
BootstrapIP: "1.2.3.4",
Domain: "1.2.3.4",
Timeout: 0,
IPStack: IpStackBoth,
},
},
{
"doh+doh3 with send client info set",
&UpstreamConfig{
Name: "doh",
Type: "doh",
Endpoint: "https://example.com?k=v",
BootstrapIP: "",
Domain: "",
Timeout: 0,
SendClientInfo: ptrBool(false),
IPStack: IpStackBoth,
},
&UpstreamConfig{
Name: "doh",
Type: "doh",
Endpoint: "https://example.com?k=v",
BootstrapIP: "",
Domain: "example.com",
Timeout: 0,
SendClientInfo: ptrBool(false),
IPStack: IpStackBoth,
u: u2,
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init()
assert.Equal(t, tc.expected, tc.uc)
})
}
}
func ptrBool(b bool) *bool {
return &b
}

View File

@@ -5,40 +5,152 @@ package ctrld
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"sync"
"time"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
func (uc *UpstreamConfig) setupDOH3Transport() {
uc.setupDOH3TransportWithoutPingUpstream()
uc.pingUpstream()
go uc.pingUpstream()
}
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
rt := &http3.RoundTripper{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
domain := addr
_, port, _ := net.SplitHostPort(addr)
// if we have a bootstrap ip set, use it to avoid DNS lookup
if uc.BootstrapIP != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
ProxyLog.Debug().Msgf("sending doh3 request to: %s", addr)
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
}
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
pd := &quicParallelDialer{}
conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg)
if err != nil {
return nil, err
}
ProxyLog.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
return rt
}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
rt := &http3.RoundTripper{}
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
host := addr
ProxyLog.Debug().Msgf("debug dial context D0H3 %s - %s", addr, bootstrapDNS)
// if we have a bootstrap ip set, use it to avoid DNS lookup
if uc.BootstrapIP != "" {
if _, port, _ := net.SplitHostPort(addr); port != "" {
addr = net.JoinHostPort(uc.BootstrapIP, port)
}
ProxyLog.Debug().Msgf("sending doh3 request to: %s", addr)
}
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
uc.mu.Lock()
defer uc.mu.Unlock()
switch uc.IPStack {
case IpStackBoth, "":
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
case IpStackV4:
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
case IpStackV6:
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
case IpStackSplit:
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if ctrldnet.IPv6Available(ctx) {
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
} else {
uc.http3RoundTripper6 = uc.http3RoundTripper4
}
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
}
}
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
uc.mu.Lock()
defer uc.mu.Unlock()
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.http3RoundTripper
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.http3RoundTripper4
default:
return uc.http3RoundTripper6
}
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg)
}
return uc.http3RoundTripper
}
// Putting the code for quic parallel dialer here:
//
// - quic dialer is different with net.Dialer
// - simplification for quic free version
type parallelDialerResult struct {
conn quic.EarlyConnection
err error
}
type quicParallelDialer struct{}
func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
if len(addrs) == 0 {
return nil, errors.New("empty addresses")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
ch := make(chan *parallelDialerResult, len(addrs))
var wg sync.WaitGroup
wg.Add(len(addrs))
go func() {
wg.Wait()
close(ch)
}()
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
uc.http3RoundTripper = rt
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
ch <- &parallelDialerResult{conn: nil, err: err}
return
}
conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
ch <- &parallelDialerResult{conn: conn, err: err}
}(addr)
}
errs := make([]error, 0, len(addrs))
for res := range ch {
if res.err == nil {
cancel()
return res.conn, res.err
}
errs = append(errs, res.err)
}
return nil, errors.Join(errs...)
}

View File

@@ -2,6 +2,9 @@
package ctrld
import "net/http"
func (uc *UpstreamConfig) setupDOH3Transport() {}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {}
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { return nil }

View File

@@ -24,10 +24,12 @@ func TestLoadConfig(t *testing.T) {
assert.Contains(t, cfg.Network, "0")
assert.Contains(t, cfg.Network, "1")
assert.Len(t, cfg.Upstream, 3)
assert.Len(t, cfg.Upstream, 4)
assert.Contains(t, cfg.Upstream, "0")
assert.Contains(t, cfg.Upstream, "1")
assert.Contains(t, cfg.Upstream, "2")
assert.Contains(t, cfg.Upstream, "3")
assert.NotNil(t, cfg.Upstream["3"].SendClientInfo)
assert.Len(t, cfg.Listener, 2)
assert.Contains(t, cfg.Listener, "0")
@@ -42,6 +44,8 @@ func TestLoadConfig(t *testing.T) {
assert.Len(t, cfg.Listener["0"].Policy.Rules, 2)
assert.Contains(t, cfg.Listener["0"].Policy.Rules[0], "*.ru")
assert.Contains(t, cfg.Listener["0"].Policy.Rules[1], "*.local.host")
assert.True(t, cfg.HasUpstreamSendClientInfo())
}
func TestLoadDefaultConfig(t *testing.T) {
@@ -61,6 +65,7 @@ func TestConfigValidation(t *testing.T) {
{"invalid Config", &ctrld.Config{}, true},
{"default Config", defaultConfig(t), false},
{"sample Config", testhelper.SampleConfig(t), false},
{"empty listener IP", emptyListenerIP(t), false},
{"invalid cidr", invalidNetworkConfig(t), true},
{"invalid upstream type", invalidUpstreamType(t), true},
{"invalid upstream timeout", invalidUpstreamTimeout(t), true},
@@ -130,9 +135,15 @@ func invalidListenerIP(t *testing.T) *ctrld.Config {
return cfg
}
func emptyListenerIP(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Listener["0"].IP = ""
return cfg
}
func invalidListenerPort(t *testing.T) *ctrld.Config {
cfg := defaultConfig(t)
cfg.Listener["0"].Port = 0
cfg.Listener["0"].Port = -1
return cfg
}
@@ -165,116 +176,3 @@ func configWithInvalidRcodes(t *testing.T) *ctrld.Config {
}
return cfg
}
func TestUpstreamConfig_Init(t *testing.T) {
tests := []struct {
name string
uc *ctrld.UpstreamConfig
expected *ctrld.UpstreamConfig
}{
{
"doh+doh3",
&ctrld.UpstreamConfig{
Name: "doh",
Type: "doh",
Endpoint: "https://example.com",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&ctrld.UpstreamConfig{
Name: "doh",
Type: "doh",
Endpoint: "https://example.com",
BootstrapIP: "",
Domain: "example.com",
Timeout: 0,
},
},
{
"dot+doq",
&ctrld.UpstreamConfig{
Name: "dot",
Type: "dot",
Endpoint: "freedns.controld.com:8853",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&ctrld.UpstreamConfig{
Name: "dot",
Type: "dot",
Endpoint: "freedns.controld.com:8853",
BootstrapIP: "",
Domain: "freedns.controld.com",
Timeout: 0,
},
},
{
"dot+doq without port",
&ctrld.UpstreamConfig{
Name: "dot",
Type: "dot",
Endpoint: "freedns.controld.com",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&ctrld.UpstreamConfig{
Name: "dot",
Type: "dot",
Endpoint: "freedns.controld.com:853",
BootstrapIP: "",
Domain: "freedns.controld.com",
Timeout: 0,
},
},
{
"legacy",
&ctrld.UpstreamConfig{
Name: "legacy",
Type: "legacy",
Endpoint: "1.2.3.4:53",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&ctrld.UpstreamConfig{
Name: "legacy",
Type: "legacy",
Endpoint: "1.2.3.4:53",
BootstrapIP: "1.2.3.4",
Domain: "1.2.3.4",
Timeout: 0,
},
},
{
"legacy without port",
&ctrld.UpstreamConfig{
Name: "legacy",
Type: "legacy",
Endpoint: "1.2.3.4",
BootstrapIP: "",
Domain: "",
Timeout: 0,
},
&ctrld.UpstreamConfig{
Name: "legacy",
Type: "legacy",
Endpoint: "1.2.3.4:53",
BootstrapIP: "1.2.3.4",
Domain: "1.2.3.4",
Timeout: 0,
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tc.uc.Init()
assert.Equal(t, tc.expected, tc.uc)
})
}
}

View File

@@ -227,9 +227,27 @@ Value `0` means no timeout.
The protocol that `ctrld` will use to send DNS requests to upstream.
- Type: string
- required: yes
- Required: yes
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os`
### ip_stack
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.
- Type: string
- Required: no
- Valid values:
- `both`: using either ipv4 or ipv6.
- `v4`: only dial upstream via IPv4, never dial IPv6.
- `v6`: only dial upstream via IPv6, never dial IPv4.
- `split`:
- If `A` record is requested -> dial via ipv4.
- If `AAAA` or any other record is requested -> dial ipv6 (if available, otherwise ipv4)
If `ip_stack` is empty, or undefined:
- Default value is `both` for non-Control D resolvers.
- Default value is `split` for Control D resolvers.
## Network
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.
@@ -271,16 +289,14 @@ The `[listener]` section specifies the ip and port of the local DNS server. You
```
### ip
IP address that serves the incoming requests.
IP address that serves the incoming requests. If `ip` is empty, ctrld will listen on all available addresses.
- Type: string
- Required: yes
- Type: ip address
### port
Port number that the listener will listen on for incoming requests.
Port number that the listener will listen on for incoming requests. If `port` is `0`, a random available port will be chosen.
- Type: number
- Required: yes
### restricted
If set to `true` makes the listener `REFUSE` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.

69
doh.go
View File

@@ -7,25 +7,36 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"github.com/miekg/dns"
)
const (
DoHMacHeader = "x-cd-mac"
DoHIPHeader = "x-cd-ip"
DoHHostHeader = "x-cd-host"
headerApplicationDNS = "application/dns-message"
)
func newDohResolver(uc *UpstreamConfig) *dohResolver {
r := &dohResolver{
endpoint: uc.Endpoint,
endpoint: uc.u,
isDoH3: uc.Type == ResolverTypeDOH3,
transport: uc.transport,
http3RoundTripper: uc.http3RoundTripper,
sendClientInfo: uc.UpstreamSendClientInfo(),
uc: uc,
}
return r
}
type dohResolver struct {
endpoint string
uc *UpstreamConfig
endpoint *url.URL
isDoH3 bool
transport *http.Transport
http3RoundTripper http.RoundTripper
sendClientInfo bool
}
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
@@ -33,26 +44,34 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
if err != nil {
return nil, err
}
enc := base64.RawURLEncoding.EncodeToString(data)
url := fmt.Sprintf("%s?dns=%s", r.endpoint, enc)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
query := r.endpoint.Query()
query.Add("dns", enc)
endpoint := *r.endpoint
endpoint.RawQuery = query.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
if err != nil {
return nil, fmt.Errorf("could not create request: %w", err)
}
req.Header.Set("Content-Type", "application/dns-message")
req.Header.Set("Accept", "application/dns-message")
c := http.Client{Transport: r.transport}
addHeader(ctx, req, r.sendClientInfo)
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
c := http.Client{Transport: r.uc.dohTransport(dnsTyp)}
if r.isDoH3 {
if r.http3RoundTripper == nil {
transport := r.uc.doh3Transport(dnsTyp)
if transport == nil {
return nil, errors.New("DoH3 is not supported")
}
c.Transport = r.http3RoundTripper
c.Transport = transport
}
resp, err := c.Do(req)
if err != nil {
if r.isDoH3 {
if closer, ok := r.http3RoundTripper.(io.Closer); ok {
if closer, ok := c.Transport.(io.Closer); ok {
closer.Close()
}
}
@@ -70,5 +89,27 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
}
answer := new(dns.Msg)
return answer, answer.Unpack(buf)
if err := answer.Unpack(buf); err != nil {
return nil, fmt.Errorf("answer.Unpack: %w", err)
}
return answer, nil
}
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
req.Header.Set("Content-Type", headerApplicationDNS)
req.Header.Set("Accept", headerApplicationDNS)
if sendClientInfo {
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
if ci.Mac != "" {
req.Header.Set(DoHMacHeader, ci.Mac)
}
if ci.IP != "" {
req.Header.Set(DoHIPHeader, ci.IP)
}
if ci.Hostname != "" {
req.Header.Set(DoHHostHeader, ci.Hostname)
}
}
}
Log(ctx, ProxyLog.Debug().Interface("header", req.Header), "sending request header")
}

14
doq.go
View File

@@ -20,11 +20,17 @@ type doqResolver struct {
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
endpoint := r.uc.Endpoint
tlsConfig := &tls.Config{NextProtos: []string{"doq"}}
if r.uc.BootstrapIP != "" {
tlsConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
ip := r.uc.BootstrapIP
if ip == "" {
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
ip = r.uc.bootstrapIPForDNSType(dnsTyp)
}
tlsConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(ip, port)
return resolve(ctx, msg, endpoint, tlsConfig)
}

15
dot.go
View File

@@ -14,18 +14,25 @@ type dotResolver struct {
func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// The dialer is used to prevent bootstrapping cycle.
// If r.endpoing is set to dns.controld.dev, we need to resolve
// If r.endpoint is set to dns.controld.dev, we need to resolve
// dns.controld.dev first. By using a dialer with custom resolver,
// we ensure that we can always resolve the bootstrap domain
// regardless of the machine DNS status.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: "tcp-tls",
Dialer: dialer,
Net: tcpNet,
Dialer: dialer,
TLSConfig: &tls.Config{RootCAs: r.uc.certPool},
}
endpoint := r.uc.Endpoint
if r.uc.BootstrapIP != "" {
dnsClient.TLSConfig = &tls.Config{ServerName: r.uc.Domain}
dnsClient.TLSConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
}

View File

@@ -1,43 +0,0 @@
package ctrld
// TODO(cuonglm): use stdlib once we bump minimum version to 1.20
func joinErrors(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}
type joinError struct {
errs []error
}
func (e *joinError) Error() string {
var b []byte
for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}
return string(b)
}
func (e *joinError) Unwrap() []error {
return e.errs
}

8
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/Control-D-Inc/ctrld
go 1.20
require (
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534
github.com/coreos/go-systemd/v22 v22.5.0
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19
github.com/frankban/quicktest v1.14.3
github.com/fsnotify/fsnotify v1.6.0
@@ -21,6 +21,7 @@ require (
github.com/spf13/cobra v1.4.0
github.com/spf13/viper v1.14.0
github.com/stretchr/testify v1.8.1
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
golang.org/x/net v0.7.0
golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0
@@ -65,6 +66,7 @@ require (
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.4.1 // indirect
github.com/u-root/uio v0.0.0-20221213070652-c3537552635f // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect
golang.org/x/crypto v0.6.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
@@ -75,3 +77,7 @@ require (
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace github.com/mr-karan/doggo => github.com/Windscribe/doggo v0.0.0-20220919152748-2c118fc391f8
replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be

17
go.sum
View File

@@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik=
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w=
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 h1:Kk6a4nehpJ3UuJRqlA3JxYxBZEqCeOmATOvrbT4p9RA=
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
@@ -50,8 +52,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534 h1:rtAn27wIbmOGUs7RIbVgPEjb31ehTVniDwPGXyMxm5U=
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs=
@@ -247,9 +249,7 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rogpeppe/go-internal v1.8.1-0.20211023094830-115ce09fd6b4 h1:Ha8xCaq6ln1a+R91Km45Oq6lPXj2Mla6CRJYcuV2h1w=
github.com/rogpeppe/go-internal v1.8.1-0.20211023094830-115ce09fd6b4/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
@@ -282,6 +282,11 @@ github.com/subosito/gotenv v1.4.1 h1:jyEFiXpy21Wm81FBN71l9VoMMV8H8jG+qIK3GCpY6Qs
github.com/subosito/gotenv v1.4.1/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
github.com/u-root/uio v0.0.0-20221213070652-c3537552635f h1:dpx1PHxYqAnXzbryJrWP1NQLzEjwcVgFLhkknuFQ7ww=
github.com/u-root/uio v0.0.0-20221213070652-c3537552635f/go.mod h1:IogEAUBXDEwX7oR/BMmCctShYs80ql4hF0ySdzGxf7E=
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 h1:8mhqcHPqTMhSPoslhGYihEgSfc77+7La1P6kiB6+9So=
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -429,6 +434,7 @@ golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -437,6 +443,7 @@ golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

3372
internal/certs/cacert.pem Normal file

File diff suppressed because it is too large Load Diff

22
internal/certs/root_ca.go Normal file
View File

@@ -0,0 +1,22 @@
package certs
import (
"crypto/x509"
_ "embed"
"sync"
)
var (
//go:embed cacert.pem
caRoots []byte
caCertPoolOnce sync.Once
caCertPool *x509.CertPool
)
func CACertPool() *x509.CertPool {
caCertPoolOnce.Do(func() {
caCertPool = x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caRoots)
})
return caCertPool
}

View File

@@ -0,0 +1,27 @@
package certs
import (
"crypto/tls"
"net/http"
"testing"
"time"
)
func TestCACertPool(t *testing.T) {
c := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: CACertPool(),
},
},
Timeout: 2 * time.Second,
}
resp, err := c.Get("https://freedns.controld.com/p1")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if !resp.TLS.HandshakeComplete {
t.Error("TLS handshake is not complete")
}
}

View File

@@ -3,17 +3,17 @@ package controld
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/miekg/dns"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/certs"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
"github.com/Control-D-Inc/ctrld/internal/router"
)
const (
@@ -22,14 +22,12 @@ const (
InvalidConfigCode = 40401
)
var (
resolveAPIDomainOnce sync.Once
apiDomainIP string
)
// ResolverConfig represents Control D resolver data.
type ResolverConfig struct {
DOH string `json:"doh"`
DOH string `json:"doh"`
Ctrld struct {
CustomConfig string `json:"custom_config"`
} `json:"ctrld"`
Exclude []string `json:"exclude"`
}
@@ -56,7 +54,7 @@ type utilityRequest struct {
}
// FetchResolverConfig fetch Control D config for given uid.
func FetchResolverConfig(uid string) (*ResolverConfig, error) {
func FetchResolverConfig(uid, version string) (*ResolverConfig, error) {
body, _ := json.Marshal(utilityRequest{UID: uid})
req, err := http.NewRequest("POST", resolverDataURL, bytes.NewReader(body))
if err != nil {
@@ -64,55 +62,28 @@ func FetchResolverConfig(uid string) (*ResolverConfig, error) {
}
q := req.URL.Query()
q.Set("platform", "ctrld")
q.Set("version", version)
req.URL.RawQuery = q.Encode()
req.Header.Add("Content-Type", "application/json")
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
// We experiment hanging in TLS handshake when connecting to ControlD API
// with ipv6. So prefer ipv4 if available.
proto := "tcp6"
if ctrldnet.SupportsIPv4() {
proto = "tcp4"
ips := ctrld.LookupIP(apiDomain)
if len(ips) == 0 {
ctrld.ProxyLog.Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr)
return ctrldnet.Dialer.DialContext(ctx, network, addr)
}
resolveAPIDomainOnce.Do(func() {
r, err := ctrld.NewResolver(&ctrld.UpstreamConfig{Type: ctrld.ResolverTypeOS})
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ctrld.ProxyLog.Debug().Msgf("API IPs: %v", ips)
_, port, _ := net.SplitHostPort(addr)
addrs := make([]string, len(ips))
for i := range ips {
addrs[i] = net.JoinHostPort(ips[i], port)
}
d := &ctrldnet.ParallelDialer{}
return d.DialContext(ctx, network, addrs)
}
msg := new(dns.Msg)
dnsType := dns.TypeAAAA
if proto == "tcp4" {
dnsType = dns.TypeA
}
msg.SetQuestion(apiDomain+".", dnsType)
msg.RecursionDesired = true
answer, err := r.Resolve(ctx, msg)
if err != nil {
return
}
if answer.Rcode != dns.RcodeSuccess || len(answer.Answer) == 0 {
return
}
for _, record := range answer.Answer {
switch ar := record.(type) {
case *dns.A:
apiDomainIP = ar.A.String()
return
case *dns.AAAA:
apiDomainIP = ar.AAAA.String()
return
}
}
})
if apiDomainIP != "" {
if _, port, _ := net.SplitHostPort(addr); port != "" {
return ctrldnet.Dialer.DialContext(ctx, proto, net.JoinHostPort(apiDomainIP, port))
}
}
return ctrldnet.Dialer.DialContext(ctx, proto, addr)
if router.Name() == router.DDWrt {
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
}
client := http.Client{
Timeout: 10 * time.Second,

View File

@@ -9,8 +9,6 @@ import (
"github.com/stretchr/testify/require"
)
const utilityURL = "https://api.controld.com/utility"
func TestFetchResolverConfig(t *testing.T) {
tests := []struct {
name string
@@ -24,7 +22,7 @@ func TestFetchResolverConfig(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := FetchResolverConfig(tc.uid)
got, err := FetchResolverConfig(tc.uid, "dev-test")
require.False(t, (err != nil) != tc.wantErr, err)
if !tc.wantErr {
assert.NotEmpty(t, got.DOH)

View File

@@ -2,6 +2,7 @@ package net
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
@@ -12,7 +13,6 @@ import (
const (
controldIPv6Test = "ipv6.controld.io"
controldIPv4Test = "ipv4.controld.io"
bootstrapDNS = "76.76.2.0:53"
)
@@ -37,8 +37,6 @@ var probeStackDialer = &net.Dialer{
var (
stackOnce atomic.Pointer[sync.Once]
ipv4Enabled bool
ipv6Enabled bool
canListenIPv6Local bool
hasNetworkUp bool
)
@@ -47,13 +45,8 @@ func init() {
stackOnce.Store(new(sync.Once))
}
func supportIPv4() bool {
_, err := probeStackDialer.Dial("tcp4", net.JoinHostPort(controldIPv4Test, "80"))
return err == nil
}
func supportIPv6(ctx context.Context) bool {
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "80"))
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "443"))
return err == nil
}
@@ -75,8 +68,6 @@ func probeStack() {
b.BackOff(context.Background(), err)
}
}
ipv4Enabled = supportIPv4()
ipv6Enabled = supportIPv6(context.Background())
canListenIPv6Local = supportListenIPv6Local()
}
@@ -85,16 +76,6 @@ func Up() bool {
return hasNetworkUp
}
func SupportsIPv4() bool {
stackOnce.Load().Do(probeStack)
return ipv4Enabled
}
func SupportsIPv6() bool {
stackOnce.Load().Do(probeStack)
return ipv6Enabled
}
func SupportsIPv6ListenLocal() bool {
stackOnce.Load().Do(probeStack)
return canListenIPv6Local
@@ -112,3 +93,47 @@ func IsIPv6(ip string) bool {
parsedIP := net.ParseIP(ip)
return parsedIP != nil && parsedIP.To4() == nil && parsedIP.To16() != nil
}
type parallelDialerResult struct {
conn net.Conn
err error
}
type ParallelDialer struct {
net.Dialer
}
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) {
if len(addrs) == 0 {
return nil, errors.New("empty addresses")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
ch := make(chan *parallelDialerResult, len(addrs))
var wg sync.WaitGroup
wg.Add(len(addrs))
go func() {
wg.Wait()
close(ch)
}()
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
conn, err := d.Dialer.DialContext(ctx, network, addr)
ch <- &parallelDialerResult{conn: conn, err: err}
}(addr)
}
errs := make([]error, 0, len(addrs))
for res := range ch {
if res.err == nil {
cancel()
return res.conn, res.err
}
errs = append(errs, res.err)
}
return nil, errors.Join(errs...)
}

View File

@@ -0,0 +1,121 @@
package router
import (
"bytes"
"io"
"log"
"net"
"os"
"strings"
"time"
"github.com/fsnotify/fsnotify"
"tailscale.com/util/lineread"
"github.com/Control-D-Inc/ctrld"
)
var clientInfoFiles = []string{
"/tmp/dnsmasq.leases", // ddwrt
"/tmp/dhcp.leases", // openwrt
"/var/lib/misc/dnsmasq.leases", // merlin
"/mnt/data/udapi-config/dnsmasq.lease", // UDM Pro
"/data/udapi-config/dnsmasq.lease", // UDR
}
func (r *router) watchClientInfoTable() {
if r.watcher == nil {
return
}
timer := time.NewTicker(time.Minute * 5)
for {
select {
case <-timer.C:
for _, name := range r.watcher.WatchList() {
_ = readClientInfoFile(name)
}
case event, ok := <-r.watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) {
if err := readClientInfoFile(event.Name); err != nil && !os.IsNotExist(err) {
log.Println("could not read client info file:", err)
}
}
case err, ok := <-r.watcher.Errors:
if !ok {
return
}
log.Println("error:", err)
}
}
}
func Stop() error {
if Name() == "" {
return nil
}
r := routerPlatform.Load()
if r.watcher != nil {
if err := r.watcher.Close(); err != nil {
return err
}
}
return nil
}
func GetClientInfoByMac(mac string) *ctrld.ClientInfo {
if mac == "" {
return nil
}
_ = Name()
r := routerPlatform.Load()
val, ok := r.mac.Load(mac)
if !ok {
return nil
}
return val.(*ctrld.ClientInfo)
}
func readClientInfoFile(name string) error {
f, err := os.Open(name)
if err != nil {
return err
}
defer f.Close()
return readClientInfoReader(f)
}
func readClientInfoReader(reader io.Reader) error {
r := routerPlatform.Load()
return lineread.Reader(reader, func(line []byte) error {
fields := bytes.Fields(line)
if len(fields) != 5 {
return nil
}
mac := string(fields[1])
if _, err := net.ParseMAC(mac); err != nil {
// The second field is not a mac, skip.
return nil
}
ip := normalizeIP(string(fields[2]))
if net.ParseIP(ip) == nil {
log.Printf("invalid ip address entry: %q", ip)
ip = ""
}
hostname := string(fields[3])
r.mac.Store(mac, &ctrld.ClientInfo{Mac: mac, IP: ip, Hostname: hostname})
return nil
})
}
func normalizeIP(in string) string {
// dnsmasq may put ip with interface index in lease file, strip it here.
ip, _, found := strings.Cut(in, "%")
if found {
return ip
}
return in
}

View File

@@ -0,0 +1,70 @@
package router
import (
"strings"
"testing"
"github.com/Control-D-Inc/ctrld"
)
func Test_normalizeIP(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{"v4", "127.0.0.1", "127.0.0.1"},
{"v4 with index", "127.0.0.1%lo", "127.0.0.1"},
{"v6", "fe80::1", "fe80::1"},
{"v6 with index", "fe80::1%22002", "fe80::1"},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := normalizeIP(tc.in); got != tc.want {
t.Errorf("normalizeIP() = %v, want %v", got, tc.want)
}
})
}
}
func Test_readClientInfoReader(t *testing.T) {
tests := []struct {
name string
in string
mac string
}{
{
"good",
`1683329857 e6:20:59:b8:c1:6d 192.168.1.186 * 01:e6:20:59:b8:c1:6d
`,
"e6:20:59:b8:c1:6d",
},
{
"bad seen on UDMdream machine",
`1683329857 e6:20:59:b8:c1:6e 192.168.1.111 * 01:e6:20:59:b8:c1:6e
duid 00:01:00:01:2b:e4:2e:2c:52:52:14:26:dc:1c
1683322985 117442354 2600:4040:b0e6:b700::111 ASDASD 00:01:00:01:2a:d0:b9:81:00:07:32:4c:1c:07
`,
"e6:20:59:b8:c1:6e",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := routerPlatform.Load()
r.mac.Delete(tc.mac)
if err := readClientInfoReader(strings.NewReader(tc.in)); err != nil {
t.Errorf("readClientInfoReader() error = %v", err)
}
info, existed := r.mac.Load(tc.mac)
if !existed {
t.Error("client info missing")
}
if ci, ok := info.(*ctrld.ClientInfo); ok && existed && ci.Mac != tc.mac {
t.Errorf("mac mismatched, got: %q, want: %q", ci.Mac, tc.mac)
}
})
}
}

71
internal/router/ddwrt.go Normal file
View File

@@ -0,0 +1,71 @@
package router
import (
"errors"
"fmt"
"os/exec"
)
const (
nvramCtrldKeyPrefix = "ctrld_"
nvramCtrldSetupKey = "ctrld_setup"
nvramRCStartupKey = "rc_startup"
)
//lint:ignore ST1005 This error is for human.
var errDdwrtJffs2NotEnabled = errors.New(`could not install service without jffs, follow this guide to enable:
https://wiki.dd-wrt.com/wiki/index.php/Journalling_Flash_File_System
`)
func setupDDWrt() error {
// Already setup.
if val, _ := nvram("get", nvramCtrldSetupKey); val == "1" {
return nil
}
data, err := dnsMasqConf()
if err != nil {
return err
}
nvramKvMap := nvramKV()
nvramKvMap["dnsmasq_options"] = data
if err := nvramSetup(nvramKvMap); err != nil {
return err
}
// Restart dnsmasq service.
if err := ddwrtRestartDNSMasq(); err != nil {
return err
}
return nil
}
func cleanupDDWrt() error {
// Restore old configs.
if err := nvramRestore(nvramKV()); err != nil {
return err
}
// Restart dnsmasq service.
if err := ddwrtRestartDNSMasq(); err != nil {
return err
}
return nil
}
func postInstallDDWrt() error {
return nil
}
func ddwrtRestartDNSMasq() error {
if out, err := exec.Command("restart_dns").CombinedOutput(); err != nil {
return fmt.Errorf("restart_dns: %s, %w", string(out), err)
}
return nil
}
func ddwrtJff2Enabled() bool {
out, _ := nvram("get", "enable_jffs2")
return out == "1"
}

View File

@@ -0,0 +1,67 @@
package router
import (
"strings"
"text/template"
)
const dnsMasqConfigContentTmpl = `# GENERATED BY ctrld - DO NOT MODIFY
no-resolv
server=127.0.0.1#5354
{{- if .SendClientInfo}}
add-mac
{{- end}}
`
const merlinDNSMasqPostConfPath = "/jffs/scripts/dnsmasq.postconf"
const merlinDNSMasqPostConfMarker = `# GENERATED BY ctrld - EOF`
const merlinDNSMasqPostConfTmpl = `# GENERATED BY ctrld - DO NOT MODIFY
#!/bin/sh
config_file="$1"
. /usr/sbin/helper.sh
pid=$(cat /tmp/ctrld.pid 2>/dev/null)
if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then
pc_delete "servers-file" "$config_file" # no WAN DNS settings
pc_append "no-resolv" "$config_file" # do not read /etc/resolv.conf
pc_append "server=127.0.0.1#5354" "$config_file" # use ctrld as upstream
{{- if .SendClientInfo}}
pc_append "add-mac" "$config_file" # add client mac
{{- end}}
pc_delete "dnssec" "$config_file" # disable DNSSEC
pc_delete "trust-anchor=" "$config_file" # disable DNSSEC
# For John fork
pc_delete "resolv-file" "$config_file" # no WAN DNS settings
# Change /etc/resolv.conf, which may be changed by WAN DNS setup
pc_delete "nameserver" /etc/resolv.conf
pc_append "nameserver 127.0.0.1" /etc/resolv.conf
exit 0
fi
`
func dnsMasqConf() (string, error) {
var sb strings.Builder
var tmplText string
switch Name() {
case DDWrt, OpenWrt, Ubios:
tmplText = dnsMasqConfigContentTmpl
case Merlin:
tmplText = merlinDNSMasqPostConfTmpl
}
tmpl := template.Must(template.New("").Parse(tmplText))
var to = &struct {
SendClientInfo bool
}{
routerPlatform.Load().sendClientInfo,
}
if err := tmpl.Execute(&sb, to); err != nil {
return "", err
}
return sb.String(), nil
}

89
internal/router/merlin.go Normal file
View File

@@ -0,0 +1,89 @@
package router
import (
"bytes"
"fmt"
"os"
"os/exec"
"strings"
"unicode"
)
func setupMerlin() error {
buf, err := os.ReadFile(merlinDNSMasqPostConfPath)
// Already setup.
if bytes.Contains(buf, []byte(merlinDNSMasqPostConfMarker)) {
return nil
}
if err != nil && !os.IsNotExist(err) {
return err
}
merlinDNSMasqPostConf, err := dnsMasqConf()
if err != nil {
return err
}
data := strings.Join([]string{
merlinDNSMasqPostConf,
"\n",
merlinDNSMasqPostConfMarker,
"\n",
string(buf),
}, "\n")
// Write dnsmasq post conf file.
if err := os.WriteFile(merlinDNSMasqPostConfPath, []byte(data), 0750); err != nil {
return err
}
// Restart dnsmasq service.
if err := merlinRestartDNSMasq(); err != nil {
return err
}
if err := nvramSetup(nvramKV()); err != nil {
return err
}
return nil
}
func cleanupMerlin() error {
// Restore old configs.
if err := nvramRestore(nvramKV()); err != nil {
return err
}
buf, err := os.ReadFile(merlinDNSMasqPostConfPath)
if err != nil && !os.IsNotExist(err) {
return err
}
// Restore dnsmasq post conf file.
if err := os.WriteFile(merlinDNSMasqPostConfPath, merlinParsePostConf(buf), 0750); err != nil {
return err
}
// Restart dnsmasq service.
if err := merlinRestartDNSMasq(); err != nil {
return err
}
return nil
}
func postInstallMerlin() error {
return nil
}
func merlinRestartDNSMasq() error {
if out, err := exec.Command("service", "restart_dnsmasq").CombinedOutput(); err != nil {
return fmt.Errorf("restart_dnsmasq: %s, %w", string(out), err)
}
return nil
}
func merlinParsePostConf(buf []byte) []byte {
if len(buf) == 0 {
return nil
}
parts := bytes.Split(buf, []byte(merlinDNSMasqPostConfMarker))
if len(parts) != 1 {
return bytes.TrimLeftFunc(parts[1], unicode.IsSpace)
}
return buf
}

View File

@@ -0,0 +1,38 @@
package router
import (
"bytes"
"strings"
"testing"
)
func Test_merlinParsePostConf(t *testing.T) {
origContent := "# foo"
data := strings.Join([]string{
merlinDNSMasqPostConfTmpl,
"\n",
merlinDNSMasqPostConfMarker,
"\n",
}, "\n")
tests := []struct {
name string
data string
expected string
}{
{"empty", "", ""},
{"no ctrld", origContent, origContent},
{"ctrld with data", data + origContent, origContent},
{"ctrld without data", data, ""},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
//t.Parallel()
if got := merlinParsePostConf([]byte(tc.data)); !bytes.Equal(got, []byte(tc.expected)) {
t.Errorf("unexpected result, want: %q, got: %q", tc.expected, string(got))
}
})
}
}

93
internal/router/nvram.go Normal file
View File

@@ -0,0 +1,93 @@
package router
import (
"bytes"
"fmt"
"os/exec"
"strings"
)
func nvram(args ...string) (string, error) {
cmd := exec.Command("nvram", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("%s:%w", stderr.String(), err)
}
return strings.TrimSpace(stdout.String()), nil
}
/*
NOTE:
- For Openwrt, DNSSEC is not included in default dnsmasq (require dnsmasq-full).
- For Merlin, DNSSEC is configured during postconf script (see merlinDNSMasqPostConfTmpl).
- For Ubios UDM Pro/Dream Machine, DNSSEC is not included in their dnsmasq package:
+https://community.ui.com/questions/Implement-DNSSEC-into-UniFi/951c72b0-4d88-4c86-9174-45417bd2f9ca
+https://community.ui.com/questions/Enable-DNSSEC-for-Unifi-Dream-Machine-FW-updates/e68e367c-d09b-4459-9444-18908f7c1ea1
*/
func nvramKV() map[string]string {
switch Name() {
case DDWrt:
return map[string]string{
"dns_dnsmasq": "1", // Make dnsmasq running but disable DNS ability, ctrld will replace it.
"dnsmasq_options": "", // Configuration of dnsmasq set by ctrld, filled by setupDDWrt.
"dns_crypt": "0", // Disable DNSCrypt.
"dnssec": "0", // Disable DNSSEC.
}
case Merlin:
return map[string]string{
"dnspriv_enable": "0", // Ensure Merlin native DoT disabled.
}
}
return nil
}
func nvramSetup(m map[string]string) error {
// Backup current value, store ctrld's configs.
for key, value := range m {
old, err := nvram("get", key)
if err != nil {
return fmt.Errorf("%s: %w", old, err)
}
if out, err := nvram("set", nvramCtrldKeyPrefix+key+"="+old); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
if out, err := nvram("set", key+"="+value); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
}
if out, err := nvram("set", nvramCtrldSetupKey+"=1"); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
// Commit.
if out, err := nvram("commit"); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
return nil
}
func nvramRestore(m map[string]string) error {
// Restore old configs.
for key := range m {
ctrldKey := nvramCtrldKeyPrefix + key
old, err := nvram("get", ctrldKey)
if err != nil {
return fmt.Errorf("%s: %w", old, err)
}
_, _ = nvram("unset", ctrldKey)
if out, err := nvram("set", key+"="+old); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
}
if out, err := nvram("unset", "ctrld_setup"); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
// Commit.
if out, err := nvram("commit"); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
return nil
}

View File

@@ -0,0 +1,85 @@
package router
import (
"bytes"
"errors"
"fmt"
"os"
"os/exec"
"strings"
)
var errUCIEntryNotFound = errors.New("uci: Entry not found")
const openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
// IsGLiNet reports whether the router is an GL.iNet router.
func IsGLiNet() bool {
if Name() != OpenWrt {
return false
}
buf, _ := os.ReadFile("/proc/version")
// The output of /proc/version contains "(glinet@glinet)".
return bytes.Contains(buf, []byte(" (glinet"))
}
func setupOpenWrt() error {
// Delete dnsmasq port if set.
if _, err := uci("delete", "dhcp.@dnsmasq[0].port"); err != nil && !errors.Is(err, errUCIEntryNotFound) {
return err
}
// Disable dnsmasq as DNS server.
dnsMasqConfigContent, err := dnsMasqConf()
if err != nil {
return err
}
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(dnsMasqConfigContent), 0600); err != nil {
return err
}
// Commit.
if _, err := uci("commit"); err != nil {
return err
}
// Restart dnsmasq service.
if err := openwrtRestartDNSMasq(); err != nil {
return err
}
return nil
}
func cleanupOpenWrt() error {
// Remove the custom dnsmasq config
if err := os.Remove(openwrtDNSMasqConfigPath); err != nil {
return err
}
// Restart dnsmasq service.
if err := openwrtRestartDNSMasq(); err != nil {
return err
}
return nil
}
func postInstallOpenWrt() error {
return exec.Command("/etc/init.d/ctrld", "enable").Run()
}
func uci(args ...string) (string, error) {
cmd := exec.Command("uci", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if strings.HasPrefix(stderr.String(), errUCIEntryNotFound.Error()) {
return "", errUCIEntryNotFound
}
return "", fmt.Errorf("%s:%w", stderr.String(), err)
}
return strings.TrimSpace(stdout.String()), nil
}
func openwrtRestartDNSMasq() error {
if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil {
return fmt.Errorf("%s: %w", string(out), err)
}
return nil
}

24
internal/router/procd.go Normal file
View File

@@ -0,0 +1,24 @@
package router
const openWrtScript = `#!/bin/sh /etc/rc.common
USE_PROCD=1
# After network starts
START=21
# Before network stops
STOP=89
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
name="{{.Name}}"
pid_file="/var/run/${name}.pid"
start_service() {
echo "Starting ${name}"
procd_open_instance
procd_set_param command ${cmd}
procd_set_param respawn # respawn automatically if something died
procd_set_param stdout 1 # forward stdout of the command to logd
procd_set_param stderr 1 # same for stderr
procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop
procd_close_instance
echo "${name} has been started"
}
`

222
internal/router/router.go Normal file
View File

@@ -0,0 +1,222 @@
package router
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"
"github.com/fsnotify/fsnotify"
"github.com/kardianos/service"
"tailscale.com/logtail/backoff"
"github.com/Control-D-Inc/ctrld"
)
const (
OpenWrt = "openwrt"
DDWrt = "ddwrt"
Merlin = "merlin"
Ubios = "ubios"
)
// ErrNotSupported reports the current router is not supported error.
var ErrNotSupported = errors.New("unsupported platform")
var routerPlatform atomic.Pointer[router]
type router struct {
name string
sendClientInfo bool
mac sync.Map
watcher *fsnotify.Watcher
}
// SupportedPlatforms return all platforms that can be configured to run with ctrld.
func SupportedPlatforms() []string {
return []string{DDWrt, Merlin, OpenWrt, Ubios}
}
var configureFunc = map[string]func() error{
DDWrt: setupDDWrt,
Merlin: setupMerlin,
OpenWrt: setupOpenWrt,
Ubios: setupUbiOS,
}
// Configure configures things for running ctrld on the router.
func Configure(c *ctrld.Config) error {
name := Name()
switch name {
case DDWrt, Merlin, OpenWrt, Ubios:
if c.HasUpstreamSendClientInfo() {
r := routerPlatform.Load()
r.sendClientInfo = true
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
r.watcher = watcher
go r.watchClientInfoTable()
for _, file := range clientInfoFiles {
_ = readClientInfoFile(file)
_ = r.watcher.Add(file)
}
}
configure := configureFunc[name]
if err := configure(); err != nil {
return err
}
return nil
default:
return ErrNotSupported
}
}
// ConfigureService performs necessary setup for running ctrld as a service on router.
func ConfigureService(sc *service.Config) error {
name := Name()
switch name {
case DDWrt:
if !ddwrtJff2Enabled() {
return errDdwrtJffs2NotEnabled
}
case OpenWrt:
sc.Option["SysvScript"] = openWrtScript
case Merlin, Ubios:
}
return nil
}
// PreStart blocks until the router is ready for running ctrld.
func PreStart() (err error) {
if Name() != DDWrt {
return nil
}
pidFile := "/tmp/ctrld.pid"
// On Merlin, NTP may out of sync, so waiting for it to be ready.
//
// Remove pid file and trigger dnsmasq restart, so NTP can resolve
// server name and perform time synchronization.
pid, err := os.ReadFile(pidFile)
if err != nil {
return fmt.Errorf("PreStart: os.Readfile: %w", err)
}
if err := os.Remove(pidFile); err != nil {
return fmt.Errorf("PreStart: os.Remove: %w", err)
}
defer func() {
if werr := os.WriteFile(pidFile, pid, 0600); werr != nil {
err = errors.Join(err, werr)
return
}
if rerr := merlinRestartDNSMasq(); rerr != nil {
err = errors.Join(err, rerr)
return
}
}()
if err := merlinRestartDNSMasq(); err != nil {
return fmt.Errorf("PreStart: merlinRestartDNSMasq: %w", err)
}
// Wait until `ntp_read=1` set.
b := backoff.NewBackoff("PreStart", func(format string, args ...any) {}, 10*time.Second)
for {
out, err := nvram("get", "ntp_ready")
if err != nil {
return fmt.Errorf("PreStart: nvram: %w", err)
}
if out == "1" {
return nil
}
b.BackOff(context.Background(), errors.New("ntp not ready"))
}
}
// PostInstall performs task after installing ctrld on router.
func PostInstall() error {
name := Name()
switch name {
case DDWrt:
return postInstallDDWrt()
case Merlin:
return postInstallMerlin()
case OpenWrt:
return postInstallOpenWrt()
case Ubios:
return postInstallUbiOS()
}
return nil
}
// Cleanup cleans ctrld setup on the router.
func Cleanup() error {
name := Name()
switch name {
case DDWrt:
return cleanupDDWrt()
case Merlin:
return cleanupMerlin()
case OpenWrt:
return cleanupOpenWrt()
case Ubios:
return cleanupUbiOS()
}
return nil
}
// ListenAddress returns the listener address of ctrld on router.
func ListenAddress() string {
name := Name()
switch name {
case DDWrt, Merlin, OpenWrt, Ubios:
return "127.0.0.1:5354"
}
return ""
}
// Name returns name of the router platform.
func Name() string {
if r := routerPlatform.Load(); r != nil {
return r.name
}
r := &router{}
r.name = distroName()
routerPlatform.Store(r)
return r.name
}
func distroName() string {
switch {
case bytes.HasPrefix(uname(), []byte("DD-WRT")):
return DDWrt
case bytes.HasPrefix(uname(), []byte("ASUSWRT-Merlin")):
return Merlin
case haveFile("/etc/openwrt_version"):
return OpenWrt
case haveDir("/data/unifi"):
return Ubios
}
return ""
}
func haveFile(file string) bool {
_, err := os.Stat(file)
return err == nil
}
func haveDir(dir string) bool {
fi, _ := os.Stat(dir)
return fi != nil && fi.IsDir()
}
func uname() []byte {
out, _ := exec.Command("uname", "-o").Output()
return out
}

View File

@@ -0,0 +1,82 @@
package router
import (
"bytes"
"os"
"os/exec"
"github.com/kardianos/service"
)
func init() {
systems := []service.System{
&linuxSystemService{
name: "ddwrt",
detect: func() bool { return Name() == DDWrt },
interactive: func() bool {
is, _ := isInteractive()
return is
},
new: newddwrtService,
},
&linuxSystemService{
name: "merlin",
detect: func() bool { return Name() == Merlin },
interactive: func() bool {
is, _ := isInteractive()
return is
},
new: newMerlinService,
},
&linuxSystemService{
name: "ubios",
detect: func() bool {
if Name() != Ubios {
return false
}
out, err := exec.Command("ubnt-device-info", "firmware").CombinedOutput()
if err == nil {
// For v2/v3, UbiOS use a Debian base with systemd, so it is not
// necessary to use custom implementation for supporting init system.
return bytes.HasPrefix(out, []byte("1."))
}
return true
},
interactive: func() bool {
is, _ := isInteractive()
return is
},
new: newUbiosService,
},
}
systems = append(systems, service.AvailableSystems()...)
service.ChooseSystem(systems...)
}
type linuxSystemService struct {
name string
detect func() bool
interactive func() bool
new func(i service.Interface, platform string, c *service.Config) (service.Service, error)
}
func (sc linuxSystemService) String() string {
return sc.name
}
func (sc linuxSystemService) Detect() bool {
return sc.detect()
}
func (sc linuxSystemService) Interactive() bool {
return sc.interactive()
}
func (sc linuxSystemService) New(i service.Interface, c *service.Config) (service.Service, error) {
return sc.new(i, sc.String(), c)
}
func isInteractive() (bool, error) {
ppid := os.Getppid()
if ppid == 1 {
return false, nil
}
return true, nil
}

View File

@@ -0,0 +1,292 @@
package router
import (
"bytes"
"errors"
"fmt"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
"text/template"
"github.com/kardianos/service"
)
type ddwrtSvc struct {
i service.Interface
platform string
*service.Config
rcStartup string
}
func newddwrtService(i service.Interface, platform string, c *service.Config) (service.Service, error) {
s := &ddwrtSvc{
i: i,
platform: platform,
Config: c,
}
if err := os.MkdirAll("/jffs/etc/config", 0644); err != nil {
return nil, err
}
return s, nil
}
func (s *ddwrtSvc) String() string {
if len(s.DisplayName) > 0 {
return s.DisplayName
}
return s.Name
}
func (s *ddwrtSvc) Platform() string {
return s.platform
}
func (s *ddwrtSvc) configPath() string {
return fmt.Sprintf("/jffs/etc/config/%s.startup", s.Config.Name)
}
func (s *ddwrtSvc) template() *template.Template {
return template.Must(template.New("").Parse(ddwrtSvcScript))
}
func (s *ddwrtSvc) Install() error {
confPath := s.configPath()
if _, err := os.Stat(confPath); err == nil {
return fmt.Errorf("already installed: %s", confPath)
}
path, err := os.Executable()
if err != nil {
return err
}
if !strings.HasPrefix(path, "/jffs/") {
return errors.New("could not install service outside /jffs")
}
var to = &struct {
*service.Config
Path string
}{
s.Config,
path,
}
f, err := os.Create(confPath)
if err != nil {
return err
}
defer f.Close()
if err := s.template().Execute(f, to); err != nil {
return err
}
if err = os.Chmod(confPath, 0755); err != nil {
return err
}
var sb strings.Builder
if err := template.Must(template.New("").Parse(ddwrtStartupCmd)).Execute(&sb, to); err != nil {
return err
}
s.rcStartup = sb.String()
curVal, err := nvram("get", nvramRCStartupKey)
if err != nil {
return err
}
if _, err := nvram("set", nvramCtrldKeyPrefix+nvramRCStartupKey+"="+curVal); err != nil {
return err
}
val := strings.Join([]string{curVal, s.rcStartup + " &", fmt.Sprintf(`echo $! > "/tmp/%s.pid"`, s.Config.Name)}, "\n")
if _, err := nvram("set", nvramRCStartupKey+"="+val); err != nil {
return err
}
if out, err := nvram("commit"); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
return nil
}
func (s *ddwrtSvc) Uninstall() error {
if err := os.Remove(s.configPath()); err != nil {
return err
}
ctrldStartupKey := nvramCtrldKeyPrefix + nvramRCStartupKey
rcStartup, err := nvram("get", ctrldStartupKey)
if err != nil {
return err
}
_, _ = nvram("unset", ctrldStartupKey)
if _, err := nvram("set", nvramRCStartupKey+"="+rcStartup); err != nil {
return err
}
if out, err := nvram("commit"); err != nil {
return fmt.Errorf("%s: %w", out, err)
}
return nil
}
func (s *ddwrtSvc) Logger(errs chan<- error) (service.Logger, error) {
if service.Interactive() {
return service.ConsoleLogger, nil
}
return s.SystemLogger(errs)
}
func (s *ddwrtSvc) SystemLogger(errs chan<- error) (service.Logger, error) {
// TODO(cuonglm): detect syslog enable and return proper logger?
// this at least works with default configuration.
if service.Interactive() {
return service.ConsoleLogger, nil
}
return &noopLogger{}, nil
}
func (s *ddwrtSvc) Run() (err error) {
err = s.i.Start(s)
if err != nil {
return err
}
if interactice, _ := isInteractive(); !interactice {
signal.Ignore(syscall.SIGHUP)
}
var sigChan = make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt)
<-sigChan
return s.i.Stop(s)
}
func (s *ddwrtSvc) Status() (service.Status, error) {
if _, err := os.Stat(s.configPath()); os.IsNotExist(err) {
return service.StatusUnknown, service.ErrNotInstalled
}
out, err := exec.Command(s.configPath(), "status").CombinedOutput()
if err != nil {
return service.StatusUnknown, err
}
switch string(bytes.TrimSpace(out)) {
case "running":
return service.StatusRunning, nil
default:
return service.StatusStopped, nil
}
}
func (s *ddwrtSvc) Start() error {
return exec.Command(s.configPath(), "start").Run()
}
func (s *ddwrtSvc) Stop() error {
return exec.Command(s.configPath(), "stop").Run()
}
func (s *ddwrtSvc) Restart() error {
err := s.Stop()
if err != nil {
return err
}
return s.Start()
}
type noopLogger struct {
}
func (c noopLogger) Error(v ...interface{}) error {
return nil
}
func (c noopLogger) Warning(v ...interface{}) error {
return nil
}
func (c noopLogger) Info(v ...interface{}) error {
return nil
}
func (c noopLogger) Errorf(format string, a ...interface{}) error {
return nil
}
func (c noopLogger) Warningf(format string, a ...interface{}) error {
return nil
}
func (c noopLogger) Infof(format string, a ...interface{}) error {
return nil
}
const ddwrtStartupCmd = `{{.Path}}{{range .Arguments}} {{.}}{{end}}`
const ddwrtSvcScript = `#!/bin/sh
name="{{.Name}}"
cmd="{{.Path}}{{range .Arguments}} {{.}}{{end}}"
pid_file="/tmp/$name.pid"
get_pid() {
cat "$pid_file"
}
is_running() {
[ -f "$pid_file" ] && ps | grep -q "^ *$(get_pid) "
}
case "$1" in
start)
if is_running; then
echo "Already started"
else
echo "Starting $name"
$cmd &
echo $! > "$pid_file"
chmod 600 "$pid_file"
if ! is_running; then
echo "Failed to start $name"
exit 1
fi
fi
;;
stop)
if is_running; then
echo -n "Stopping $name..."
kill "$(get_pid)"
for _ in 1 2 3 4 5; do
if ! is_running; then
echo "stopped"
if [ -f "$pid_file" ]; then
rm "$pid_file"
fi
exit 0
fi
printf "."
sleep 2
done
echo "failed to stop $name"
exit 1
fi
exit 1
;;
restart)
$0 stop
$0 start
;;
status)
if is_running; then
echo "running"
else
echo "stopped"
exit 1
fi
;;
*)
echo "Usage: $0 {start|stop|restart|status}"
exit 1
;;
esac
exit 0
`

View File

@@ -0,0 +1,354 @@
package router
import (
"bytes"
"errors"
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"syscall"
"text/template"
"github.com/kardianos/service"
)
const (
merlinJFFSScriptPath = "/jffs/scripts/services-start"
merlinJFFSServiceEventScriptPath = "/jffs/scripts/service-event"
)
type merlinSvc struct {
i service.Interface
platform string
*service.Config
}
func newMerlinService(i service.Interface, platform string, c *service.Config) (service.Service, error) {
s := &merlinSvc{
i: i,
platform: platform,
Config: c,
}
return s, nil
}
func (s *merlinSvc) String() string {
if len(s.DisplayName) > 0 {
return s.DisplayName
}
return s.Name
}
func (s *merlinSvc) Platform() string {
return s.platform
}
func (s *merlinSvc) configPath() string {
path, err := os.Executable()
if err != nil {
return ""
}
return path + ".startup"
}
func (s *merlinSvc) template() *template.Template {
return template.Must(template.New("").Parse(merlinSvcScript))
}
func (s *merlinSvc) Install() error {
exePath, err := os.Executable()
if err != nil {
return err
}
if !strings.HasPrefix(exePath, "/jffs/") {
return errors.New("could not install service outside /jffs")
}
if _, err := nvram("set", "jffs2_scripts=1"); err != nil {
return err
}
if _, err := nvram("commit"); err != nil {
return err
}
confPath := s.configPath()
if _, err := os.Stat(confPath); err == nil {
return fmt.Errorf("already installed: %s", confPath)
}
var to = &struct {
*service.Config
Path string
}{
s.Config,
exePath,
}
f, err := os.Create(confPath)
if err != nil {
return fmt.Errorf("os.Create: %w", err)
}
defer f.Close()
if err := s.template().Execute(f, to); err != nil {
return fmt.Errorf("s.template.Execute: %w", err)
}
if err = os.Chmod(confPath, 0755); err != nil {
return fmt.Errorf("os.Chmod: startup script: %w", err)
}
if err := os.MkdirAll(filepath.Dir(merlinJFFSScriptPath), 0755); err != nil {
return fmt.Errorf("os.MkdirAll: %w", err)
}
tmpScript, err := os.CreateTemp("", "ctrld_install")
if err != nil {
return fmt.Errorf("os.CreateTemp: %w", err)
}
defer os.Remove(tmpScript.Name())
defer tmpScript.Close()
if _, err := tmpScript.WriteString(merlinAddLineToScript); err != nil {
return fmt.Errorf("tmpScript.WriteString: %w", err)
}
if err := tmpScript.Close(); err != nil {
return fmt.Errorf("tmpScript.Close: %w", err)
}
addLineToScript := func(line, script string) error {
if _, err := os.Stat(script); os.IsNotExist(err) {
if err := os.WriteFile(script, []byte("#!/bin/sh\n"), 0755); err != nil {
return err
}
}
if err := os.Chmod(script, 0755); err != nil {
return fmt.Errorf("os.Chmod: jffs script: %w", err)
}
if err := exec.Command("sh", tmpScript.Name(), line, script).Run(); err != nil {
return fmt.Errorf("exec.Command: add startup script: %w", err)
}
return nil
}
for script, line := range map[string]string{
merlinJFFSScriptPath: s.configPath() + " start",
merlinJFFSServiceEventScriptPath: s.configPath() + ` service_event "$1" "$2"`,
} {
if err := addLineToScript(line, script); err != nil {
return err
}
}
return nil
}
func (s *merlinSvc) Uninstall() error {
if err := os.Remove(s.configPath()); err != nil {
return fmt.Errorf("os.Remove: %w", err)
}
tmpScript, err := os.CreateTemp("", "ctrld_uninstall")
if err != nil {
return fmt.Errorf("os.CreateTemp: %w", err)
}
defer os.Remove(tmpScript.Name())
defer tmpScript.Close()
if _, err := tmpScript.WriteString(merlinRemoveLineFromScript); err != nil {
return fmt.Errorf("tmpScript.WriteString: %w", err)
}
if err := tmpScript.Close(); err != nil {
return fmt.Errorf("tmpScript.Close: %w", err)
}
removeLineFromScript := func(line, script string) error {
if _, err := os.Stat(script); os.IsNotExist(err) {
if err := os.WriteFile(script, []byte("#!/bin/sh\n"), 0755); err != nil {
return err
}
}
if err := os.Chmod(script, 0755); err != nil {
return fmt.Errorf("os.Chmod: jffs script: %w", err)
}
if err := exec.Command("sh", tmpScript.Name(), line, script).Run(); err != nil {
return fmt.Errorf("exec.Command: add startup script: %w", err)
}
return nil
}
for script, line := range map[string]string{
merlinJFFSScriptPath: s.configPath() + " start",
merlinJFFSServiceEventScriptPath: s.configPath() + ` service_event "$1" "$2"`,
} {
if err := removeLineFromScript(line, script); err != nil {
return err
}
}
return nil
}
func (s *merlinSvc) Logger(errs chan<- error) (service.Logger, error) {
if service.Interactive() {
return service.ConsoleLogger, nil
}
return s.SystemLogger(errs)
}
func (s *merlinSvc) SystemLogger(errs chan<- error) (service.Logger, error) {
return newSysLogger(s.Name, errs)
}
func (s *merlinSvc) Run() (err error) {
err = s.i.Start(s)
if err != nil {
return err
}
if interactice, _ := isInteractive(); !interactice {
signal.Ignore(syscall.SIGHUP)
}
var sigChan = make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt)
<-sigChan
return s.i.Stop(s)
}
func (s *merlinSvc) Status() (service.Status, error) {
if _, err := os.Stat(s.configPath()); os.IsNotExist(err) {
return service.StatusUnknown, service.ErrNotInstalled
}
out, err := exec.Command(s.configPath(), "status").CombinedOutput()
if err != nil {
return service.StatusUnknown, err
}
switch string(bytes.TrimSpace(out)) {
case "running":
return service.StatusRunning, nil
default:
return service.StatusStopped, nil
}
}
func (s *merlinSvc) Start() error {
return exec.Command(s.configPath(), "start").Run()
}
func (s *merlinSvc) Stop() error {
return exec.Command(s.configPath(), "stop").Run()
}
func (s *merlinSvc) Restart() error {
err := s.Stop()
if err != nil {
return err
}
return s.Start()
}
const merlinSvcScript = `#!/bin/sh
name="{{.Name}}"
cmd="{{.Path}}{{range .Arguments}} {{.}}{{end}}"
pid_file="/tmp/$name.pid"
get_pid() {
cat "$pid_file"
}
is_running() {
[ -f "$pid_file" ] && ps | grep -q "^ *$(get_pid) "
}
case "$1" in
start)
if is_running; then
logger -c "Already started"
else
logger -c "Starting $name"
if [ -f /rom/ca-bundle.crt ]; then
# For Johns fork
export SSL_CERT_FILE=/rom/ca-bundle.crt
fi
$cmd &
echo $! > "$pid_file"
chmod 600 "$pid_file"
if ! is_running; then
logger -c "Failed to start $name"
exit 1
fi
fi
;;
stop)
if is_running; then
logger -c "Stopping $name..."
kill "$(get_pid)"
for _ in 1 2 3 4 5; do
if ! is_running; then
logger -c "stopped"
if [ -f "$pid_file" ]; then
rm "$pid_file"
fi
exit 0
fi
printf "."
sleep 2
done
logger -c "failed to stop $name"
exit 1
fi
exit 1
;;
restart)
$0 stop
$0 start
;;
status)
if is_running; then
echo "running"
else
echo "stopped"
exit 1
fi
;;
service_event)
event=$2
svc=$3
dnsmasq_pid_file=$(sed -n '/pid-file=/s///p' /etc/dnsmasq.conf)
if [ "$event" = "restart" ] && [ "$svc" = "diskmon" ]; then
kill "$(cat "$dnsmasq_pid_file")" >/dev/null 2>&1
fi
;;
*)
echo "Usage: $0 {start|stop|restart|status}"
exit 1
;;
esac
exit 0
`
const merlinAddLineToScript = `#!/bin/sh
line=$1
file=$2
. /usr/sbin/helper.sh
pc_append "$line" "$file"
`
const merlinRemoveLineFromScript = `#!/bin/sh
line=$1
file=$2
. /usr/sbin/helper.sh
pc_delete "$line" "$file"
`

View File

@@ -0,0 +1,336 @@
package router
import (
"bytes"
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"syscall"
"text/template"
"time"
"github.com/kardianos/service"
)
// This is a copy of https://github.com/kardianos/service/blob/v1.2.1/service_sysv_linux.go,
// with modification for supporting ubios v1 init system.
type ubiosSvc struct {
i service.Interface
platform string
*service.Config
}
func newUbiosService(i service.Interface, platform string, c *service.Config) (service.Service, error) {
s := &ubiosSvc{
i: i,
platform: platform,
Config: c,
}
return s, nil
}
func (s *ubiosSvc) String() string {
if len(s.DisplayName) > 0 {
return s.DisplayName
}
return s.Name
}
func (s *ubiosSvc) Platform() string {
return s.platform
}
func (s *ubiosSvc) configPath() string {
return "/etc/init.d/" + s.Config.Name
}
func (s *ubiosSvc) execPath() (string, error) {
if len(s.Executable) != 0 {
return filepath.Abs(s.Executable)
}
return os.Executable()
}
func (s *ubiosSvc) template() *template.Template {
return template.Must(template.New("").Funcs(tf).Parse(ubiosSvcScript))
}
func (s *ubiosSvc) Install() error {
confPath := s.configPath()
if _, err := os.Stat(confPath); err == nil {
return fmt.Errorf("init already exists: %s", confPath)
}
f, err := os.Create(confPath)
if err != nil {
return fmt.Errorf("failed to create config path: %w", err)
}
defer f.Close()
path, err := s.execPath()
if err != nil {
return fmt.Errorf("failed to get exec path: %w", err)
}
var to = &struct {
*service.Config
Path string
DnsMasqConfPath string
}{
s.Config,
path,
ubiosDNSMasqConfigPath,
}
if err := s.template().Execute(f, to); err != nil {
return fmt.Errorf("failed to create init script: %w", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("failed to save init script: %w", err)
}
if err = os.Chmod(confPath, 0755); err != nil {
return fmt.Errorf("failed to set init script executable: %w", err)
}
// Enable on boot
script, err := os.CreateTemp("", "ctrld_boot.service")
if err != nil {
return fmt.Errorf("failed to create boot service tmp file: %w", err)
}
defer script.Close()
svcConfig := *to.Config
svcConfig.Arguments = os.Args[1:]
to.Config = &svcConfig
if err := template.Must(template.New("").Funcs(tf).Parse(ubiosBootSystemdService)).Execute(script, &to); err != nil {
return fmt.Errorf("failed to create boot service file: %w", err)
}
if err := script.Close(); err != nil {
return fmt.Errorf("failed to save boot service file: %w", err)
}
// Copy the boot script to container and start.
cmd := exec.Command("podman", "cp", "--pause=false", script.Name(), "unifi-os:/lib/systemd/system/ctrld-boot.service")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to copy boot script, out: %s, err: %v", string(out), err)
}
cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "enable", "--now", "ctrld-boot.service")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to start ctrld boot script, out: %s, err: %v", string(out), err)
}
return nil
}
func (s *ubiosSvc) Uninstall() error {
if err := os.Remove(s.configPath()); err != nil {
return err
}
// Remove ctrld-boot service inside unifi-os container.
cmd := exec.Command("podman", "exec", "unifi-os", "systemctl", "disable", "ctrld-boot.service")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to disable ctrld-boot service, out: %s, err: %v", string(out), err)
}
cmd = exec.Command("podman", "exec", "unifi-os", "rm", "/lib/systemd/system/ctrld-boot.service")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to remove ctrld-boot service file, out: %s, err: %v", string(out), err)
}
cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "daemon-reload")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to reload systemd service, out: %s, err: %v", string(out), err)
}
cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "reset-failed")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to reset-failed systemd service, out: %s, err: %v", string(out), err)
}
return nil
}
func (s *ubiosSvc) Logger(errs chan<- error) (service.Logger, error) {
if service.Interactive() {
return service.ConsoleLogger, nil
}
return s.SystemLogger(errs)
}
func (s *ubiosSvc) SystemLogger(errs chan<- error) (service.Logger, error) {
return newSysLogger(s.Name, errs)
}
func (s *ubiosSvc) Run() (err error) {
err = s.i.Start(s)
if err != nil {
return err
}
if interactice, _ := isInteractive(); !interactice {
signal.Ignore(syscall.SIGHUP)
}
var sigChan = make(chan os.Signal, 3)
signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt)
<-sigChan
return s.i.Stop(s)
}
func (s *ubiosSvc) Status() (service.Status, error) {
if _, err := os.Stat(s.configPath()); os.IsNotExist(err) {
return service.StatusUnknown, service.ErrNotInstalled
}
out, err := exec.Command(s.configPath(), "status").CombinedOutput()
if err != nil {
return service.StatusUnknown, err
}
switch string(bytes.TrimSpace(out)) {
case "Running":
return service.StatusRunning, nil
default:
return service.StatusStopped, nil
}
}
func (s *ubiosSvc) Start() error {
return exec.Command(s.configPath(), "start").Run()
}
func (s *ubiosSvc) Stop() error {
return exec.Command(s.configPath(), "stop").Run()
}
func (s *ubiosSvc) Restart() error {
err := s.Stop()
if err != nil {
return err
}
time.Sleep(50 * time.Millisecond)
return s.Start()
}
const ubiosBootSystemdService = `[Unit]
Description=Run ctrld On Startup UDM
Wants=network-online.target
After=network-online.target
StartLimitIntervalSec=500
StartLimitBurst=5
[Service]
Restart=on-failure
RestartSec=5s
ExecStart=/sbin/ssh-proxy '[ -f "{{.DnsMasqConfPath}}" ] || {{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}'
RemainAfterExit=true
[Install]
WantedBy=multi-user.target
`
const ubiosSvcScript = `#!/bin/sh
# For RedHat and cousins:
# chkconfig: - 99 01
# description: {{.Description}}
# processname: {{.Path}}
### BEGIN INIT INFO
# Provides: {{.Path}}
# Required-Start:
# Required-Stop:
# Default-Start: 2 3 4 5
# Default-Stop: 0 1 6
# Short-Description: {{.DisplayName}}
# Description: {{.Description}}
### END INIT INFO
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
name=$(basename $(readlink -f $0))
pid_file="/var/run/$name.pid"
stdout_log="/var/log/$name.log"
stderr_log="/var/log/$name.err"
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
get_pid() {
cat "$pid_file"
}
is_running() {
[ -f "$pid_file" ] && cat /proc/$(get_pid)/stat > /dev/null 2>&1
}
case "$1" in
start)
if is_running; then
echo "Already started"
else
echo "Starting $name"
{{if .WorkingDirectory}}cd '{{.WorkingDirectory}}'{{end}}
$cmd >> "$stdout_log" 2>> "$stderr_log" &
echo $! > "$pid_file"
if ! is_running; then
echo "Unable to start, see $stdout_log and $stderr_log"
exit 1
fi
fi
;;
stop)
if is_running; then
echo -n "Stopping $name.."
kill $(get_pid)
for i in $(seq 1 10)
do
if ! is_running; then
break
fi
echo -n "."
sleep 1
done
echo
if is_running; then
echo "Not stopped; may still be shutting down or shutdown may have failed"
exit 1
else
echo "Stopped"
if [ -f "$pid_file" ]; then
rm "$pid_file"
fi
fi
else
echo "Not running"
fi
;;
restart)
$0 stop
if is_running; then
echo "Unable to stop, will not attempt to start"
exit 1
fi
$0 start
;;
status)
if is_running; then
echo "Running"
else
echo "Stopped"
exit 1
fi
;;
*)
echo "Usage: $0 {start|stop|restart|status}"
exit 1
;;
esac
exit 0
`
var tf = map[string]interface{}{
"cmd": func(s string) string {
return `"` + strings.Replace(s, `"`, `\"`, -1) + `"`
},
"cmdEscape": func(s string) string {
return strings.Replace(s, " ", `\x20`, -1)
},
}

49
internal/router/syslog.go Normal file
View File

@@ -0,0 +1,49 @@
//go:build linux || darwin || freebsd
package router
import (
"fmt"
"log/syslog"
"github.com/kardianos/service"
)
func newSysLogger(name string, errs chan<- error) (service.Logger, error) {
w, err := syslog.New(syslog.LOG_INFO, name)
if err != nil {
return nil, err
}
return sysLogger{w, errs}, nil
}
type sysLogger struct {
*syslog.Writer
errs chan<- error
}
func (s sysLogger) send(err error) error {
if err != nil && s.errs != nil {
s.errs <- err
}
return err
}
func (s sysLogger) Error(v ...interface{}) error {
return s.send(s.Writer.Err(fmt.Sprint(v...)))
}
func (s sysLogger) Warning(v ...interface{}) error {
return s.send(s.Writer.Warning(fmt.Sprint(v...)))
}
func (s sysLogger) Info(v ...interface{}) error {
return s.send(s.Writer.Info(fmt.Sprint(v...)))
}
func (s sysLogger) Errorf(format string, a ...interface{}) error {
return s.send(s.Writer.Err(fmt.Sprintf(format, a...)))
}
func (s sysLogger) Warningf(format string, a ...interface{}) error {
return s.send(s.Writer.Warning(fmt.Sprintf(format, a...)))
}
func (s sysLogger) Infof(format string, a ...interface{}) error {
return s.send(s.Writer.Info(fmt.Sprintf(format, a...)))
}

View File

@@ -0,0 +1,7 @@
package router
import "github.com/kardianos/service"
func newSysLogger(name string, errs chan<- error) (service.Logger, error) {
return service.ConsoleLogger, nil
}

59
internal/router/ubios.go Normal file
View File

@@ -0,0 +1,59 @@
package router
import (
"bytes"
"os"
"strconv"
)
const (
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
)
func setupUbiOS() error {
// Disable dnsmasq as DNS server.
dnsMasqConfigContent, err := dnsMasqConf()
if err != nil {
return err
}
if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(dnsMasqConfigContent), 0600); err != nil {
return err
}
// Restart dnsmasq service.
if err := ubiosRestartDNSMasq(); err != nil {
return err
}
return nil
}
func cleanupUbiOS() error {
// Remove the custom dnsmasq config
if err := os.Remove(ubiosDNSMasqConfigPath); err != nil {
return err
}
// Restart dnsmasq service.
if err := ubiosRestartDNSMasq(); err != nil {
return err
}
return nil
}
func postInstallUbiOS() error {
return nil
}
func ubiosRestartDNSMasq() error {
buf, err := os.ReadFile("/run/dnsmasq.pid")
if err != nil {
return err
}
pid, err := strconv.ParseUint(string(bytes.TrimSpace(buf)), 10, 64)
if err != nil {
return err
}
proc, err := os.FindProcess(int(pid))
if err != nil {
return err
}
return proc.Kill()
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"sync"
"time"
"github.com/miekg/dns"
)
@@ -33,7 +34,7 @@ var errUnknownResolver = errors.New("unknown resolver")
// NewResolver creates a Resolver based on the given upstream config.
func NewResolver(uc *UpstreamConfig) (Resolver, error) {
typ, endpoint := uc.Type, uc.Endpoint
typ := uc.Type
switch typ {
case ResolverTypeDOH, ResolverTypeDOH3:
return newDohResolver(uc), nil
@@ -44,7 +45,7 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
case ResolverTypeOS:
return or, nil
case ResolverTypeLegacy:
return &legacyResolver{endpoint: endpoint}, nil
return &legacyResolver{uc: uc}, nil
}
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
}
@@ -79,7 +80,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
for _, server := range o.nameservers {
go func(server string) {
defer wg.Done()
answer, _, err := dnsClient.ExchangeContext(ctx, msg, server)
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
ch <- &osResolverResult{answer: answer, err: err}
}(server)
}
@@ -93,7 +94,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
errs = append(errs, res.err)
}
return nil, joinErrors(errs...)
return nil, errors.Join(errs...)
}
func newDialer(dnsAddress string) *net.Dialer {
@@ -109,16 +110,87 @@ func newDialer(dnsAddress string) *net.Dialer {
}
type legacyResolver struct {
endpoint string
uc *UpstreamConfig
}
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// See comment in (*dotResolver).resolve method.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dnsTyp := uint16(0)
if msg != nil && len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
_, udpNet := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: "udp",
Net: udpNet,
Dialer: dialer,
}
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint)
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.uc.Endpoint)
return answer, err
}
// LookupIP looks up host using OS resolver.
// It returns a slice of that host's IPv4 and IPv6 addresses.
func LookupIP(domain string) []string {
return lookupIP(domain, -1, true)
}
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
resolver := &osResolver{nameservers: nameservers()}
if withBootstrapDNS {
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
}
ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
timeoutMs := 2000
if timeout > 0 && timeout < timeoutMs {
timeoutMs = timeout
}
questionDomain := dns.Fqdn(domain)
ipFromRecord := func(record dns.RR) string {
switch ar := record.(type) {
case *dns.A:
if ar.Hdr.Name != questionDomain {
return ""
}
return ar.A.String()
case *dns.AAAA:
if ar.Hdr.Name != questionDomain {
return ""
}
return ar.AAAA.String()
}
return ""
}
lookup := func(dnsType uint16) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
m := new(dns.Msg)
m.SetQuestion(questionDomain, dnsType)
m.RecursionDesired = true
r, err := resolver.Resolve(ctx, m)
if err != nil {
ProxyLog.Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain)
return
}
if r.Rcode != dns.RcodeSuccess {
ProxyLog.Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode])
return
}
if len(r.Answer) == 0 {
ProxyLog.Error().Msg("no answer from OS resolver")
return
}
for _, a := range r.Answer {
if ip := ipFromRecord(a); ip != "" {
ips = append(ips, ip)
}
}
}
// Find all A, AAAA records of the domain.
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
lookup(dnsType)
}
return ips
}

View File

@@ -50,6 +50,13 @@ type = "legacy"
endpoint = "8.8.8.8"
timeout = 5
[upstream.3]
name = "DOH with client info"
type = "doh"
endpoint = "https://dns.controld.com/client_info_upstream/main-device"
timeout = 5
send_client_info = false
[listener.0]
ip = "127.0.0.1"
port = 53