mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Instead of re-query DNS record for upstream when re-bootstrapping, just query all records on startup, then selecting the next bootstrap ip depends on the current network stack.
316 lines
11 KiB
Go
316 lines
11 KiB
Go
package ctrld
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/go-playground/validator/v10"
|
|
"github.com/miekg/dns"
|
|
"github.com/spf13/viper"
|
|
"golang.org/x/sync/singleflight"
|
|
|
|
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
|
|
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
|
)
|
|
|
|
// SetConfigName set the config name that ctrld will look for.
|
|
func SetConfigName(v *viper.Viper, name string) {
|
|
v.SetConfigName(name)
|
|
|
|
configPath := "$HOME"
|
|
// viper has its own way to get user home directory: https://github.com/spf13/viper/blob/v1.14.0/util.go#L134
|
|
// To be consistent, we prefer os.UserHomeDir instead.
|
|
if homeDir, err := os.UserHomeDir(); err == nil {
|
|
configPath = homeDir
|
|
}
|
|
v.AddConfigPath(configPath)
|
|
v.AddConfigPath(".")
|
|
}
|
|
|
|
// InitConfig initializes default config values for given *viper.Viper instance.
|
|
func InitConfig(v *viper.Viper, name string) {
|
|
SetConfigName(v, name)
|
|
|
|
v.SetDefault("listener", map[string]*ListenerConfig{
|
|
"0": {
|
|
IP: "127.0.0.1",
|
|
Port: 53,
|
|
},
|
|
})
|
|
v.SetDefault("network", map[string]*NetworkConfig{
|
|
"0": {
|
|
Name: "Network 0",
|
|
Cidrs: []string{"0.0.0.0/0"},
|
|
},
|
|
})
|
|
v.SetDefault("upstream", map[string]*UpstreamConfig{
|
|
"0": {
|
|
BootstrapIP: "76.76.2.11",
|
|
Name: "Control D - Anti-Malware",
|
|
Type: ResolverTypeDOH,
|
|
Endpoint: "https://freedns.controld.com/p1",
|
|
Timeout: 5000,
|
|
},
|
|
"1": {
|
|
BootstrapIP: "76.76.2.11",
|
|
Name: "Control D - No Ads",
|
|
Type: ResolverTypeDOQ,
|
|
Endpoint: "p2.freedns.controld.com",
|
|
Timeout: 3000,
|
|
},
|
|
})
|
|
}
|
|
|
|
// Config represents ctrld supported configuration.
|
|
type Config struct {
|
|
Service ServiceConfig `mapstructure:"service" toml:"service,omitempty"`
|
|
Listener map[string]*ListenerConfig `mapstructure:"listener" toml:"listener" validate:"min=1,dive"`
|
|
Network map[string]*NetworkConfig `mapstructure:"network" toml:"network" validate:"min=1,dive"`
|
|
Upstream map[string]*UpstreamConfig `mapstructure:"upstream" toml:"upstream" validate:"min=1,dive"`
|
|
}
|
|
|
|
// ServiceConfig specifies the general ctrld config.
|
|
type ServiceConfig struct {
|
|
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
|
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
|
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
|
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
|
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
|
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
|
Daemon bool `mapstructure:"-" toml:"-"`
|
|
AllocateIP bool `mapstructure:"-" toml:"-"`
|
|
}
|
|
|
|
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
|
|
type NetworkConfig struct {
|
|
Name string `mapstructure:"name" toml:"name,omitempty"`
|
|
Cidrs []string `mapstructure:"cidrs" toml:"cidrs,omitempty" validate:"dive,cidr"`
|
|
IPNets []*net.IPNet `mapstructure:"-" toml:"-"`
|
|
}
|
|
|
|
// UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to.
|
|
type UpstreamConfig struct {
|
|
Name string `mapstructure:"name" toml:"name,omitempty"`
|
|
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
|
|
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
|
|
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
|
|
Domain string `mapstructure:"-" toml:"-"`
|
|
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
|
|
transport *http.Transport `mapstructure:"-" toml:"-"`
|
|
http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"`
|
|
|
|
g singleflight.Group
|
|
bootstrapIPs []string
|
|
nextBootstrapIP atomic.Uint32
|
|
}
|
|
|
|
// ListenerConfig specifies the networks configuration that ctrld will run on.
|
|
type ListenerConfig struct {
|
|
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"ip"`
|
|
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gt=0"`
|
|
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
|
|
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
|
|
}
|
|
|
|
// ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests.
|
|
type ListenerPolicyConfig struct {
|
|
Name string `mapstructure:"name" toml:"name,omitempty"`
|
|
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
|
|
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
|
|
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
|
|
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
|
|
}
|
|
|
|
// Rule is a map from source to list of upstreams.
|
|
// ctrld uses rule to perform requests matching and forward
|
|
// the request to corresponding upstreams if it's matched.
|
|
type Rule map[string][]string
|
|
|
|
// Init initialized necessary values for an UpstreamConfig.
|
|
func (uc *UpstreamConfig) Init() {
|
|
if u, err := url.Parse(uc.Endpoint); err == nil {
|
|
uc.Domain = u.Host
|
|
}
|
|
if uc.Domain != "" {
|
|
return
|
|
}
|
|
|
|
if !strings.Contains(uc.Endpoint, ":") {
|
|
uc.Domain = uc.Endpoint
|
|
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
|
|
}
|
|
host, _, _ := net.SplitHostPort(uc.Endpoint)
|
|
uc.Domain = host
|
|
if net.ParseIP(uc.Domain) != nil {
|
|
uc.BootstrapIP = uc.Domain
|
|
}
|
|
}
|
|
|
|
// SetupBootstrapIP manually find all available IPs of the upstream.
|
|
// The first usable IP will be used as bootstrap IP of the upstream.
|
|
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(uc.Timeout)*time.Millisecond)
|
|
defer cancel()
|
|
|
|
c := new(dns.Client)
|
|
bootstrapIP := func(record dns.RR) string {
|
|
switch ar := record.(type) {
|
|
case *dns.A:
|
|
return ar.A.String()
|
|
case *dns.AAAA:
|
|
return ar.AAAA.String()
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// Find all A, AAAA records of the upstream.
|
|
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
|
|
m := new(dns.Msg)
|
|
m.SetQuestion(uc.Domain+".", dnsType)
|
|
m.RecursionDesired = true
|
|
r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(bootstrapDNS, "53"))
|
|
if err != nil {
|
|
ProxyLog.Error().Err(err).Str("type", dns.TypeToString[dnsType]).Msgf("could not resolve domain %s for upstream", uc.Domain)
|
|
continue
|
|
}
|
|
if r.Rcode != dns.RcodeSuccess {
|
|
ProxyLog.Error().Msgf("could not resolve domain return code: %d, upstream", r.Rcode)
|
|
continue
|
|
}
|
|
if len(r.Answer) == 0 {
|
|
ProxyLog.Error().Msg("no answer from bootstrap DNS server")
|
|
continue
|
|
}
|
|
for _, a := range r.Answer {
|
|
ip := bootstrapIP(a)
|
|
if ip == "" {
|
|
continue
|
|
}
|
|
|
|
// Storing the ip to uc.bootstrapIPs list, so it can be selected later
|
|
// when retrying failed request due to network stack changed.
|
|
uc.bootstrapIPs = append(uc.bootstrapIPs, ip)
|
|
if uc.BootstrapIP == "" {
|
|
// Remember what's the current IP in bootstrap IPs list,
|
|
// so we can select next one upon re-bootstrapping.
|
|
uc.nextBootstrapIP.Add(1)
|
|
|
|
// If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip.
|
|
if !ctrldnet.IPv6Available() && ctrldnet.IsIPv6(ip) {
|
|
continue
|
|
}
|
|
uc.BootstrapIP = ip
|
|
}
|
|
}
|
|
}
|
|
ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs)
|
|
}
|
|
|
|
// ReBootstrap re-setup the bootstrap IP and the transport.
|
|
func (uc *UpstreamConfig) ReBootstrap() {
|
|
_, _, _ = uc.g.Do("rebootstrap", func() (any, error) {
|
|
ProxyLog.Debug().Msg("re-bootstrapping upstream ip")
|
|
n := uint32(len(uc.bootstrapIPs))
|
|
// Only attempt n times, because if there's no usable ip,
|
|
// the bootstrap ip will be kept as-is.
|
|
for i := uint32(0); i < n; i++ {
|
|
// Select the next ip in bootstrap ip list.
|
|
next := uc.nextBootstrapIP.Add(1)
|
|
ip := uc.bootstrapIPs[(next-1)%n]
|
|
if !ctrldnet.IPv6Available() && ctrldnet.IsIPv6(ip) {
|
|
continue
|
|
}
|
|
uc.BootstrapIP = ip
|
|
break
|
|
}
|
|
uc.SetupTransport()
|
|
return true, nil
|
|
})
|
|
}
|
|
|
|
// SetupTransport initializes the network transport used to connect to upstream server.
|
|
// For now, only DoH upstream is supported.
|
|
func (uc *UpstreamConfig) SetupTransport() {
|
|
switch uc.Type {
|
|
case ResolverTypeDOH:
|
|
uc.setupDOHTransport()
|
|
case ResolverTypeDOH3:
|
|
uc.setupDOH3Transport()
|
|
}
|
|
}
|
|
|
|
func (uc *UpstreamConfig) setupDOHTransport() {
|
|
uc.transport = http.DefaultTransport.(*http.Transport).Clone()
|
|
uc.transport.IdleConnTimeout = 5 * time.Second
|
|
uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
dialer := &net.Dialer{
|
|
Timeout: 2 * time.Second,
|
|
KeepAlive: 2 * time.Second,
|
|
}
|
|
Log(ctx, ProxyLog.Debug(), "debug dial context %s - %s - %s", addr, network, bootstrapDNS)
|
|
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
|
if uc.BootstrapIP != "" {
|
|
if _, port, _ := net.SplitHostPort(addr); port != "" {
|
|
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
|
}
|
|
}
|
|
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr)
|
|
return dialer.DialContext(ctx, network, addr)
|
|
}
|
|
|
|
uc.pingUpstream()
|
|
}
|
|
|
|
func (uc *UpstreamConfig) pingUpstream() {
|
|
// Warming up the transport by querying a test packet.
|
|
dnsResolver, err := NewResolver(uc)
|
|
if err != nil {
|
|
ProxyLog.Error().Err(err).Msgf("failed to create resolver for upstream: %s", uc.Name)
|
|
return
|
|
}
|
|
msg := new(dns.Msg)
|
|
msg.SetQuestion(".", dns.TypeNS)
|
|
msg.MsgHdr.RecursionDesired = true
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_, _ = dnsResolver.Resolve(ctx, msg)
|
|
}
|
|
|
|
// Init initialized necessary values for an ListenerConfig.
|
|
func (lc *ListenerConfig) Init() {
|
|
if lc.Policy != nil {
|
|
lc.Policy.FailoverRcodeNumbers = make([]int, len(lc.Policy.FailoverRcodes))
|
|
for i, rcode := range lc.Policy.FailoverRcodes {
|
|
lc.Policy.FailoverRcodeNumbers[i] = dnsrcode.FromString(rcode)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ValidateConfig validates the given config.
|
|
func ValidateConfig(validate *validator.Validate, cfg *Config) error {
|
|
_ = validate.RegisterValidation("dnsrcode", validateDnsRcode)
|
|
return validate.Struct(cfg)
|
|
}
|
|
|
|
func validateDnsRcode(fl validator.FieldLevel) bool {
|
|
return dnsrcode.FromString(fl.Field().String()) != -1
|
|
}
|
|
|
|
func defaultPortFor(typ string) string {
|
|
switch typ {
|
|
case ResolverTypeDOH, ResolverTypeDOH3:
|
|
return "443"
|
|
case ResolverTypeDOQ, ResolverTypeDOT:
|
|
return "853"
|
|
case ResolverTypeLegacy:
|
|
return "53"
|
|
}
|
|
return "53"
|
|
}
|