all: implement split upstreams

This commit introduces split upstreams feature, allowing to configure
what ip stack that ctrld will use to connect to upstream.
This commit is contained in:
Cuong Manh Le
2023-04-28 01:12:59 +07:00
committed by Cuong Manh Le
parent 5cad0d6be1
commit b267572b38
10 changed files with 286 additions and 68 deletions

View File

@@ -242,7 +242,8 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
}
}
answer, err := resolve1(n, upstreamConfig, msg)
if err != nil {
// Only do re-bootstrapping if bootstrap ip is not explicitly set by user.
if err != nil && upstreamConfig.BootstrapIP == "" {
ctrld.Log(ctx, mainLog.Debug().Err(err), "could not resolve query on first attempt, retrying...")
// If any error occurred, re-bootstrap transport/ip, retry the request.
upstreamConfig.ReBootstrap()

206
config.go
View File

@@ -4,13 +4,13 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/go-playground/validator/v10"
@@ -22,6 +22,15 @@ import (
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
const (
IpStackBoth = "both"
IpStackV4 = "v4"
IpStackV6 = "v6"
IpStackSplit = "split"
)
var controldParentDomains = []string{"controld.com", "controld.net", "controld.dev"}
// SetConfigName set the config name that ctrld will look for.
// DEPRECATED: use SetConfigNameWithPath instead.
func SetConfigName(v *viper.Viper, name string) {
@@ -118,19 +127,25 @@ type UpstreamConfig struct {
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
Domain string `mapstructure:"-" toml:"-"`
IPStack string `mapstructure:"ip_stack" toml:"ip_stack,omitempty" validate:"ipstack"`
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
// The caller should not access this field directly.
// Use UpstreamSendClientInfo instead.
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
transport *http.Transport `mapstructure:"-" toml:"-"`
http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"`
certPool *x509.CertPool `mapstructure:"-" toml:"-"`
u *url.URL `mapstructure:"-" toml:"-"`
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
g singleflight.Group
mu sync.Mutex
bootstrapIPs []string
nextBootstrapIP atomic.Uint32
g singleflight.Group
mu sync.Mutex
bootstrapIPs []string
bootstrapIPs4 []string
bootstrapIPs6 []string
transport *http.Transport
transport4 *http.Transport
transport6 *http.Transport
http3RoundTripper http.RoundTripper
http3RoundTripper4 http.RoundTripper
http3RoundTripper6 http.RoundTripper
certPool *x509.CertPool
u *url.URL
}
// ListenerConfig specifies the networks configuration that ctrld will run on.
@@ -164,18 +179,23 @@ func (uc *UpstreamConfig) Init() {
uc.u = u
}
}
if uc.Domain != "" {
return
if uc.Domain == "" {
if !strings.Contains(uc.Endpoint, ":") {
uc.Domain = uc.Endpoint
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
}
host, _, _ := net.SplitHostPort(uc.Endpoint)
uc.Domain = host
if net.ParseIP(uc.Domain) != nil {
uc.BootstrapIP = uc.Domain
}
}
if !strings.Contains(uc.Endpoint, ":") {
uc.Domain = uc.Endpoint
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
}
host, _, _ := net.SplitHostPort(uc.Endpoint)
uc.Domain = host
if net.ParseIP(uc.Domain) != nil {
uc.BootstrapIP = uc.Domain
if uc.IPStack == "" {
if uc.isControlD() {
uc.IPStack = IpStackSplit
} else {
uc.IPStack = IpStackBoth
}
}
}
@@ -195,13 +215,8 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
}
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
if u, err := url.Parse(uc.Endpoint); err == nil {
domain := u.Hostname()
for _, parent := range []string{"controld.com", "controld.net"} {
if dns.IsSubDomain(parent, domain) {
return true
}
}
if uc.isControlD() {
return true
}
}
return false
@@ -226,6 +241,13 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
for _, ip := range uc.bootstrapIPs {
if ctrldnet.IsIPv6(ip) {
uc.bootstrapIPs6 = append(uc.bootstrapIPs6, ip)
} else {
uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip)
}
}
ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs)
}
@@ -238,7 +260,6 @@ func (uc *UpstreamConfig) ReBootstrap() {
}
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
ProxyLog.Debug().Msg("re-bootstrapping upstream ip")
uc.BootstrapIP = ""
uc.setupTransportWithoutPingUpstream()
return true, nil
})
@@ -269,19 +290,17 @@ func (uc *UpstreamConfig) setupDOHTransport() {
uc.pingUpstream()
}
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
uc.transport = http.DefaultTransport.(*http.Transport).Clone()
uc.transport.IdleConnTimeout = 5 * time.Second
uc.transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.IdleConnTimeout = 5 * time.Second
transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
dialerTimeoutMs := 2000
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
dialerTimeoutMs = uc.Timeout
}
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
if uc.BootstrapIP != "" {
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
@@ -292,17 +311,42 @@ func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
pd := &ctrldnet.ParallelDialer{}
pd.Timeout = dialerTimeout
pd.KeepAlive = dialerTimeout
addrs := make([]string, len(uc.bootstrapIPs))
for i := range uc.bootstrapIPs {
addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port)
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
conn, err := pd.DialContext(ctx, network, addrs)
conn, err := pd.DialContext(ctx, network, dialAddrs)
if err != nil {
return nil, err
}
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", conn.RemoteAddr())
return conn, nil
}
return transport
}
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
switch uc.IPStack {
case IpStackBoth, "":
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
case IpStackV4:
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
case IpStackV6:
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
case IpStackSplit:
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if ctrldnet.IPv6Available(ctx) {
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
} else {
uc.transport6 = uc.transport4
}
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) pingUpstream() {
@@ -320,6 +364,74 @@ func (uc *UpstreamConfig) pingUpstream() {
_, _ = dnsResolver.Resolve(ctx, msg)
}
func (uc *UpstreamConfig) isControlD() bool {
domain := uc.Domain
if domain == "" {
if u, err := url.Parse(uc.Endpoint); err == nil {
domain = u.Hostname()
}
}
for _, parent := range controldParentDomains {
if dns.IsSubDomain(parent, domain) {
return true
}
}
return false
}
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.transport
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.transport4
default:
return uc.transport6
}
}
return uc.transport
}
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
switch uc.IPStack {
case IpStackBoth:
return pick(uc.bootstrapIPs)
case IpStackV4:
return pick(uc.bootstrapIPs4)
case IpStackV6:
return pick(uc.bootstrapIPs6)
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return pick(uc.bootstrapIPs4)
default:
return pick(uc.bootstrapIPs6)
}
}
return pick(uc.bootstrapIPs)
}
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
switch uc.IPStack {
case IpStackBoth:
return "tcp-tls", "udp"
case IpStackV4:
return "tcp4-tls", "udp4"
case IpStackV6:
return "tcp6-tls", "udp6"
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return "tcp4-tls", "udp4"
default:
return "tcp6-tls", "udp6"
}
}
return "tcp-tls", "udp"
}
// Init initialized necessary values for an ListenerConfig.
func (lc *ListenerConfig) Init() {
if lc.Policy != nil {
@@ -333,6 +445,7 @@ func (lc *ListenerConfig) Init() {
// ValidateConfig validates the given config.
func ValidateConfig(validate *validator.Validate, cfg *Config) error {
_ = validate.RegisterValidation("dnsrcode", validateDnsRcode)
_ = validate.RegisterValidation("ipstack", validateIpStack)
_ = validate.RegisterValidation("iporempty", validateIpOrEmpty)
return validate.Struct(cfg)
}
@@ -341,6 +454,15 @@ func validateDnsRcode(fl validator.FieldLevel) bool {
return dnsrcode.FromString(fl.Field().String()) != -1
}
func validateIpStack(fl validator.FieldLevel) bool {
switch fl.Field().String() {
case IpStackBoth, IpStackV4, IpStackV6, IpStackSplit, "":
return true
default:
return false
}
}
func validateIpOrEmpty(fl validator.FieldLevel) bool {
val := fl.Field().String()
if val == "" {
@@ -384,3 +506,7 @@ func ResolverTypeFromEndpoint(endpoint string) string {
}
return ResolverTypeDOT
}
func pick(s []string) string {
return s[rand.Intn(len(s))]
}

View File

@@ -48,6 +48,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "example.com",
Timeout: 0,
IPStack: IpStackBoth,
u: u1,
},
},
@@ -68,6 +69,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "example.com",
Timeout: 0,
IPStack: IpStackBoth,
u: u2,
},
},
@@ -88,6 +90,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "freedns.controld.com",
Timeout: 0,
IPStack: IpStackSplit,
},
},
{
@@ -99,6 +102,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "",
Timeout: 0,
IPStack: IpStackSplit,
},
&UpstreamConfig{
Name: "dot",
@@ -107,6 +111,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "freedns.controld.com",
Timeout: 0,
IPStack: IpStackSplit,
},
},
{
@@ -126,6 +131,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "1.2.3.4",
Domain: "1.2.3.4",
Timeout: 0,
IPStack: IpStackBoth,
},
},
{
@@ -145,6 +151,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "1.2.3.4",
Domain: "1.2.3.4",
Timeout: 0,
IPStack: IpStackBoth,
},
},
{
@@ -157,6 +164,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
Domain: "",
Timeout: 0,
SendClientInfo: ptrBool(false),
IPStack: IpStackBoth,
},
&UpstreamConfig{
Name: "doh",
@@ -166,6 +174,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
Domain: "example.com",
Timeout: 0,
SendClientInfo: ptrBool(false),
IPStack: IpStackBoth,
u: u2,
},
},

View File

@@ -7,10 +7,15 @@ import (
"crypto/tls"
"errors"
"net"
"net/http"
"sync"
"time"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
func (uc *UpstreamConfig) setupDOH3Transport() {
@@ -18,9 +23,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
uc.pingUpstream()
}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
rt := &http3.RoundTripper{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
@@ -40,20 +43,57 @@ func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
}
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
}
addrs := make([]string, len(uc.bootstrapIPs))
for i := range uc.bootstrapIPs {
addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port)
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
pd := &quicParallelDialer{}
conn, err := pd.Dial(ctx, domain, addrs, tlsCfg, cfg)
conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg)
if err != nil {
return nil, err
}
ProxyLog.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
return rt
}
uc.http3RoundTripper = rt
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
switch uc.IPStack {
case IpStackBoth, "":
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
case IpStackV4:
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
case IpStackV6:
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
case IpStackSplit:
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if ctrldnet.IPv6Available(ctx) {
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
} else {
uc.http3RoundTripper6 = uc.http3RoundTripper4
}
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.http3RoundTripper
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.http3RoundTripper4
default:
return uc.http3RoundTripper6
}
}
return uc.http3RoundTripper
}
// Putting the code for quic parallel dialer here:

View File

@@ -2,6 +2,9 @@
package ctrld
import "net/http"
func (uc *UpstreamConfig) setupDOH3Transport() {}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {}
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { return nil }

View File

@@ -227,9 +227,27 @@ Value `0` means no timeout.
The protocol that `ctrld` will use to send DNS requests to upstream.
- Type: string
- required: yes
- Required: yes
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os`
### ip_stack
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.
- Type: string
- Required: no
- Valid values:
- `both`: using either ipv4 or ipv6.
- `v4`: only dial upstream via IPv4, never dial IPv6.
- `v6`: only dial upstream via IPv6, never dial IPv4.
- `split`:
- If `A` record is requested -> dial via ipv4.
- If `AAAA` or any other record is requested -> dial ipv6 (if available, otherwise ipv4)
If `ip_stack` is empty, or undefined:
- Default value is `both` for non-Control D resolvers.
- Default value is `split` for Control D resolvers.
## Network
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.

16
doh.go
View File

@@ -36,12 +36,12 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver {
}
type dohResolver struct {
uc *UpstreamConfig
endpoint *url.URL
isDoH3 bool
transport *http.Transport
http3RoundTripper http.RoundTripper
sendClientInfo bool
uc *UpstreamConfig
}
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
@@ -61,18 +61,22 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return nil, fmt.Errorf("could not create request: %w", err)
}
addHeader(ctx, req, r.sendClientInfo)
c := http.Client{Transport: r.transport}
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
c := http.Client{Transport: r.uc.dohTransport(dnsTyp)}
if r.isDoH3 {
if r.http3RoundTripper == nil {
transport := r.uc.doh3Transport(dnsTyp)
if transport == nil {
return nil, errors.New("DoH3 is not supported")
}
c.Transport = r.http3RoundTripper
c.Transport = transport
}
resp, err := c.Do(req)
if err != nil {
if r.isDoH3 {
if closer, ok := r.http3RoundTripper.(io.Closer); ok {
if closer, ok := c.Transport.(io.Closer); ok {
closer.Close()
}
}

14
doq.go
View File

@@ -20,11 +20,17 @@ type doqResolver struct {
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
endpoint := r.uc.Endpoint
tlsConfig := &tls.Config{NextProtos: []string{"doq"}}
if r.uc.BootstrapIP != "" {
tlsConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
ip := r.uc.BootstrapIP
if ip == "" {
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
ip = r.uc.bootstrapIPForDNSType(dnsTyp)
}
tlsConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(ip, port)
return resolve(ctx, msg, endpoint, tlsConfig)
}

10
dot.go
View File

@@ -14,13 +14,19 @@ type dotResolver struct {
func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// The dialer is used to prevent bootstrapping cycle.
// If r.endpoing is set to dns.controld.dev, we need to resolve
// If r.endpoint is set to dns.controld.dev, we need to resolve
// dns.controld.dev first. By using a dialer with custom resolver,
// we ensure that we can always resolve the bootstrap domain
// regardless of the machine DNS status.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: "tcp-tls",
Net: tcpNet,
Dialer: dialer,
TLSConfig: &tls.Config{RootCAs: r.uc.certPool},
}

View File

@@ -34,7 +34,7 @@ var errUnknownResolver = errors.New("unknown resolver")
// NewResolver creates a Resolver based on the given upstream config.
func NewResolver(uc *UpstreamConfig) (Resolver, error) {
typ, endpoint := uc.Type, uc.Endpoint
typ := uc.Type
switch typ {
case ResolverTypeDOH, ResolverTypeDOH3:
return newDohResolver(uc), nil
@@ -45,7 +45,7 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
case ResolverTypeOS:
return or, nil
case ResolverTypeLegacy:
return &legacyResolver{endpoint: endpoint}, nil
return &legacyResolver{uc: uc}, nil
}
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
}
@@ -110,17 +110,22 @@ func newDialer(dnsAddress string) *net.Dialer {
}
type legacyResolver struct {
endpoint string
uc *UpstreamConfig
}
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// See comment in (*dotResolver).resolve method.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
_, udpNet := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: "udp",
Net: udpNet,
Dialer: dialer,
}
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint)
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.uc.Endpoint)
return answer, err
}