Import code, preparing for release

This commit is contained in:
Cuong Manh Le
2022-12-12 21:24:20 +07:00
committed by Cuong Manh Le
parent cef3cc497e
commit 91d60d2a64
22 changed files with 2545 additions and 1 deletions

116
cmd/ctrld/cli.go Normal file
View File

@@ -0,0 +1,116 @@
package main
import (
"fmt"
"log"
"os"
"os/exec"
"runtime"
"github.com/kardianos/service"
"github.com/pelletier/go-toml"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var (
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
defaultConfigWritten = false
)
func initCLI() {
// Enable opening via explorer.exe on Windows.
// See: https://github.com/spf13/cobra/issues/844.
cobra.MousetrapHelpText = ""
rootCmd := &cobra.Command{
Use: "ctrld",
Short: "Running Control-D DNS proxy server",
Version: "1.0.0",
}
rootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "verbose log output")
runCmd := &cobra.Command{
Use: "run",
Short: "Run the DNS proxy server",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
if daemon && runtime.GOOS == "windows" {
log.Fatal("Cannot run in daemon mode. Please install a Windows service.")
}
if configPath != "" {
v.SetConfigFile(configPath)
}
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
writeConfigFile()
defaultConfigWritten = true
} else {
log.Fatalf("failed to decode config file: %v", err)
}
}
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
}
initLogging()
if daemon {
exe, err := os.Executable()
if err != nil {
mainLog.Error().Err(err).Msg("failed to find the binary")
os.Exit(1)
}
curDir, err := os.Getwd()
if err != nil {
mainLog.Error().Err(err).Msg("failed to get current working directory")
os.Exit(1)
}
// If running as daemon, re-run the command in background, with daemon off.
cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...)
cmd.Dir = curDir
if err := cmd.Start(); err != nil {
mainLog.Error().Err(err).Msg("failed to start process as daemon")
os.Exit(1)
}
mainLog.Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started")
os.Exit(0)
}
s, err := service.New(&prog{}, svcConfig)
if err != nil {
mainLog.Fatal().Err(err).Msg("failed create new service")
}
serviceLogger, err := s.Logger(nil)
if err != nil {
mainLog.Error().Err(err).Msg("failed to get service logger")
return
}
if err := s.Run(); err != nil {
if sErr := serviceLogger.Error(err); sErr != nil {
mainLog.Error().Err(sErr).Msg("failed to write service log")
}
mainLog.Error().Err(err).Msg("failed to start service")
}
},
}
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
rootCmd.AddCommand(runCmd)
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func writeConfigFile() {
c := v.AllSettings()
bs, err := toml.Marshal(c)
if err != nil {
log.Fatalf("unable to marshal config to toml: %v", err)
}
if err := os.WriteFile("config.toml", bs, 0600); err != nil {
log.Printf("failed to write config file: %v\n", err)
}
}

215
cmd/ctrld/dns_proxy.go Normal file
View File

@@ -0,0 +1,215 @@
package main
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/miekg/dns"
"github.com/Control-D-Inc/ctrld"
)
func (p *prog) serveUDP(listenerNum string) error {
listenerConfig := p.cfg.Listener[listenerNum]
// make sure ip is allocated
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
mainLog.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
return allocErr
}
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
domain := canonicalName(m.Question[0].Name)
reqId := requestID()
fmtSrcToDest := fmtRemoteToLocal(listenerNum, w.RemoteAddr().String(), w.LocalAddr().String())
t := time.Now()
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
ctrld.Log(ctx, proxyLog.Debug(), "%s received query: %s", fmtSrcToDest, domain)
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, w.RemoteAddr(), domain)
var answer *dns.Msg
if !matched && listenerConfig.Restricted {
answer = new(dns.Msg)
answer.SetRcode(m, dns.RcodeRefused)
} else {
answer = p.proxy(ctx, upstreams, m)
rtt := time.Since(t)
ctrld.Log(ctx, proxyLog.Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
}
if err := w.WriteMsg(answer); err != nil {
ctrld.Log(ctx, mainLog.Error().Err(err), "serveUDP: failed to send DNS response to client")
}
})
s := &dns.Server{
Addr: net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)),
Net: "udp",
Handler: handler,
}
return s.ListenAndServe()
}
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
upstreams := []string{"upstream." + defaultUpstreamNum}
matchedPolicy := "no policy"
matchedNetwork := "no network"
matchedRule := "no rule"
matched := false
defer func() {
if !matched && lc.Restricted {
ctrld.Log(ctx, proxyLog.Info(), "query refused, %s does not match any network policy", addr.String())
return
}
ctrld.Log(ctx, proxyLog.Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams)
}()
if lc.Policy == nil {
return upstreams, false
}
do := func(policyUpstreams []string) {
upstreams = append([]string(nil), policyUpstreams...)
}
for _, rule := range lc.Policy.Rules {
// There's only one entry per rule, config validation ensures this.
for source, targets := range rule {
if source == domain || wildcardMatches(source, domain) {
matchedPolicy = lc.Policy.Name
matchedRule = source
do(targets)
matched = true
return upstreams, matched
}
}
}
var sourceIP net.IP
switch addr := addr.(type) {
case *net.UDPAddr:
sourceIP = addr.IP
case *net.TCPAddr:
sourceIP = addr.IP
}
for _, rule := range lc.Policy.Networks {
for source, targets := range rule {
networkNum := strings.TrimPrefix(source, "network.")
nc := p.cfg.Network[networkNum]
if nc == nil {
continue
}
for _, ipNet := range nc.IPNets {
if ipNet.Contains(sourceIP) {
matchedPolicy = lc.Policy.Name
matchedNetwork = source
do(targets)
matched = true
return upstreams, matched
}
}
}
}
return upstreams, matched
}
func (p *prog) proxy(ctx context.Context, upstreams []string, msg *dns.Msg) *dns.Msg {
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
if err != nil {
ctrld.Log(ctx, proxyLog.Error().Err(err), "failed to create resolver")
return nil
}
if upstreamConfig.Timeout > 0 {
timeoutCtx, cancel := context.WithTimeout(ctx, time.Millisecond*time.Duration(upstreamConfig.Timeout))
defer cancel()
ctx = timeoutCtx
}
answer, err := dnsResolver.Resolve(ctx, msg)
if err != nil {
ctrld.Log(ctx, proxyLog.Error().Err(err), "failed to resolve query")
return nil
}
return answer
}
for n, upstreamConfig := range upstreamConfigs {
if answer := resolve(n, upstreamConfig, msg); answer != nil {
return answer
}
}
ctrld.Log(ctx, proxyLog.Error(), "all upstreams failed")
answer := new(dns.Msg)
answer.SetRcode(msg, dns.RcodeServerFailure)
return answer
}
// canonicalName returns canonical name from FQDN with "." trimmed.
func canonicalName(fqdn string) string {
q := strings.TrimSpace(fqdn)
q = strings.TrimSuffix(q, ".")
// https://datatracker.ietf.org/doc/html/rfc4343
q = strings.ToLower(q)
return q
}
func wildcardMatches(wildcard, domain string) bool {
// Wildcard match.
wildCardParts := strings.Split(wildcard, "*")
if len(wildCardParts) != 2 {
return false
}
switch {
case len(wildCardParts[0]) > 0 && len(wildCardParts[1]) > 0:
// Domain must match both prefix and suffix.
return strings.HasPrefix(domain, wildCardParts[0]) && strings.HasSuffix(domain, wildCardParts[1])
case len(wildCardParts[1]) > 0:
// Only suffix must match.
return strings.HasSuffix(domain, wildCardParts[1])
case len(wildCardParts[0]) > 0:
// Only prefix must match.
return strings.HasPrefix(domain, wildCardParts[0])
}
return false
}
func fmtRemoteToLocal(listenerNum, remote, local string) string {
return fmt.Sprintf("%s -> listener.%s: %s:", remote, listenerNum, local)
}
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
for _, upstream := range upstreams {
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
}
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
}
return upstreamConfigs
}
func requestID() string {
b := make([]byte, 3) // 6 chars
if _, err := rand.Read(b); err != nil {
panic(err)
}
return hex.EncodeToString(b)
}
var osUpstreamConfig = &ctrld.UpstreamConfig{
Name: "OS resolver",
Type: "os",
}

117
cmd/ctrld/dns_proxy_test.go Normal file
View File

@@ -0,0 +1,117 @@
package main
import (
"context"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/testhelper"
)
func Test_wildcardMatches(t *testing.T) {
tests := []struct {
name string
wildcard string
domain string
match bool
}{
{"prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
{"prefix", "*.windscribe.com", "anything.windscribe.com", true},
{"prefix not match other domain", "*.windscribe.com", "example.com", false},
{"prefix not match domain in name", "*.windscribe.com", "wwindscribe.com", false},
{"suffix", "suffix.*", "suffix.windscribe.com", true},
{"suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
{"both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
{"both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
}
})
}
}
func Test_canonicalName(t *testing.T) {
tests := []struct {
name string
domain string
canonical string
}{
{"fqdn to canonical", "windscribe.com.", "windscribe.com"},
{"already canonical", "windscribe.com", "windscribe.com"},
{"case insensitive", "Windscribe.Com.", "windscribe.com"},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := canonicalName(tc.domain); got != tc.canonical {
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
}
})
}
}
func Test_prog_upstreamFor(t *testing.T) {
cfg := testhelper.SampleConfig(t)
prog := &prog{cfg: cfg}
for _, nc := range prog.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
t.Fatal(err)
}
nc.IPNets = append(nc.IPNets, ipNet)
}
}
tests := []struct {
name string
ip string
defaultUpstreamNum string
lc *ctrld.ListenerConfig
domain string
upstreams []string
matched bool
}{
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true},
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true},
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true},
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
for _, network := range []string{"udp", "tcp"} {
var (
addr net.Addr
err error
)
switch network {
case "udp":
addr, err = net.ResolveUDPAddr(network, tc.ip)
case "tcp":
addr, err = net.ResolveTCPAddr(network, tc.ip)
}
require.NoError(t, err)
require.NotNil(t, addr)
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
assert.Equal(t, tc.matched, matched)
assert.Equal(t, tc.upstreams, upstreams)
}
})
}
}

65
cmd/ctrld/main.go Normal file
View File

@@ -0,0 +1,65 @@
package main
import (
"fmt"
"io"
"os"
"time"
"github.com/rs/zerolog"
"github.com/Control-D-Inc/ctrld"
)
var (
configPath string
daemon bool
cfg ctrld.Config
verbose bool
bootstrapDNS = "76.76.2.0"
rootLogger = zerolog.New(io.Discard)
mainLog = rootLogger
proxyLog = rootLogger
)
func main() {
ctrld.InitConfig(v, "config")
initCLI()
}
func initLogging() {
writers := []io.Writer{io.Discard}
isLog := cfg.Service.LogLevel != ""
if logPath := cfg.Service.LogPath; logPath != "" {
logFile, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to creating log file: %v", err)
os.Exit(1)
}
isLog = true
writers = append(writers, logFile)
}
zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs
if verbose || isLog {
consoleWriter := zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
w.TimeFormat = time.StampMilli
})
writers = append(writers, consoleWriter)
multi := zerolog.MultiLevelWriter(writers...)
mainLog = mainLog.Output(multi).With().Timestamp().Str("prefix", "main").Logger()
proxyLog = proxyLog.Output(multi).With().Timestamp().Logger()
// TODO: find a better way.
ctrld.ProxyLog = proxyLog
}
if cfg.Service.LogLevel == "" {
return
}
level, err := zerolog.ParseLevel(cfg.Service.LogLevel)
if err != nil {
mainLog.Warn().Err(err).Msg("could not set log level")
return
}
zerolog.SetGlobalLevel(level)
}

25
cmd/ctrld/os_linux.go Normal file
View File

@@ -0,0 +1,25 @@
package main
import (
"os/exec"
)
// allocate loopback ip
// sudo ip a add 127.0.0.2/24 dev lo
func allocateIP(ip string) error {
cmd := exec.Command("ip", "a", "add", ip+"/24", "dev", "lo")
if err := cmd.Run(); err != nil {
mainLog.Error().Err(err).Msg("allocateIP failed")
return err
}
return nil
}
func deAllocateIP(ip string) error {
cmd := exec.Command("ip", "a", "del", ip+"/24", "dev", "lo")
if err := cmd.Run(); err != nil {
mainLog.Error().Err(err).Msg("deAllocateIP failed")
return err
}
return nil
}

28
cmd/ctrld/os_mac.go Normal file
View File

@@ -0,0 +1,28 @@
//go:build darwin
// +build darwin
package main
import (
"os/exec"
)
// allocate loopback ip
// sudo ifconfig lo0 alias 127.0.0.2 up
func allocateIP(ip string) error {
cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up")
if err := cmd.Run(); err != nil {
mainLog.Error().Err(err).Msg("allocateIP failed")
return err
}
return nil
}
func deAllocateIP(ip string) error {
cmd := exec.Command("ifconfig", "lo0", "-alias", ip)
if err := cmd.Run(); err != nil {
mainLog.Error().Err(err).Msg("deAllocateIP failed")
return err
}
return nil
}

14
cmd/ctrld/os_windows.go Normal file
View File

@@ -0,0 +1,14 @@
//go:build windows
// +build windows
package main
// TODO(cuonglm): implement.
func allocateIP(ip string) error {
return nil
}
// TODO(cuonglm): implement.
func deAllocateIP(ip string) error {
return nil
}

159
cmd/ctrld/prog.go Normal file
View File

@@ -0,0 +1,159 @@
package main
import (
"errors"
"net"
"os"
"strconv"
"sync"
"syscall"
"github.com/kardianos/service"
"github.com/miekg/dns"
"github.com/Control-D-Inc/ctrld"
)
var errWindowsAddrInUse = syscall.Errno(0x2740)
var svcConfig = &service.Config{
Name: "ctrld",
DisplayName: "Control-D Helper Service",
}
type prog struct {
cfg *ctrld.Config
}
func (p *prog) Start(s service.Service) error {
p.cfg = &cfg
go p.run()
return nil
}
func (p *prog) run() {
var wg sync.WaitGroup
wg.Add(len(p.cfg.Listener))
for _, nc := range p.cfg.Network {
for _, cidr := range nc.Cidrs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
proxyLog.Error().Err(err).Str("network", nc.Name).Str("cidr", cidr).Msg("invalid cidr")
continue
}
nc.IPNets = append(nc.IPNets, ipNet)
}
}
for n := range p.cfg.Upstream {
uc := p.cfg.Upstream[n]
uc.Init()
if uc.BootstrapIP == "" {
// resolve it manually and set the bootstrap ip
c := new(dns.Client)
m := new(dns.Msg)
m.SetQuestion(uc.Domain+".", dns.TypeA)
m.RecursionDesired = true
r, _, err := c.Exchange(m, net.JoinHostPort(bootstrapDNS, "53"))
if err != nil {
proxyLog.Error().Err(err).Msgf("could not resolve domain %s for upstream.%s", uc.Domain, n)
} else {
if r.Rcode != dns.RcodeSuccess {
proxyLog.Error().Msgf("could not resolve domain return code: %d, upstream.%s", r.Rcode, n)
} else {
for _, a := range r.Answer {
if ar, ok := a.(*dns.A); ok {
uc.BootstrapIP = ar.A.String()
proxyLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n)
}
}
}
}
}
}
for listenerNum := range p.cfg.Listener {
go func(listenerNum string) {
defer wg.Done()
listenerConfig := p.cfg.Listener[listenerNum]
upstreamConfig := p.cfg.Upstream[listenerNum]
if upstreamConfig == nil {
proxyLog.Error().Msgf("missing upstream config for: [listener.%s]", listenerNum)
return
}
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
proxyLog.Info().Msgf("Starting DNS server on listener.%s: %s", listenerNum, addr)
err := p.serveUDP(listenerNum)
if err != nil && !defaultConfigWritten {
proxyLog.Error().Err(err).Msgf("Unable to start dns proxy on listener.%s", listenerNum)
return
}
if opErr, ok := err.(*net.OpError); ok {
if sErr, ok := opErr.Err.(*os.SyscallError); ok && errors.Is(opErr.Err, syscall.EADDRINUSE) || errors.Is(sErr.Err, errWindowsAddrInUse) {
proxyLog.Warn().Msgf("Address %s already in used, pick a random one", addr)
pc, err := net.ListenPacket("udp", net.JoinHostPort(listenerConfig.IP, "0"))
if err != nil {
proxyLog.Error().Err(err).Msg("failed to listen packet")
return
}
_, portStr, _ := net.SplitHostPort(pc.LocalAddr().String())
port, err := strconv.Atoi(portStr)
if err != nil {
proxyLog.Error().Err(err).Msg("malformed port")
return
}
listenerConfig.Port = port
v.Set("listener", map[string]*ctrld.ListenerConfig{
"0": {
IP: "127.0.0.1",
Port: port,
},
})
writeConfigFile()
proxyLog.Info().Msgf("Starting DNS server on listener.%s: %s", listenerNum, pc.LocalAddr())
// There can be a race between closing the listener and start our own UDP server, but it's
// rare, and we only do this once, so let conservative here.
if err := pc.Close(); err != nil {
proxyLog.Error().Err(err).Msg("failed to close packet conn")
return
}
if err := p.serveUDP(listenerNum); err != nil {
proxyLog.Error().Err(err).Msgf("Unable to start dns proxy on listener.%s", listenerNum)
return
}
}
}
}(listenerNum)
}
wg.Wait()
}
func (p *prog) Stop(s service.Service) error {
if err := p.deAllocateIP(); err != nil {
mainLog.Error().Err(err).Msg("de-allocate ip failed")
return err
}
return nil
}
func (p *prog) allocateIP(ip string) error {
if !p.cfg.Service.AllocateIP {
return nil
}
return allocateIP(ip)
}
func (p *prog) deAllocateIP() error {
if !p.cfg.Service.AllocateIP {
return nil
}
for _, lc := range p.cfg.Listener {
if err := deAllocateIP(lc.IP); err != nil {
return err
}
}
return nil
}