all: implement split upstreams

This commit introduces split upstreams feature, allowing to configure
what ip stack that ctrld will use to connect to upstream.
This commit is contained in:
Cuong Manh Le
2023-04-28 01:12:59 +07:00
committed by Cuong Manh Le
parent 5cad0d6be1
commit b267572b38
10 changed files with 286 additions and 68 deletions
+2 -1
View File
@@ -242,7 +242,8 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
}
}
answer, err := resolve1(n, upstreamConfig, msg)
if err != nil {
// Only do re-bootstrapping if bootstrap ip is not explicitly set by user.
if err != nil && upstreamConfig.BootstrapIP == "" {
ctrld.Log(ctx, mainLog.Debug().Err(err), "could not resolve query on first attempt, retrying...")
// If any error occurred, re-bootstrap transport/ip, retry the request.
upstreamConfig.ReBootstrap()
+166 -40
View File
@@ -4,13 +4,13 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/go-playground/validator/v10"
@@ -22,6 +22,15 @@ import (
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
const (
IpStackBoth = "both"
IpStackV4 = "v4"
IpStackV6 = "v6"
IpStackSplit = "split"
)
var controldParentDomains = []string{"controld.com", "controld.net", "controld.dev"}
// SetConfigName set the config name that ctrld will look for.
// DEPRECATED: use SetConfigNameWithPath instead.
func SetConfigName(v *viper.Viper, name string) {
@@ -118,19 +127,25 @@ type UpstreamConfig struct {
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
Domain string `mapstructure:"-" toml:"-"`
IPStack string `mapstructure:"ip_stack" toml:"ip_stack,omitempty" validate:"ipstack"`
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
// The caller should not access this field directly.
// Use UpstreamSendClientInfo instead.
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
transport *http.Transport `mapstructure:"-" toml:"-"`
http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"`
certPool *x509.CertPool `mapstructure:"-" toml:"-"`
u *url.URL `mapstructure:"-" toml:"-"`
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
g singleflight.Group
mu sync.Mutex
bootstrapIPs []string
nextBootstrapIP atomic.Uint32
g singleflight.Group
mu sync.Mutex
bootstrapIPs []string
bootstrapIPs4 []string
bootstrapIPs6 []string
transport *http.Transport
transport4 *http.Transport
transport6 *http.Transport
http3RoundTripper http.RoundTripper
http3RoundTripper4 http.RoundTripper
http3RoundTripper6 http.RoundTripper
certPool *x509.CertPool
u *url.URL
}
// ListenerConfig specifies the networks configuration that ctrld will run on.
@@ -164,18 +179,23 @@ func (uc *UpstreamConfig) Init() {
uc.u = u
}
}
if uc.Domain != "" {
return
if uc.Domain == "" {
if !strings.Contains(uc.Endpoint, ":") {
uc.Domain = uc.Endpoint
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
}
host, _, _ := net.SplitHostPort(uc.Endpoint)
uc.Domain = host
if net.ParseIP(uc.Domain) != nil {
uc.BootstrapIP = uc.Domain
}
}
if !strings.Contains(uc.Endpoint, ":") {
uc.Domain = uc.Endpoint
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
}
host, _, _ := net.SplitHostPort(uc.Endpoint)
uc.Domain = host
if net.ParseIP(uc.Domain) != nil {
uc.BootstrapIP = uc.Domain
if uc.IPStack == "" {
if uc.isControlD() {
uc.IPStack = IpStackSplit
} else {
uc.IPStack = IpStackBoth
}
}
}
@@ -195,13 +215,8 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
}
switch uc.Type {
case ResolverTypeDOH, ResolverTypeDOH3:
if u, err := url.Parse(uc.Endpoint); err == nil {
domain := u.Hostname()
for _, parent := range []string{"controld.com", "controld.net"} {
if dns.IsSubDomain(parent, domain) {
return true
}
}
if uc.isControlD() {
return true
}
}
return false
@@ -226,6 +241,13 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
// The first usable IP will be used as bootstrap IP of the upstream.
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
for _, ip := range uc.bootstrapIPs {
if ctrldnet.IsIPv6(ip) {
uc.bootstrapIPs6 = append(uc.bootstrapIPs6, ip)
} else {
uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip)
}
}
ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs)
}
@@ -238,7 +260,6 @@ func (uc *UpstreamConfig) ReBootstrap() {
}
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
ProxyLog.Debug().Msg("re-bootstrapping upstream ip")
uc.BootstrapIP = ""
uc.setupTransportWithoutPingUpstream()
return true, nil
})
@@ -269,19 +290,17 @@ func (uc *UpstreamConfig) setupDOHTransport() {
uc.pingUpstream()
}
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
uc.transport = http.DefaultTransport.(*http.Transport).Clone()
uc.transport.IdleConnTimeout = 5 * time.Second
uc.transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.IdleConnTimeout = 5 * time.Second
transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
dialerTimeoutMs := 2000
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
dialerTimeoutMs = uc.Timeout
}
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
if uc.BootstrapIP != "" {
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
@@ -292,17 +311,42 @@ func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
pd := &ctrldnet.ParallelDialer{}
pd.Timeout = dialerTimeout
pd.KeepAlive = dialerTimeout
addrs := make([]string, len(uc.bootstrapIPs))
for i := range uc.bootstrapIPs {
addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port)
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
conn, err := pd.DialContext(ctx, network, addrs)
conn, err := pd.DialContext(ctx, network, dialAddrs)
if err != nil {
return nil, err
}
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", conn.RemoteAddr())
return conn, nil
}
return transport
}
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
switch uc.IPStack {
case IpStackBoth, "":
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
case IpStackV4:
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
case IpStackV6:
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
case IpStackSplit:
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if ctrldnet.IPv6Available(ctx) {
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
} else {
uc.transport6 = uc.transport4
}
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) pingUpstream() {
@@ -320,6 +364,74 @@ func (uc *UpstreamConfig) pingUpstream() {
_, _ = dnsResolver.Resolve(ctx, msg)
}
func (uc *UpstreamConfig) isControlD() bool {
domain := uc.Domain
if domain == "" {
if u, err := url.Parse(uc.Endpoint); err == nil {
domain = u.Hostname()
}
}
for _, parent := range controldParentDomains {
if dns.IsSubDomain(parent, domain) {
return true
}
}
return false
}
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.transport
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.transport4
default:
return uc.transport6
}
}
return uc.transport
}
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
switch uc.IPStack {
case IpStackBoth:
return pick(uc.bootstrapIPs)
case IpStackV4:
return pick(uc.bootstrapIPs4)
case IpStackV6:
return pick(uc.bootstrapIPs6)
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return pick(uc.bootstrapIPs4)
default:
return pick(uc.bootstrapIPs6)
}
}
return pick(uc.bootstrapIPs)
}
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
switch uc.IPStack {
case IpStackBoth:
return "tcp-tls", "udp"
case IpStackV4:
return "tcp4-tls", "udp4"
case IpStackV6:
return "tcp6-tls", "udp6"
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return "tcp4-tls", "udp4"
default:
return "tcp6-tls", "udp6"
}
}
return "tcp-tls", "udp"
}
// Init initialized necessary values for an ListenerConfig.
func (lc *ListenerConfig) Init() {
if lc.Policy != nil {
@@ -333,6 +445,7 @@ func (lc *ListenerConfig) Init() {
// ValidateConfig validates the given config.
func ValidateConfig(validate *validator.Validate, cfg *Config) error {
_ = validate.RegisterValidation("dnsrcode", validateDnsRcode)
_ = validate.RegisterValidation("ipstack", validateIpStack)
_ = validate.RegisterValidation("iporempty", validateIpOrEmpty)
return validate.Struct(cfg)
}
@@ -341,6 +454,15 @@ func validateDnsRcode(fl validator.FieldLevel) bool {
return dnsrcode.FromString(fl.Field().String()) != -1
}
func validateIpStack(fl validator.FieldLevel) bool {
switch fl.Field().String() {
case IpStackBoth, IpStackV4, IpStackV6, IpStackSplit, "":
return true
default:
return false
}
}
func validateIpOrEmpty(fl validator.FieldLevel) bool {
val := fl.Field().String()
if val == "" {
@@ -384,3 +506,7 @@ func ResolverTypeFromEndpoint(endpoint string) string {
}
return ResolverTypeDOT
}
func pick(s []string) string {
return s[rand.Intn(len(s))]
}
+9
View File
@@ -48,6 +48,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "example.com",
Timeout: 0,
IPStack: IpStackBoth,
u: u1,
},
},
@@ -68,6 +69,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "example.com",
Timeout: 0,
IPStack: IpStackBoth,
u: u2,
},
},
@@ -88,6 +90,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "freedns.controld.com",
Timeout: 0,
IPStack: IpStackSplit,
},
},
{
@@ -99,6 +102,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "",
Timeout: 0,
IPStack: IpStackSplit,
},
&UpstreamConfig{
Name: "dot",
@@ -107,6 +111,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "",
Domain: "freedns.controld.com",
Timeout: 0,
IPStack: IpStackSplit,
},
},
{
@@ -126,6 +131,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "1.2.3.4",
Domain: "1.2.3.4",
Timeout: 0,
IPStack: IpStackBoth,
},
},
{
@@ -145,6 +151,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
BootstrapIP: "1.2.3.4",
Domain: "1.2.3.4",
Timeout: 0,
IPStack: IpStackBoth,
},
},
{
@@ -157,6 +164,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
Domain: "",
Timeout: 0,
SendClientInfo: ptrBool(false),
IPStack: IpStackBoth,
},
&UpstreamConfig{
Name: "doh",
@@ -166,6 +174,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
Domain: "example.com",
Timeout: 0,
SendClientInfo: ptrBool(false),
IPStack: IpStackBoth,
u: u2,
},
},
+48 -8
View File
@@ -7,10 +7,15 @@ import (
"crypto/tls"
"errors"
"net"
"net/http"
"sync"
"time"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
)
func (uc *UpstreamConfig) setupDOH3Transport() {
@@ -18,9 +23,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
uc.pingUpstream()
}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
rt := &http3.RoundTripper{}
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
@@ -40,20 +43,57 @@ func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
}
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
}
addrs := make([]string, len(uc.bootstrapIPs))
for i := range uc.bootstrapIPs {
addrs[i] = net.JoinHostPort(uc.bootstrapIPs[i], port)
dialAddrs := make([]string, len(addrs))
for i := range addrs {
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
}
pd := &quicParallelDialer{}
conn, err := pd.Dial(ctx, domain, addrs, tlsCfg, cfg)
conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg)
if err != nil {
return nil, err
}
ProxyLog.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
return conn, err
}
return rt
}
uc.http3RoundTripper = rt
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
uc.mu.Lock()
defer uc.mu.Unlock()
switch uc.IPStack {
case IpStackBoth, "":
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
case IpStackV4:
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
case IpStackV6:
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
case IpStackSplit:
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if ctrldnet.IPv6Available(ctx) {
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
} else {
uc.http3RoundTripper6 = uc.http3RoundTripper4
}
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
}
}
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
switch uc.IPStack {
case IpStackBoth, IpStackV4, IpStackV6:
return uc.http3RoundTripper
case IpStackSplit:
switch dnsType {
case dns.TypeA:
return uc.http3RoundTripper4
default:
return uc.http3RoundTripper6
}
}
return uc.http3RoundTripper
}
// Putting the code for quic parallel dialer here:
+4 -1
View File
@@ -2,6 +2,9 @@
package ctrld
import "net/http"
func (uc *UpstreamConfig) setupDOH3Transport() {}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {}
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {}
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { return nil }
+19 -1
View File
@@ -227,9 +227,27 @@ Value `0` means no timeout.
The protocol that `ctrld` will use to send DNS requests to upstream.
- Type: string
- required: yes
- Required: yes
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os`
### ip_stack
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.
- Type: string
- Required: no
- Valid values:
- `both`: using either ipv4 or ipv6.
- `v4`: only dial upstream via IPv4, never dial IPv6.
- `v6`: only dial upstream via IPv6, never dial IPv4.
- `split`:
- If `A` record is requested -> dial via ipv4.
- If `AAAA` or any other record is requested -> dial ipv6 (if available, otherwise ipv4)
If `ip_stack` is empty, or undefined:
- Default value is `both` for non-Control D resolvers.
- Default value is `split` for Control D resolvers.
## Network
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.
+10 -6
View File
@@ -36,12 +36,12 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver {
}
type dohResolver struct {
uc *UpstreamConfig
endpoint *url.URL
isDoH3 bool
transport *http.Transport
http3RoundTripper http.RoundTripper
sendClientInfo bool
uc *UpstreamConfig
}
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
@@ -61,18 +61,22 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return nil, fmt.Errorf("could not create request: %w", err)
}
addHeader(ctx, req, r.sendClientInfo)
c := http.Client{Transport: r.transport}
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
c := http.Client{Transport: r.uc.dohTransport(dnsTyp)}
if r.isDoH3 {
if r.http3RoundTripper == nil {
transport := r.uc.doh3Transport(dnsTyp)
if transport == nil {
return nil, errors.New("DoH3 is not supported")
}
c.Transport = r.http3RoundTripper
c.Transport = transport
}
resp, err := c.Do(req)
if err != nil {
if r.isDoH3 {
if closer, ok := r.http3RoundTripper.(io.Closer); ok {
if closer, ok := c.Transport.(io.Closer); ok {
closer.Close()
}
}
+10 -4
View File
@@ -20,11 +20,17 @@ type doqResolver struct {
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
endpoint := r.uc.Endpoint
tlsConfig := &tls.Config{NextProtos: []string{"doq"}}
if r.uc.BootstrapIP != "" {
tlsConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
ip := r.uc.BootstrapIP
if ip == "" {
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
ip = r.uc.bootstrapIPForDNSType(dnsTyp)
}
tlsConfig.ServerName = r.uc.Domain
_, port, _ := net.SplitHostPort(endpoint)
endpoint = net.JoinHostPort(ip, port)
return resolve(ctx, msg, endpoint, tlsConfig)
}
+8 -2
View File
@@ -14,13 +14,19 @@ type dotResolver struct {
func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// The dialer is used to prevent bootstrapping cycle.
// If r.endpoing is set to dns.controld.dev, we need to resolve
// If r.endpoint is set to dns.controld.dev, we need to resolve
// dns.controld.dev first. By using a dialer with custom resolver,
// we ensure that we can always resolve the bootstrap domain
// regardless of the machine DNS status.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: "tcp-tls",
Net: tcpNet,
Dialer: dialer,
TLSConfig: &tls.Config{RootCAs: r.uc.certPool},
}
+10 -5
View File
@@ -34,7 +34,7 @@ var errUnknownResolver = errors.New("unknown resolver")
// NewResolver creates a Resolver based on the given upstream config.
func NewResolver(uc *UpstreamConfig) (Resolver, error) {
typ, endpoint := uc.Type, uc.Endpoint
typ := uc.Type
switch typ {
case ResolverTypeDOH, ResolverTypeDOH3:
return newDohResolver(uc), nil
@@ -45,7 +45,7 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
case ResolverTypeOS:
return or, nil
case ResolverTypeLegacy:
return &legacyResolver{endpoint: endpoint}, nil
return &legacyResolver{uc: uc}, nil
}
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
}
@@ -110,17 +110,22 @@ func newDialer(dnsAddress string) *net.Dialer {
}
type legacyResolver struct {
endpoint string
uc *UpstreamConfig
}
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
// See comment in (*dotResolver).resolve method.
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
dnsTyp := uint16(0)
if len(msg.Question) > 0 {
dnsTyp = msg.Question[0].Qtype
}
_, udpNet := r.uc.netForDNSType(dnsTyp)
dnsClient := &dns.Client{
Net: "udp",
Net: udpNet,
Dialer: dialer,
}
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint)
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.uc.Endpoint)
return answer, err
}