mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Compare commits
88 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
998b9a5c5d | ||
|
|
0084e9ef26 | ||
|
|
122600bff2 | ||
|
|
41846b6d4c | ||
|
|
dfbcb1489d | ||
|
|
684019c2e3 | ||
|
|
e92619620d | ||
|
|
cebfd12d5c | ||
|
|
874ff01ab8 | ||
|
|
0bb8703f78 | ||
|
|
0bb51aa71d | ||
|
|
af2c1c87e0 | ||
|
|
8939debbc0 | ||
|
|
7591a0ccc6 | ||
|
|
c3ff8182af | ||
|
|
5897c174d3 | ||
|
|
f9a3f4c045 | ||
|
|
a2cb895cdc | ||
|
|
2bebe93e47 | ||
|
|
28ec1869fc | ||
|
|
17f6d7a77b | ||
|
|
9e6e647ff8 | ||
|
|
a2116e5eb5 | ||
|
|
564c9ef712 | ||
|
|
856abb71b7 | ||
|
|
0a30fdea69 | ||
|
|
4f125cf107 | ||
|
|
494d8be777 | ||
|
|
cd9c750884 | ||
|
|
91d319804b | ||
|
|
180eae60f2 | ||
|
|
d01f5c2777 | ||
|
|
294a90a807 | ||
|
|
c3b4ae9c79 | ||
|
|
09188bedf7 | ||
|
|
4614b98e94 | ||
|
|
990bc620f7 | ||
|
|
efb5a92571 | ||
|
|
8e0a96a44c | ||
|
|
43ff2f648c | ||
|
|
4816a09e3a | ||
|
|
3fea92c8b1 | ||
|
|
63f959c951 | ||
|
|
44ba6aadd9 | ||
|
|
d88cf52b4e | ||
|
|
58a00ea24a | ||
|
|
712b23a4bb | ||
|
|
baf836557c | ||
|
|
904b23eeac | ||
|
|
6aafe445f5 | ||
|
|
ebd516855b | ||
|
|
df4e04719e | ||
|
|
2440d922c6 | ||
|
|
f1b8d1c4ad | ||
|
|
79076bda35 | ||
|
|
9d2ea15346 | ||
|
|
77c1113ff7 | ||
|
|
e03ad4cd77 | ||
|
|
6e28517454 | ||
|
|
8ddbf881b3 | ||
|
|
c58516cfb0 | ||
|
|
34758f6205 | ||
|
|
a9959a6f3d | ||
|
|
511c4e696f | ||
|
|
bed7435b0c | ||
|
|
507c1afd59 | ||
|
|
2765487f10 | ||
|
|
80a88811cd | ||
|
|
823195c504 | ||
|
|
0f3e8c7ada | ||
|
|
ee5eb4fc4e | ||
|
|
d58d8074f4 | ||
|
|
94a0530991 | ||
|
|
073af0f89c | ||
|
|
6028b8f186 | ||
|
|
126477ef88 | ||
|
|
13391fd469 | ||
|
|
82e44b01af | ||
|
|
e355fd70ab | ||
|
|
d5c171735e | ||
|
|
b175368794 | ||
|
|
bcf4c25ba8 | ||
|
|
11b09af76d | ||
|
|
af0380a96a | ||
|
|
f39512b4c0 | ||
|
|
7ce62ccaec | ||
|
|
44c0a06996 | ||
|
|
f7d3db06c6 |
10
README.md
10
README.md
@@ -76,8 +76,8 @@ $ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
or
|
||||
|
||||
```
|
||||
$ docker build -t controld/ctrld .
|
||||
$ docker run -d --name=ctrld -p 53:53/tcp -p 53:53/udp controld/ctrld --cd=RESOLVER_ID_GOES_HERE -vv
|
||||
$ docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
$ docker run -d --name=ctrld -p 53:53/tcp -p 53:53/udp controldns/ctrld --cd=RESOLVER_ID_GOES_HERE -vv
|
||||
```
|
||||
|
||||
|
||||
@@ -188,8 +188,8 @@ See [Configuration Docs](docs/config.md).
|
||||
[listener]
|
||||
|
||||
[listener.0]
|
||||
ip = "127.0.0.1"
|
||||
port = 53
|
||||
ip = ""
|
||||
port = 0
|
||||
restricted = false
|
||||
|
||||
[network]
|
||||
@@ -220,6 +220,8 @@ See [Configuration Docs](docs/config.md).
|
||||
|
||||
```
|
||||
|
||||
`ctrld` will pick a working config for `listener.0` then writing the default config to disk for the first run.
|
||||
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
|
||||
|
||||
@@ -5,9 +5,11 @@ type ClientInfoCtxKey struct{}
|
||||
|
||||
// ClientInfo represents ctrld's clients information.
|
||||
type ClientInfo struct {
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Self bool
|
||||
ClientIDPref string
|
||||
}
|
||||
|
||||
// LeaseFileFormat specifies the format of DHCP lease file.
|
||||
|
||||
798
cmd/cli/cli.go
798
cmd/cli/cli.go
File diff suppressed because it is too large
Load Diff
@@ -6,14 +6,18 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
contentTypeJson = "application/json"
|
||||
listClientsPath = "/clients"
|
||||
startedPath = "/started"
|
||||
reloadPath = "/reload"
|
||||
)
|
||||
|
||||
type controlServer struct {
|
||||
@@ -75,6 +79,52 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusRequestTimeout)
|
||||
}
|
||||
}))
|
||||
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
listeners := make(map[string]*ctrld.ListenerConfig)
|
||||
p.mu.Lock()
|
||||
for k, v := range p.cfg.Listener {
|
||||
listeners[k] = &ctrld.ListenerConfig{
|
||||
IP: v.IP,
|
||||
Port: v.Port,
|
||||
}
|
||||
}
|
||||
oldSvc := p.cfg.Service
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.reloadDoneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Checking for cases that we could not do a reload.
|
||||
|
||||
// 1. Listener config ip or port changes.
|
||||
for k, v := range p.cfg.Listener {
|
||||
l := listeners[k]
|
||||
if l == nil || l.IP != v.IP || l.Port != v.Port {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Service config changes.
|
||||
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, reload is done.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
}
|
||||
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
|
||||
@@ -4,11 +4,10 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -16,10 +15,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"go4.org/mem"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/util/lineread"
|
||||
"tailscale.com/net/netaddr"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
|
||||
const (
|
||||
staleTTL = 60 * time.Second
|
||||
localTTL = 3600 * time.Second
|
||||
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
||||
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
||||
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
||||
@@ -40,6 +40,29 @@ var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "Private resolver",
|
||||
Type: ctrld.ResolverTypePrivate,
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
// proxyRequest contains data for proxying a DNS query to upstream.
|
||||
type proxyRequest struct {
|
||||
msg *dns.Msg
|
||||
ci *ctrld.ClientInfo
|
||||
failoverRcodes []int
|
||||
ufr *upstreamForResult
|
||||
}
|
||||
|
||||
// upstreamForResult represents the result of processing rules for a request.
|
||||
type upstreamForResult struct {
|
||||
upstreams []string
|
||||
matchedPolicy string
|
||||
matchedNetwork string
|
||||
matchedRule string
|
||||
matched bool
|
||||
}
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string) error {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
@@ -47,36 +70,58 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
mainLog.Load().Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
|
||||
return allocErr
|
||||
}
|
||||
var failoverRcodes []int
|
||||
if listenerConfig.Policy != nil {
|
||||
failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
p.sema.acquire()
|
||||
defer p.sema.release()
|
||||
if len(m.Question) == 0 {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeFormatError)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
reqId := requestID()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
if !listenerConfig.AllowWanClients && isWanClient(w.RemoteAddr()) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "query refused, listener does not allow WAN clients: %s", w.RemoteAddr().String())
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
go p.detectLoop(m)
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
reqId := requestID()
|
||||
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
|
||||
mac := macFromMsg(m)
|
||||
ci := p.getClientInfo(remoteIP, mac)
|
||||
ci := p.getClientInfo(remoteIP, m)
|
||||
ci.ClientIDPref = p.cfg.Service.ClientIDPref
|
||||
stripClientSubnet(m)
|
||||
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), ci)
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String())
|
||||
t := time.Now()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
||||
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, domain)
|
||||
res := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, ci.Mac, domain)
|
||||
var answer *dns.Msg
|
||||
if !matched && listenerConfig.Restricted {
|
||||
if !res.matched && listenerConfig.Restricted {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", remoteAddr.String())
|
||||
answer = new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
} else {
|
||||
answer = p.proxy(ctx, upstreams, failoverRcodes, m, ci)
|
||||
var failoverRcode []int
|
||||
if listenerConfig.Policy != nil {
|
||||
failoverRcode = listenerConfig.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
answer = p.proxy(ctx, &proxyRequest{
|
||||
msg: m,
|
||||
ci: ci,
|
||||
failoverRcodes: failoverRcode,
|
||||
ufr: res,
|
||||
})
|
||||
rtt := time.Since(t)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
}
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveUDP: failed to send DNS response to client")
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "serveDNS: failed to send DNS response to client")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -102,7 +147,7 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
// addresses of the machine. So ctrld could receive queries from LAN clients.
|
||||
if needRFC1918Listeners(listenerConfig) {
|
||||
g.Go(func() error {
|
||||
for _, addr := range rfc1918Addresses() {
|
||||
for _, addr := range ctrld.Rfc1918Addresses() {
|
||||
func() {
|
||||
listenAddr := net.JoinHostPort(addr, strconv.Itoa(listenerConfig.Port))
|
||||
s, errCh := runDNSServer(listenAddr, proto, handler)
|
||||
@@ -121,7 +166,8 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
})
|
||||
}
|
||||
g.Go(func() error {
|
||||
s, errCh := runDNSServer(dnsListenAddress(listenerConfig), proto, handler)
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
s, errCh := runDNSServer(addr, proto, handler)
|
||||
defer s.Shutdown()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
@@ -148,27 +194,24 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
// Though domain policy has higher priority than network policy, it is still
|
||||
// processed later, because policy logging want to know whether a network rule
|
||||
// is disregarded in favor of the domain level rule.
|
||||
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
|
||||
upstreams := []string{"upstream." + defaultUpstreamNum}
|
||||
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, srcMac, domain string) (res *upstreamForResult) {
|
||||
upstreams := []string{upstreamPrefix + defaultUpstreamNum}
|
||||
matchedPolicy := "no policy"
|
||||
matchedNetwork := "no network"
|
||||
matchedRule := "no rule"
|
||||
matched := false
|
||||
res = &upstreamForResult{}
|
||||
|
||||
defer func() {
|
||||
if !matched && lc.Restricted {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "query refused, %s does not match any network policy", addr.String())
|
||||
return
|
||||
}
|
||||
if matched {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams)
|
||||
} else {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
}
|
||||
res.upstreams = upstreams
|
||||
res.matched = matched
|
||||
res.matchedPolicy = matchedPolicy
|
||||
res.matchedNetwork = matchedNetwork
|
||||
res.matchedRule = matchedRule
|
||||
}()
|
||||
|
||||
if lc.Policy == nil {
|
||||
return upstreams, false
|
||||
return
|
||||
}
|
||||
|
||||
do := func(policyUpstreams []string) {
|
||||
@@ -204,6 +247,19 @@ networkRules:
|
||||
}
|
||||
}
|
||||
|
||||
macRules:
|
||||
for _, rule := range lc.Policy.Macs {
|
||||
for source, targets := range rule {
|
||||
if source != "" && strings.EqualFold(source, srcMac) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
matchedNetwork = source
|
||||
networkTargets = targets
|
||||
matched = true
|
||||
break macRules
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
// There's only one entry per rule, config validation ensures this.
|
||||
for source, targets := range rule {
|
||||
@@ -215,7 +271,7 @@ networkRules:
|
||||
matchedRule = source
|
||||
do(targets)
|
||||
matched = true
|
||||
return upstreams, matched
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -224,26 +280,134 @@ networkRules:
|
||||
do(networkTargets)
|
||||
}
|
||||
|
||||
return upstreams, matched
|
||||
return
|
||||
}
|
||||
|
||||
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg, ci *ctrld.ClientInfo) *dns.Msg {
|
||||
func (p *prog) proxyPrivatePtrLookup(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
||||
cDomainName := msg.Question[0].Name
|
||||
locked := p.ptrLoopGuard.TryLock(cDomainName)
|
||||
defer p.ptrLoopGuard.Unlock(cDomainName)
|
||||
if !locked {
|
||||
return nil
|
||||
}
|
||||
ip := ipFromARPA(cDomainName)
|
||||
if name := p.ciTable.LookupHostname(ip.String(), ""); name != "" {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetReply(msg)
|
||||
answer.Compress = true
|
||||
answer.Answer = []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Ptr: dns.Fqdn(name),
|
||||
}}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using client info table")
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
Mac: p.ciTable.LookupMac(ip.String()),
|
||||
IP: ip.String(),
|
||||
Hostname: name,
|
||||
})
|
||||
return answer
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) proxyLanHostnameQuery(ctx context.Context, msg *dns.Msg) *dns.Msg {
|
||||
q := msg.Question[0]
|
||||
hostname := strings.TrimSuffix(q.Name, ".")
|
||||
locked := p.lanLoopGuard.TryLock(hostname)
|
||||
defer p.lanLoopGuard.Unlock(hostname)
|
||||
if !locked {
|
||||
return nil
|
||||
}
|
||||
if ip := p.ciTable.LookupIPByHostname(hostname, q.Qtype == dns.TypeAAAA); ip != nil {
|
||||
answer := new(dns.Msg)
|
||||
answer.SetReply(msg)
|
||||
answer.Compress = true
|
||||
switch {
|
||||
case ip.Is4():
|
||||
answer.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(localTTL.Seconds()),
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}}
|
||||
case ip.Is6():
|
||||
answer.Answer = []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: uint32(localTTL.Seconds()),
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}}
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using client info table")
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "client info: %v", ctrld.ClientInfo{
|
||||
Mac: p.ciTable.LookupMac(ip.String()),
|
||||
IP: ip.String(),
|
||||
Hostname: hostname,
|
||||
})
|
||||
return answer
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) proxy(ctx context.Context, req *proxyRequest) *dns.Msg {
|
||||
var staleAnswer *dns.Msg
|
||||
upstreams := req.ufr.upstreams
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{"upstream.os"}
|
||||
upstreams = []string{upstreamOS}
|
||||
}
|
||||
|
||||
// LAN/PTR lookup flow:
|
||||
//
|
||||
// 1. If there's matching rule, follow it.
|
||||
// 2. Try from client info table.
|
||||
// 3. Try private resolver.
|
||||
// 4. Try remote upstream.
|
||||
isLanOrPtrQuery := false
|
||||
if req.ufr.matched {
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
switch {
|
||||
case isPrivatePtrLookup(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
|
||||
return answer
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
case isLanHostnameQuery(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyLanHostnameQuery(ctx, req.msg); answer != nil {
|
||||
return answer
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
default:
|
||||
ctrld.Log(ctx, mainLog.Load().Info(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
}
|
||||
}
|
||||
|
||||
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
||||
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
|
||||
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
||||
for _, upstream := range upstreams {
|
||||
cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream))
|
||||
cachedValue := p.cache.Get(dnscache.NewKey(req.msg, upstream))
|
||||
if cachedValue == nil {
|
||||
continue
|
||||
}
|
||||
answer := cachedValue.Msg.Copy()
|
||||
answer.SetRcode(msg, answer.Rcode)
|
||||
answer.SetRcode(req.msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
|
||||
@@ -270,13 +434,24 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
return dnsResolver.Resolve(resolveCtx, msg)
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && ci != nil {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci)
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
||||
}
|
||||
answer, err := resolve1(n, upstreamConfig, msg)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
|
||||
if errNetworkError(err) {
|
||||
p.um.increaseFailureCount(upstreams[n])
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
go p.um.checkUpstream(upstreams[n], upstreamConfig)
|
||||
}
|
||||
}
|
||||
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
||||
var e net.Error
|
||||
if errors.As(err, &e) && e.Timeout() {
|
||||
upstreamConfig.ReBootstrap()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return answer
|
||||
@@ -285,7 +460,15 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, msg)
|
||||
if p.isLoop(upstreamConfig) {
|
||||
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint)
|
||||
continue
|
||||
}
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, req.msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
|
||||
@@ -295,7 +478,13 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
}
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
|
||||
// We are doing LAN/PTR lookup using private resolver, so always process next one.
|
||||
// Except for the last, we want to send response instead of saying all upstream failed.
|
||||
if answer.Rcode != dns.RcodeSuccess && isLanOrPtrQuery && n != len(upstreamConfigs)-1 {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "no response from %s, process to next upstream", upstreams[n])
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(req.failoverRcodes, answer.Rcode) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "failover rcode matched, process to next upstream")
|
||||
continue
|
||||
}
|
||||
@@ -303,7 +492,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
// set compression, as it is not set by default when unpacking
|
||||
answer.Compress = true
|
||||
|
||||
if p.cache != nil {
|
||||
if p.cache != nil && req.msg.Question[0].Qtype != dns.TypePTR {
|
||||
ttl := ttlFromMsg(answer)
|
||||
now := time.Now()
|
||||
expired := now.Add(time.Duration(ttl) * time.Second)
|
||||
@@ -311,21 +500,31 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
||||
}
|
||||
setCachedAnswerTTL(answer, now, expired)
|
||||
p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
||||
p.cache.Add(dnscache.NewKey(req.msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "add cached response")
|
||||
}
|
||||
return answer
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all upstreams failed")
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(msg, dns.RcodeServerFailure)
|
||||
answer.SetRcode(req.msg, dns.RcodeServerFailure)
|
||||
return answer
|
||||
}
|
||||
|
||||
func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
||||
if len(p.localUpstreams) > 0 {
|
||||
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
||||
tmp = append(tmp, p.localUpstreams...)
|
||||
tmp = append(tmp, upstreams...)
|
||||
return tmp, p.upstreamConfigsFromUpstreamNumbers(tmp)
|
||||
}
|
||||
return append([]string{upstreamOS}, upstreams...), append([]*ctrld.UpstreamConfig{privateUpstreamConfig}, upstreamConfigs...)
|
||||
}
|
||||
|
||||
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
||||
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
|
||||
upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix)
|
||||
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
||||
}
|
||||
return upstreamConfigs
|
||||
@@ -422,29 +621,41 @@ func needLocalIPv6Listener() bool {
|
||||
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
|
||||
}
|
||||
|
||||
func dnsListenAddress(lc *ctrld.ListenerConfig) string {
|
||||
// If we are inside container and the listener loopback address, change
|
||||
// the address to something like 0.0.0.0:53, so user can expose the port to outside.
|
||||
if inContainer() {
|
||||
if ip := net.ParseIP(lc.IP); ip != nil && ip.IsLoopback() {
|
||||
return net.JoinHostPort("0.0.0.0", strconv.Itoa(lc.Port))
|
||||
}
|
||||
}
|
||||
return net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
|
||||
}
|
||||
|
||||
func macFromMsg(msg *dns.Msg) string {
|
||||
// ipAndMacFromMsg extracts IP and MAC information included in a DNS message, if any.
|
||||
func ipAndMacFromMsg(msg *dns.Msg) (string, string) {
|
||||
ip, mac := "", ""
|
||||
if opt := msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
switch e := s.(type) {
|
||||
case *dns.EDNS0_LOCAL:
|
||||
if e.Code == EDNS0_OPTION_MAC {
|
||||
return net.HardwareAddr(e.Data).String()
|
||||
mac = net.HardwareAddr(e.Data).String()
|
||||
}
|
||||
case *dns.EDNS0_SUBNET:
|
||||
if len(e.Address) > 0 && !e.Address.IsLoopback() {
|
||||
ip = e.Address.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return ip, mac
|
||||
}
|
||||
|
||||
// stripClientSubnet removes EDNS0_SUBNET from DNS message if the IP is RFC1918 or loopback address,
|
||||
// passing them to upstream is pointless, these cannot be used by anything on the WAN.
|
||||
func stripClientSubnet(msg *dns.Msg) {
|
||||
if opt := msg.IsEdns0(); opt != nil {
|
||||
opts := make([]dns.EDNS0, 0, len(opt.Option))
|
||||
for _, s := range opt.Option {
|
||||
if e, ok := s.(*dns.EDNS0_SUBNET); ok && (e.Address.IsPrivate() || e.Address.IsLoopback()) {
|
||||
continue
|
||||
}
|
||||
opts = append(opts, s)
|
||||
}
|
||||
if len(opts) != len(opt.Option) {
|
||||
opt.Option = opts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
||||
@@ -498,72 +709,180 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha
|
||||
return s, errCh
|
||||
}
|
||||
|
||||
// inContainer reports whether we're running in a container.
|
||||
//
|
||||
// Copied from https://github.com/tailscale/tailscale/blob/v1.42.0/hostinfo/hostinfo.go#L260
|
||||
// with modification for ctrld usage.
|
||||
func inContainer() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
func (p *prog) getClientInfo(remoteIP string, msg *dns.Msg) *ctrld.ClientInfo {
|
||||
ci := &ctrld.ClientInfo{}
|
||||
if p.appCallback != nil {
|
||||
ci.IP = p.appCallback.LanIp()
|
||||
ci.Mac = p.appCallback.MacAddress()
|
||||
ci.Hostname = p.appCallback.HostName()
|
||||
ci.Self = true
|
||||
return ci
|
||||
}
|
||||
ci.IP, ci.Mac = ipAndMacFromMsg(msg)
|
||||
switch {
|
||||
case ci.IP != "" && ci.Mac != "":
|
||||
// Nothing to do.
|
||||
case ci.IP == "" && ci.Mac != "":
|
||||
// Have MAC, no IP.
|
||||
ci.IP = p.ciTable.LookupIP(ci.Mac)
|
||||
case ci.IP == "" && ci.Mac == "":
|
||||
// Have nothing, use remote IP then lookup MAC.
|
||||
ci.IP = remoteIP
|
||||
fallthrough
|
||||
case ci.IP != "" && ci.Mac == "":
|
||||
// Have IP, no MAC.
|
||||
ci.Mac = p.ciTable.LookupMac(ci.IP)
|
||||
}
|
||||
|
||||
var ret bool
|
||||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||||
return true
|
||||
}
|
||||
if _, err := os.Stat("/run/.containerenv"); err == nil {
|
||||
// See https://github.com/cri-o/cri-o/issues/5461
|
||||
return true
|
||||
}
|
||||
lineread.File("/proc/1/cgroup", func(line []byte) error {
|
||||
if mem.Contains(mem.B(line), mem.S("/docker/")) ||
|
||||
mem.Contains(mem.B(line), mem.S("/lxc/")) {
|
||||
ret = true
|
||||
return io.EOF // arbitrary non-nil error to stop loop
|
||||
// If MAC is still empty here, that mean the requests are made from virtual interface,
|
||||
// like VPN/Wireguard clients, so we use ci.IP as hostname to distinguish those clients.
|
||||
if ci.Mac == "" {
|
||||
if hostname := p.ciTable.LookupHostname(ci.IP, ""); hostname != "" {
|
||||
ci.Hostname = hostname
|
||||
} else {
|
||||
// Only use IP as hostname for IPv4 clients.
|
||||
// For Android devices, when it joins the network, it uses ctrld to resolve
|
||||
// its private DNS once and never reaches ctrld again. For each time, it uses
|
||||
// a different IPv6 address, which causes hundreds/thousands different client
|
||||
// IDs created for the same device, which is pointless.
|
||||
//
|
||||
// TODO(cuonglm): investigate whether this can be a false positive for other clients?
|
||||
if !ctrldnet.IsIPv6(ci.IP) {
|
||||
ci.Hostname = ci.IP
|
||||
p.ciTable.StoreVPNClient(ci)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
lineread.File("/proc/mounts", func(line []byte) error {
|
||||
if mem.Contains(mem.B(line), mem.S("lxcfs /proc/cpuinfo fuse.lxcfs")) {
|
||||
ret = true
|
||||
return io.EOF
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return ret
|
||||
} else {
|
||||
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
|
||||
}
|
||||
ci.Self = queryFromSelf(ci.IP)
|
||||
p.spoofLoopbackIpInClientInfo(ci)
|
||||
return ci
|
||||
}
|
||||
|
||||
func (p *prog) getClientInfo(ip, mac string) *ctrld.ClientInfo {
|
||||
ci := &ctrld.ClientInfo{}
|
||||
if mac != "" {
|
||||
ci.Mac = mac
|
||||
ci.IP = p.ciTable.LookupIP(mac)
|
||||
} else {
|
||||
// spoofLoopbackIpInClientInfo replaces loopback IPs in client info.
|
||||
//
|
||||
// - Preference IPv4.
|
||||
// - Preference RFC1918.
|
||||
func (p *prog) spoofLoopbackIpInClientInfo(ci *ctrld.ClientInfo) {
|
||||
if ip := net.ParseIP(ci.IP); ip == nil || !ip.IsLoopback() {
|
||||
return
|
||||
}
|
||||
if ip := p.ciTable.LookupRFC1918IPv4(ci.Mac); ip != "" {
|
||||
ci.IP = ip
|
||||
ci.Mac = p.ciTable.LookupMac(ip)
|
||||
if ip == "127.0.0.1" || ip == "::1" {
|
||||
ci.IP = p.ciTable.LookupIP(ci.Mac)
|
||||
}
|
||||
}
|
||||
|
||||
// queryFromSelf reports whether the input IP is from device running ctrld.
|
||||
func queryFromSelf(ip string) bool {
|
||||
netIP := netip.MustParseAddr(ip)
|
||||
ifaces, err := interfaces.GetList()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not get interfaces list")
|
||||
return false
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not get interfaces addresses: %s", iface.Name)
|
||||
continue
|
||||
}
|
||||
for _, a := range addrs {
|
||||
switch v := a.(type) {
|
||||
case *net.IPNet:
|
||||
if pfx, ok := netaddr.FromStdIPNet(v); ok && pfx.Addr().Compare(netIP) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ci.Hostname = p.ciTable.LookupHostname(ci.IP, ci.Mac)
|
||||
return ci
|
||||
return false
|
||||
}
|
||||
|
||||
func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
|
||||
return lc.IP == "127.0.0.1" && lc.Port == 53
|
||||
}
|
||||
|
||||
func rfc1918Addresses() []string {
|
||||
var res []string
|
||||
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
addrs, _ := i.Addrs()
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok || !ipNet.IP.IsPrivate() {
|
||||
continue
|
||||
}
|
||||
res = append(res, ipNet.IP.String())
|
||||
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.
|
||||
func ipFromARPA(arpa string) net.IP {
|
||||
if arpa, ok := strings.CutSuffix(arpa, ".in-addr.arpa."); ok {
|
||||
if ptrIP := net.ParseIP(arpa); ptrIP != nil {
|
||||
return net.IP{ptrIP[15], ptrIP[14], ptrIP[13], ptrIP[12]}
|
||||
}
|
||||
})
|
||||
return res
|
||||
}
|
||||
if arpa, ok := strings.CutSuffix(arpa, ".ip6.arpa."); ok {
|
||||
l := net.IPv6len * 2
|
||||
base := 16
|
||||
ip := make(net.IP, net.IPv6len)
|
||||
for i := 0; i < l && arpa != ""; i++ {
|
||||
idx := strings.LastIndexByte(arpa, '.')
|
||||
off := idx + 1
|
||||
if idx == -1 {
|
||||
idx = 0
|
||||
off = 0
|
||||
} else if idx == len(arpa)-1 {
|
||||
return nil
|
||||
}
|
||||
n, err := strconv.ParseUint(arpa[off:], base, 8)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
b := byte(n)
|
||||
ii := i / 2
|
||||
if i&1 == 1 {
|
||||
b |= ip[ii] << 4
|
||||
}
|
||||
ip[ii] = b
|
||||
arpa = arpa[:idx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isPrivatePtrLookup reports whether DNS message is an PTR query for LAN/CGNAT network.
|
||||
func isPrivatePtrLookup(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
q := m.Question[0]
|
||||
if ip := ipFromARPA(q.Name); ip != nil {
|
||||
if addr, ok := netip.AddrFromSlice(ip); ok {
|
||||
return addr.IsPrivate() ||
|
||||
addr.IsLoopback() ||
|
||||
addr.IsLinkLocalUnicast() ||
|
||||
tsaddr.CGNATRange().Contains(addr)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isLanHostnameQuery reports whether DNS message is an A/AAAA query with LAN hostname.
|
||||
func isLanHostnameQuery(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
q := m.Question[0]
|
||||
switch q.Qtype {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
return !strings.Contains(name, ".") ||
|
||||
strings.HasSuffix(name, ".domain") ||
|
||||
strings.HasSuffix(name, ".lan")
|
||||
}
|
||||
|
||||
// isWanClient reports whether the input is a WAN address.
|
||||
func isWanClient(na net.Addr) bool {
|
||||
var ip netip.Addr
|
||||
if ap, err := netip.ParseAddrPort(na.String()); err == nil {
|
||||
ip = ap.Addr()
|
||||
}
|
||||
return !ip.IsLoopback() &&
|
||||
!ip.IsPrivate() &&
|
||||
!ip.IsLinkLocalUnicast() &&
|
||||
!ip.IsLinkLocalMulticast() &&
|
||||
!tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
|
||||
@@ -67,8 +67,11 @@ func Test_canonicalName(t *testing.T) {
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
for _, nc := range p.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
@@ -81,6 +84,7 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
mac string
|
||||
defaultUpstreamNum string
|
||||
lc *ctrld.ListenerConfig
|
||||
domain string
|
||||
@@ -88,11 +92,14 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
|
||||
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -111,9 +118,13 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
|
||||
assert.Equal(t, tc.matched, matched)
|
||||
assert.Equal(t, tc.upstreams, upstreams)
|
||||
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
|
||||
p.proxy(ctx, &proxyRequest{
|
||||
msg: newDnsMsgWithHostname("foo", dns.TypeA),
|
||||
ufr: ufr,
|
||||
})
|
||||
assert.Equal(t, tc.matched, ufr.matched)
|
||||
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
@@ -149,26 +160,58 @@ func TestCache(t *testing.T) {
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil)
|
||||
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil)
|
||||
req1 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.1"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
req2 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.0"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
got1 := prog.proxy(context.Background(), req1)
|
||||
got2 := prog.proxy(context.Background(), req2)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
}
|
||||
|
||||
func Test_macFromMsg(t *testing.T) {
|
||||
func Test_ipAndMacFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
wantIp bool
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has mac", "4c:20:b8:ab:87:1b", true},
|
||||
{"no mac", "4c:20:b8:ab:87:1b", false},
|
||||
{"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
{"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatal("missing IP")
|
||||
}
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -180,13 +223,23 @@ func Test_macFromMsg(t *testing.T) {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
got := macFromMsg(m)
|
||||
if tc.wantMac && got != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, got)
|
||||
if tc.wantIp {
|
||||
ec2 := &dns.EDNS0_SUBNET{Address: ip}
|
||||
o.Option = append(o.Option, ec2)
|
||||
}
|
||||
if !tc.wantMac && got != "" {
|
||||
t.Errorf("unexpected mac: %q", got)
|
||||
m.Extra = append(m.Extra, o)
|
||||
gotIP, gotMac := ipAndMacFromMsg(m)
|
||||
if tc.wantMac && gotMac != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac)
|
||||
}
|
||||
if !tc.wantMac && gotMac != "" {
|
||||
t.Errorf("unexpected mac: %q", gotMac)
|
||||
}
|
||||
if tc.wantIp && gotIP != tc.ip {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP)
|
||||
}
|
||||
if !tc.wantIp && gotIP != "" {
|
||||
t.Errorf("unexpected ip: %q", gotIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -216,3 +269,165 @@ func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipFromARPA(t *testing.T) {
|
||||
tests := []struct {
|
||||
IP string
|
||||
ARPA string
|
||||
}{
|
||||
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
|
||||
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
|
||||
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
|
||||
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
|
||||
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
|
||||
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"", "asd.in-addr.arpa."},
|
||||
{"", "asd.ip6.arpa."},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.IP, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
|
||||
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithClientIP(ip string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_stripClientSubnet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
wantSubnet bool
|
||||
}{
|
||||
{"no edns0", new(dns.Msg), false},
|
||||
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
|
||||
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
|
||||
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
|
||||
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
|
||||
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
|
||||
{"invalid IP", newDnsMsgWithClientIP(""), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
stripClientSubnet(tc.msg)
|
||||
hasSubnet := false
|
||||
if opt := tc.msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
|
||||
hasSubnet = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if tc.wantSubnet != hasSubnet {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(hostname, typ)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isLanHostnameQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isLanHostnameQuery bool
|
||||
}{
|
||||
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
|
||||
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
|
||||
t.Helper()
|
||||
m := new(dns.Msg)
|
||||
ptr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.SetQuestion(ptr, dns.TypePTR)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isPrivatePtrLookup bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
|
||||
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
|
||||
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
|
||||
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
|
||||
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
|
||||
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
|
||||
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
isWanClient bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
|
||||
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
|
||||
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
|
||||
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
|
||||
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
|
||||
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
|
||||
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isWanClient(tc.addr); tc.isWanClient != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
18
cmd/cli/library.go
Normal file
18
cmd/cli/library.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
type AppCallback struct {
|
||||
HostName func() string
|
||||
LanIp func() string
|
||||
MacAddress func() string
|
||||
Exit func(error string)
|
||||
}
|
||||
|
||||
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
|
||||
type AppConfig struct {
|
||||
CdUID string
|
||||
HomeDir string
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
141
cmd/cli/loop.go
Normal file
141
cmd/cli/loop.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
loopTestDomain = ".test"
|
||||
loopTestQtype = dns.TypeTXT
|
||||
)
|
||||
|
||||
// newLoopGuard returns new loopGuard.
|
||||
func newLoopGuard() *loopGuard {
|
||||
return &loopGuard{inflight: make(map[string]struct{})}
|
||||
}
|
||||
|
||||
// loopGuard guards against DNS loop, ensuring only one query
|
||||
// for a given domain is processed at a time.
|
||||
type loopGuard struct {
|
||||
mu sync.Mutex
|
||||
inflight map[string]struct{}
|
||||
}
|
||||
|
||||
// TryLock marks the domain as being processed.
|
||||
func (lg *loopGuard) TryLock(domain string) bool {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
if _, inflight := lg.inflight[domain]; !inflight {
|
||||
lg.inflight[domain] = struct{}{}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Unlock marks the domain as being done.
|
||||
func (lg *loopGuard) Unlock(domain string) {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
delete(lg.inflight, domain)
|
||||
}
|
||||
|
||||
// isLoop reports whether the given upstream config is detected as having DNS loop.
|
||||
func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool {
|
||||
p.loopMu.Lock()
|
||||
defer p.loopMu.Unlock()
|
||||
return p.loop[uc.UID()]
|
||||
}
|
||||
|
||||
// detectLoop checks if the given DNS message is initialized sent by ctrld.
|
||||
// If yes, marking the corresponding upstream as loop, prevent infinite DNS
|
||||
// forwarding loop.
|
||||
//
|
||||
// See p.checkDnsLoop for more details how it works.
|
||||
func (p *prog) detectLoop(msg *dns.Msg) {
|
||||
if len(msg.Question) != 1 {
|
||||
return
|
||||
}
|
||||
q := msg.Question[0]
|
||||
if q.Qtype != loopTestQtype {
|
||||
return
|
||||
}
|
||||
unFQDNname := strings.TrimSuffix(q.Name, ".")
|
||||
uid := strings.TrimSuffix(unFQDNname, loopTestDomain)
|
||||
p.loopMu.Lock()
|
||||
if _, loop := p.loop[uid]; loop {
|
||||
p.loop[uid] = loop
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
}
|
||||
|
||||
// checkDnsLoop sends a message to check if there's any DNS forwarding loop
|
||||
// with all the upstreams. The way it works based on dnsmasq --dns-loop-detect.
|
||||
//
|
||||
// - Generating a TXT test query and sending it to all upstream.
|
||||
// - The test query is formed by upstream UID and test domain: <uid>.test
|
||||
// - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop).
|
||||
//
|
||||
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
|
||||
func (p *prog) checkDnsLoop() {
|
||||
mainLog.Load().Debug().Msg("start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for n, uc := range p.cfg.Upstream {
|
||||
if p.um.isDown("upstream." + n) {
|
||||
continue
|
||||
}
|
||||
// Do not send test query to external upstream.
|
||||
if !canBeLocalUpstream(uc.Domain) {
|
||||
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
|
||||
continue
|
||||
}
|
||||
uid := uc.UID()
|
||||
p.loop[uid] = false
|
||||
upstream[uid] = uc
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
}
|
||||
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("end checking DNS loop")
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
func (p *prog) checkDnsLoopTicker(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
p.checkDnsLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loopTestMsg generates DNS message for checking loop.
|
||||
func loopTestMsg(uid string) *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)
|
||||
return msg
|
||||
}
|
||||
42
cmd/cli/loop_test.go
Normal file
42
cmd/cli/loop_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_loopGuard(t *testing.T) {
|
||||
lg := newLoopGuard()
|
||||
key := "foo"
|
||||
|
||||
var i atomic.Int64
|
||||
var started atomic.Int64
|
||||
n := 1000
|
||||
do := func() {
|
||||
locked := lg.TryLock(key)
|
||||
defer lg.Unlock(key)
|
||||
started.Add(1)
|
||||
for started.Load() < 2 {
|
||||
// Wait until at least 2 goroutines started, otherwise, on system with heavy load,
|
||||
// or having only 1 CPU, all goroutines can be scheduled to run consequently.
|
||||
}
|
||||
if locked {
|
||||
i.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
do()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if i.Load() == int64(n) {
|
||||
t.Fatalf("i must not be increased %d times", n)
|
||||
}
|
||||
}
|
||||
@@ -32,9 +32,18 @@ var (
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
nextdns string
|
||||
cdUpstreamProto string
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
noConfigStart bool
|
||||
)
|
||||
|
||||
const (
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
nextdnsFlagName = "nextdns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -65,6 +74,7 @@ func normalizeLogFilePath(logFilePath string) string {
|
||||
return filepath.Join(dir, logFilePath)
|
||||
}
|
||||
|
||||
// initConsoleLogging initializes console logging, then storing to mainLog.
|
||||
func initConsoleLogging() {
|
||||
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
@@ -86,6 +96,7 @@ func initConsoleLogging() {
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func initLogging() {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(true)
|
||||
}
|
||||
|
||||
@@ -124,7 +135,7 @@ func initLoggingWithBackup(doBackup bool) {
|
||||
}
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState() {
|
||||
func (p *prog) watchLinkState(ctx context.Context) {
|
||||
ch := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
@@ -13,14 +15,19 @@ func (p *prog) watchLinkState() {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
|
||||
return
|
||||
}
|
||||
for lu := range ch {
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case lu := <-ch:
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,4 +2,6 @@
|
||||
|
||||
package cli
|
||||
|
||||
func (p *prog) watchLinkState() {}
|
||||
import "context"
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {}
|
||||
|
||||
@@ -3,6 +3,7 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
@@ -16,13 +17,21 @@ const (
|
||||
dns=none
|
||||
systemd-resolved=false
|
||||
`
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
systemdEnabledState = "enabled"
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
)
|
||||
|
||||
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
|
||||
|
||||
// hasNetworkManager reports whether NetworkManager executable found.
|
||||
func hasNetworkManager() bool {
|
||||
exe, _ := exec.LookPath("NetworkManager")
|
||||
return exe != ""
|
||||
}
|
||||
|
||||
func setupNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
|
||||
mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
return nil
|
||||
@@ -43,6 +52,9 @@ func setupNetworkManager() error {
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
err := os.Remove(networkManagerCtrldConfFile)
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
@@ -71,6 +83,7 @@ func reloadNetworkManager() {
|
||||
waitCh := make(chan string)
|
||||
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
}
|
||||
|
||||
31
cmd/cli/nextdns.go
Normal file
31
cmd/cli/nextdns.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const nextdnsURL = "https://dns.nextdns.io"
|
||||
|
||||
func generateNextDNSConfig(uid string) {
|
||||
if uid == "" {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver")
|
||||
cfg = ctrld.Config{
|
||||
Listener: map[string]*ctrld.ListenerConfig{
|
||||
"0": {
|
||||
IP: "0.0.0.0",
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Upstream: map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
Type: ctrld.ResolverTypeDOH3,
|
||||
Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, uid),
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -9,11 +9,12 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/client6"
|
||||
@@ -24,7 +25,10 @@ import (
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
const (
|
||||
resolvConfPath = "/etc/resolv.conf"
|
||||
resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
)
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ip a add 127.0.0.2/24 dev lo
|
||||
@@ -65,6 +69,11 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
defer func() {
|
||||
if r.Mode() == "direct" {
|
||||
go watchResolveConf(osConfig)
|
||||
}
|
||||
}()
|
||||
|
||||
trySystemdResolve := false
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
@@ -85,8 +94,13 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
if useSystemdResolved {
|
||||
if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out))
|
||||
}
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if reflect.DeepEqual(currentNS, nameservers) {
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -104,7 +118,7 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if reflect.DeepEqual(currentNS, nameservers) {
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -265,3 +279,89 @@ func ignoringEINTR(fn func() error) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isSubSet reports whether s2 contains all elements of s1.
|
||||
func isSubSet(s1, s2 []string) bool {
|
||||
ok := true
|
||||
for _, ns := range s1 {
|
||||
// TODO(cuonglm): use slices.Contains once upgrading to go1.21
|
||||
if sliceContains(s2, ns) {
|
||||
continue
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// sliceContains reports whether v is present in s.
|
||||
func sliceContains[S ~[]E, E comparable](s S, v E) bool {
|
||||
return sliceIndex(s, v) >= 0
|
||||
}
|
||||
|
||||
// sliceIndex returns the index of the first occurrence of v in s,
|
||||
// or -1 if not present.
|
||||
func sliceIndex[S ~[]E, E comparable](s S, v E) int {
|
||||
for i := range s {
|
||||
if v == s[i] {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// watchResolveConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func watchResolveConf(oc dns.OSConfig) {
|
||||
mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file")
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
return
|
||||
}
|
||||
|
||||
// We watch /etc instead of /etc/resolv.conf directly,
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list")
|
||||
return
|
||||
}
|
||||
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter.
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Name != resolvConfPath { // skip if not /etc/resolv.conf changes.
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
if err := r.SetDNS(oc); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
459
cmd/cli/prog.go
459
cmd/cli/prog.go
@@ -1,19 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/viper"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
||||
@@ -25,6 +31,9 @@ const (
|
||||
defaultSemaphoreCap = 256
|
||||
ctrldLogUnixSock = "ctrld_start.sock"
|
||||
ctrldControlUnixSock = "ctrld_control.sock"
|
||||
upstreamPrefix = "upstream."
|
||||
upstreamOS = upstreamPrefix + "os"
|
||||
upstreamPrivate = upstreamPrefix + "private"
|
||||
)
|
||||
|
||||
var logf = func(format string, args ...any) {
|
||||
@@ -40,17 +49,28 @@ var svcConfig = &service.Config{
|
||||
var useSystemdResolved = false
|
||||
|
||||
type prog struct {
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
reloadCh chan struct{} // For Windows.
|
||||
reloadDoneCh chan struct{}
|
||||
logConn net.Conn
|
||||
cs *controlServer
|
||||
|
||||
cfg *ctrld.Config
|
||||
cache dnscache.Cacher
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
router router.Router
|
||||
cfg *ctrld.Config
|
||||
localUpstreams []string
|
||||
ptrNameservers []string
|
||||
appCallback *AppCallback
|
||||
cache dnscache.Cacher
|
||||
sema semaphore
|
||||
ciTable *clientinfo.Table
|
||||
um *upstreamMonitor
|
||||
router router.Router
|
||||
ptrLoopGuard *loopGuard
|
||||
lanLoopGuard *loopGuard
|
||||
|
||||
loopMu sync.Mutex
|
||||
loop map[string]bool
|
||||
|
||||
started chan struct{}
|
||||
onStartedDone chan struct{}
|
||||
@@ -59,11 +79,106 @@ type prog struct {
|
||||
}
|
||||
|
||||
func (p *prog) Start(s service.Service) error {
|
||||
p.cfg = &cfg
|
||||
go p.run()
|
||||
go p.runWait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// runWait runs ctrld components, waiting for signal to reload.
|
||||
func (p *prog) runWait() {
|
||||
p.mu.Lock()
|
||||
p.cfg = &cfg
|
||||
p.mu.Unlock()
|
||||
reloadSigCh := make(chan os.Signal, 1)
|
||||
notifyReloadSigCh(reloadSigCh)
|
||||
|
||||
reload := false
|
||||
logger := mainLog.Load()
|
||||
for {
|
||||
reloadCh := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
p.run(reload, reloadCh)
|
||||
reload = true
|
||||
}()
|
||||
select {
|
||||
case sig := <-reloadSigCh:
|
||||
logger.Notice().Msgf("got signal: %s, reloading...", sig.String())
|
||||
case <-p.reloadCh:
|
||||
logger.Notice().Msg("reloading...")
|
||||
case <-p.stopCh:
|
||||
close(reloadCh)
|
||||
return
|
||||
}
|
||||
|
||||
waitOldRunDone := func() {
|
||||
close(reloadCh)
|
||||
<-done
|
||||
}
|
||||
newCfg := &ctrld.Config{}
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
logger.Err(err).Msg("could not read new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if err := v.Unmarshal(&newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not unmarshal new config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
if cdUID != "" {
|
||||
if err := processCDFlags(newCfg); err != nil {
|
||||
logger.Err(err).Msg("could not fetch ControlD config")
|
||||
waitOldRunDone()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
waitOldRunDone()
|
||||
|
||||
p.mu.Lock()
|
||||
curListener := p.cfg.Listener
|
||||
p.mu.Unlock()
|
||||
|
||||
for n, lc := range newCfg.Listener {
|
||||
curLc := curListener[n]
|
||||
if curLc == nil {
|
||||
continue
|
||||
}
|
||||
if lc.IP == "" {
|
||||
lc.IP = curLc.IP
|
||||
}
|
||||
if lc.Port == 0 {
|
||||
lc.Port = curLc.Port
|
||||
}
|
||||
}
|
||||
if err := validateConfig(newCfg); err != nil {
|
||||
logger.Err(err).Msg("invalid config")
|
||||
continue
|
||||
}
|
||||
|
||||
// This needs to be done here, otherwise, the DNS handler may observe an invalid
|
||||
// upstream config because its initialization function have not been called yet.
|
||||
mainLog.Load().Debug().Msg("setup upstream with new config")
|
||||
p.setupUpstream(newCfg)
|
||||
|
||||
p.mu.Lock()
|
||||
*p.cfg = *newCfg
|
||||
p.mu.Unlock()
|
||||
|
||||
logger.Notice().Msg("reloading config successfully")
|
||||
select {
|
||||
case p.reloadDoneCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if !service.Interactive() {
|
||||
p.setDNS()
|
||||
@@ -77,13 +192,54 @@ func (p *prog) preRun() {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *prog) run() {
|
||||
func (p *prog) setupUpstream(cfg *ctrld.Config) {
|
||||
localUpstreams := make([]string, 0, len(cfg.Upstream))
|
||||
ptrNameservers := make([]string, 0, len(cfg.Upstream))
|
||||
for n := range cfg.Upstream {
|
||||
uc := cfg.Upstream[n]
|
||||
uc.Init()
|
||||
if uc.BootstrapIP == "" {
|
||||
uc.SetupBootstrapIP()
|
||||
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
} else {
|
||||
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
go uc.Ping()
|
||||
|
||||
if canBeLocalUpstream(uc.Domain) {
|
||||
localUpstreams = append(localUpstreams, upstreamPrefix+n)
|
||||
}
|
||||
if uc.IsDiscoverable() {
|
||||
ptrNameservers = append(ptrNameservers, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
p.localUpstreams = localUpstreams
|
||||
p.ptrNameservers = ptrNameservers
|
||||
}
|
||||
|
||||
// run runs the ctrld main components.
|
||||
//
|
||||
// The reload boolean indicates that the function is run when ctrld first start
|
||||
// or when ctrld receive reloading signal. Platform specifics setup is only done
|
||||
// on started, mean reload is "false".
|
||||
//
|
||||
// The reloadCh is used to signal ctrld listeners that ctrld is going to be reloaded,
|
||||
// so all listeners could be terminated and re-spawned again.
|
||||
func (p *prog) run(reload bool, reloadCh chan struct{}) {
|
||||
// Wait the caller to signal that we can do our logic.
|
||||
<-p.waitCh
|
||||
p.preRun()
|
||||
if !reload {
|
||||
p.preRun()
|
||||
}
|
||||
numListeners := len(p.cfg.Listener)
|
||||
p.started = make(chan struct{}, numListeners)
|
||||
if !reload {
|
||||
p.started = make(chan struct{}, numListeners)
|
||||
}
|
||||
p.onStartedDone = make(chan struct{})
|
||||
p.loop = make(map[string]bool)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
if p.cfg.Service.CacheEnable {
|
||||
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
||||
if err != nil {
|
||||
@@ -92,15 +248,7 @@ func (p *prog) run() {
|
||||
p.cache = cacher
|
||||
}
|
||||
}
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
|
||||
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
|
||||
n := *mcr
|
||||
if n == 0 {
|
||||
p.sema = &noopSemaphore{}
|
||||
} else {
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(p.cfg.Listener))
|
||||
|
||||
@@ -114,67 +262,104 @@ func (p *prog) run() {
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
for n := range p.cfg.Upstream {
|
||||
uc := p.cfg.Upstream[n]
|
||||
uc.Init()
|
||||
if uc.BootstrapIP == "" {
|
||||
uc.SetupBootstrapIP()
|
||||
mainLog.Load().Info().Msgf("bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
} else {
|
||||
mainLog.Load().Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("using bootstrap IP for upstream.%s", n)
|
||||
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
|
||||
if !reload {
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)}
|
||||
if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil {
|
||||
n := *mcr
|
||||
if n == 0 {
|
||||
p.sema = &noopSemaphore{}
|
||||
} else {
|
||||
p.sema = &chanSemaphore{ready: make(chan struct{}, n)}
|
||||
}
|
||||
}
|
||||
p.setupUpstream(p.cfg)
|
||||
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID, p.ptrNameservers)
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
go uc.Ping()
|
||||
}
|
||||
|
||||
p.ciTable = clientinfo.NewTable(&cfg, defaultRouteIP(), cdUID)
|
||||
if leaseFile := p.cfg.Service.DHCPLeaseFile; leaseFile != "" {
|
||||
mainLog.Load().Debug().Msgf("watching custom lease file: %s", leaseFile)
|
||||
format := ctrld.LeaseFileFormat(p.cfg.Service.DHCPLeaseFileFormat)
|
||||
p.ciTable.AddLeaseFile(leaseFile, format)
|
||||
}
|
||||
// context for managing spawn goroutines.
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
go func() {
|
||||
p.ciTable.Init()
|
||||
p.ciTable.RefreshLoop(p.stopCh)
|
||||
}()
|
||||
go p.watchLinkState()
|
||||
// Newer versions of android and iOS denies permission which breaks connectivity.
|
||||
if !isMobile() && !reload {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
p.ciTable.Init()
|
||||
p.ciTable.RefreshLoop(ctx)
|
||||
}()
|
||||
go p.watchLinkState(ctx)
|
||||
}
|
||||
|
||||
for listenerNum := range p.cfg.Listener {
|
||||
p.cfg.Listener[listenerNum].Init()
|
||||
go func(listenerNum string) {
|
||||
defer wg.Done()
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
upstreamConfig := p.cfg.Upstream[listenerNum]
|
||||
if upstreamConfig == nil {
|
||||
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
||||
if !reload {
|
||||
go func(listenerNum string) {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
upstreamConfig := p.cfg.Upstream[listenerNum]
|
||||
if upstreamConfig == nil {
|
||||
mainLog.Load().Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
||||
}
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
if err := p.serveDNS(listenerNum); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
}(listenerNum)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
cancelFunc()
|
||||
wg.Done()
|
||||
}()
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
case <-ctx.Done():
|
||||
case <-reloadCh:
|
||||
}
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
mainLog.Load().Info().Msgf("starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
if err := p.serveDNS(listenerNum); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msgf("unable to start dns proxy on listener.%s", listenerNum)
|
||||
}
|
||||
}(listenerNum)
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < numListeners; i++ {
|
||||
<-p.started
|
||||
}
|
||||
for _, f := range p.onStarted {
|
||||
f()
|
||||
if !reload {
|
||||
for i := 0; i < numListeners; i++ {
|
||||
<-p.started
|
||||
}
|
||||
for _, f := range p.onStarted {
|
||||
f()
|
||||
}
|
||||
}
|
||||
|
||||
close(p.onStartedDone)
|
||||
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
initLoggingWithBackup(false)
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
if p.cs != nil {
|
||||
p.registerControlServerHandler()
|
||||
if err := p.cs.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// Check for possible DNS loop.
|
||||
p.checkDnsLoop()
|
||||
// Start check DNS loop ticker.
|
||||
p.checkDnsLoopTicker(ctx)
|
||||
}()
|
||||
|
||||
if !reload {
|
||||
// Stop writing log to unix socket.
|
||||
consoleWriter.Out = os.Stdout
|
||||
initLoggingWithBackup(false)
|
||||
if p.logConn != nil {
|
||||
_ = p.logConn.Close()
|
||||
}
|
||||
if p.cs != nil {
|
||||
p.registerControlServerHandler()
|
||||
if err := p.cs.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start control server")
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
@@ -256,7 +441,7 @@ func (p *prog) setDNS() {
|
||||
|
||||
nameservers := []string{ns}
|
||||
if needRFC1918Listeners(lc) {
|
||||
nameservers = append(nameservers, rfc1918Addresses()...)
|
||||
nameservers = append(nameservers, ctrld.Rfc1918Addresses()...)
|
||||
}
|
||||
if err := setDNS(netIface, nameservers); err != nil {
|
||||
logger.Error().Err(err).Msgf("could not set DNS for interface")
|
||||
@@ -345,48 +530,108 @@ var (
|
||||
func errUrlNetworkError(err error) bool {
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
var opErr *net.OpError
|
||||
if errors.As(urlErr.Err, &opErr) {
|
||||
if opErr.Temporary() {
|
||||
return true
|
||||
}
|
||||
switch {
|
||||
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
|
||||
errors.Is(opErr.Err, syscall.EINVAL),
|
||||
errors.Is(opErr.Err, syscall.ENETUNREACH),
|
||||
errors.Is(opErr.Err, windowsENETUNREACH),
|
||||
errors.Is(opErr.Err, windowsEINVAL),
|
||||
errors.Is(opErr.Err, windowsECONNREFUSED):
|
||||
return true
|
||||
}
|
||||
return errNetworkError(urlErr.Err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func errNetworkError(err error) bool {
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
if opErr.Temporary() {
|
||||
return true
|
||||
}
|
||||
switch {
|
||||
case errors.Is(opErr.Err, syscall.ECONNREFUSED),
|
||||
errors.Is(opErr.Err, syscall.EINVAL),
|
||||
errors.Is(opErr.Err, syscall.ENETUNREACH),
|
||||
errors.Is(opErr.Err, windowsENETUNREACH),
|
||||
errors.Is(opErr.Err, windowsEINVAL),
|
||||
errors.Is(opErr.Err, windowsECONNREFUSED):
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// defaultRouteIP returns IP string of the default route if present, prefer IPv4 over IPv6.
|
||||
func defaultRouteIP() string {
|
||||
if dr, err := interfaces.DefaultRoute(); err == nil {
|
||||
if netIface, err := netInterface(dr.InterfaceName); err == nil {
|
||||
addrs, _ := netIface.Addrs()
|
||||
do := func(v4 bool) net.IP {
|
||||
for _, addr := range addrs {
|
||||
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
|
||||
if v4 {
|
||||
return netIP.IP.To4()
|
||||
}
|
||||
return netIP.IP
|
||||
}
|
||||
func ifaceFirstPrivateIP(iface *net.Interface) string {
|
||||
if iface == nil {
|
||||
return ""
|
||||
}
|
||||
do := func(addrs []net.Addr, v4 bool) net.IP {
|
||||
for _, addr := range addrs {
|
||||
if netIP, ok := addr.(*net.IPNet); ok && netIP.IP.IsPrivate() {
|
||||
if v4 {
|
||||
return netIP.IP.To4()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if ip := do(true); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
if ip := do(false); ip != nil {
|
||||
return ip.String()
|
||||
return netIP.IP
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
addrs, _ := iface.Addrs()
|
||||
if ip := do(addrs, true); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
if ip := do(addrs, false); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// defaultRouteIP returns private IP string of the default route if present, prefer IPv4 over IPv6.
|
||||
func defaultRouteIP() string {
|
||||
dr, err := interfaces.DefaultRoute()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
drNetIface, err := netInterface(dr.InterfaceName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
mainLog.Load().Debug().Str("iface", drNetIface.Name).Msg("checking default route interface")
|
||||
if ip := ifaceFirstPrivateIP(drNetIface); ip != "" {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("found ip with default route interface")
|
||||
return ip
|
||||
}
|
||||
|
||||
// If we reach here, it means the default route interface is connected directly to ISP.
|
||||
// We need to find the LAN interface with the same Mac address with drNetIface.
|
||||
//
|
||||
// There could be multiple LAN interfaces with the same Mac address, so we find all private
|
||||
// IPs then using the smallest one.
|
||||
var addrs []netip.Addr
|
||||
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
if i.Name == drNetIface.Name {
|
||||
return
|
||||
}
|
||||
if bytes.Equal(i.HardwareAddr, drNetIface.HardwareAddr) {
|
||||
for _, pfx := range prefixes {
|
||||
addr := pfx.Addr()
|
||||
if addr.IsPrivate() {
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
if len(addrs) == 0 {
|
||||
mainLog.Load().Warn().Msg("no default route IP found")
|
||||
return ""
|
||||
}
|
||||
sort.Slice(addrs, func(i, j int) bool {
|
||||
return addrs[i].Less(addrs[j])
|
||||
})
|
||||
|
||||
ip := addrs[0].String()
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("found LAN interface IP")
|
||||
return ip
|
||||
}
|
||||
|
||||
// canBeLocalUpstream reports whether the IP address can be used as a local upstream.
|
||||
func canBeLocalUpstream(addr string) bool {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -20,10 +19,8 @@ func setDependencies(svc *service.Config) {
|
||||
"Wants=NetworkManager-wait-online.service",
|
||||
"After=NetworkManager-wait-online.service",
|
||||
"Wants=systemd-networkd-wait-online.service",
|
||||
"After=systemd-networkd-wait-online.service",
|
||||
}
|
||||
if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 {
|
||||
svc.Dependencies = append(svc.Dependencies, routerDeps...)
|
||||
"Wants=nss-lookup.target",
|
||||
"After=nss-lookup.target",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
17
cmd/cli/reload_others.go
Normal file
17
cmd/cli/reload_others.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}
|
||||
18
cmd/cli/reload_windows.go
Normal file
18
cmd/cli/reload_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
select {
|
||||
case p.reloadCh <- struct{}{}:
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
return errors.New("timeout while sending reload signal")
|
||||
}
|
||||
101
cmd/cli/upstream_monitor.go
Normal file
101
cmd/cli/upstream_monitor.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 100
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
|
||||
down map[string]*atomic.Bool
|
||||
failureReq map[string]*atomic.Uint64
|
||||
|
||||
mu sync.Mutex
|
||||
checking map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
down: make(map[string]*atomic.Bool),
|
||||
failureReq: make(map[string]*atomic.Uint64),
|
||||
checking: make(map[string]bool),
|
||||
}
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
um.down[upstream] = new(atomic.Bool)
|
||||
um.failureReq[upstream] = new(atomic.Uint64)
|
||||
}
|
||||
um.down[upstreamOS] = new(atomic.Bool)
|
||||
um.failureReq[upstreamOS] = new(atomic.Uint64)
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
failedCount := um.failureReq[upstream].Add(1)
|
||||
um.down[upstream].Store(failedCount >= maxFailureRequest)
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
return um.down[upstream].Load()
|
||||
}
|
||||
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.failureReq[upstream].Store(0)
|
||||
um.down[upstream].Store(false)
|
||||
}
|
||||
|
||||
// checkUpstream checks the given upstream status, periodically sending query to upstream
|
||||
// until successfully. An upstream status/counter will be reset once it becomes reachable.
|
||||
func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) {
|
||||
um.mu.Lock()
|
||||
isChecking := um.checking[upstream]
|
||||
if isChecking {
|
||||
um.mu.Unlock()
|
||||
return
|
||||
}
|
||||
um.checking[upstream] = true
|
||||
um.mu.Unlock()
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
|
||||
return
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
|
||||
check := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
uc.ReBootstrap()
|
||||
_, err := resolver.Resolve(ctx, msg)
|
||||
return err
|
||||
}
|
||||
for {
|
||||
if err := check(); err == nil {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
|
||||
um.reset(upstream)
|
||||
return
|
||||
}
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
}
|
||||
}
|
||||
74
cmd/ctrld_library/main.go
Normal file
74
cmd/ctrld_library/main.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package ctrld_library
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
// Controller holds global state
|
||||
type Controller struct {
|
||||
stopCh chan struct{}
|
||||
AppCallback AppCallback
|
||||
Config cli.AppConfig
|
||||
}
|
||||
|
||||
// NewController provides reference to global state to be managed by android vpn service and iOS network extension.
|
||||
// reference is not safe for concurrent use.
|
||||
func NewController(appCallback AppCallback) *Controller {
|
||||
return &Controller{AppCallback: appCallback}
|
||||
}
|
||||
|
||||
// AppCallback provides access to app instance.
|
||||
type AppCallback interface {
|
||||
Hostname() string
|
||||
LanIp() string
|
||||
MacAddress() string
|
||||
Exit(error string)
|
||||
}
|
||||
|
||||
// Start configures utility with config.toml from provided directory.
|
||||
// This function will block until Stop is called
|
||||
// Check port availability prior to calling it.
|
||||
func (c *Controller) Start(CdUID string, HomeDir string, logLevel int, logPath string) {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
c.Config = cli.AppConfig{
|
||||
CdUID: CdUID,
|
||||
HomeDir: HomeDir,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
}
|
||||
appCallback := mapCallback(c.AppCallback)
|
||||
cli.RunMobile(&c.Config, &appCallback, c.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
// As workaround to avoid circular dependency between cli and ctrld_library module
|
||||
func mapCallback(callback AppCallback) cli.AppCallback {
|
||||
return cli.AppCallback{
|
||||
HostName: func() string {
|
||||
return callback.Hostname()
|
||||
},
|
||||
LanIp: func() string {
|
||||
return callback.LanIp()
|
||||
},
|
||||
MacAddress: func() string {
|
||||
return callback.MacAddress()
|
||||
},
|
||||
Exit: func(err string) {
|
||||
callback.Exit(err)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) Stop() bool {
|
||||
if c.stopCh != nil {
|
||||
close(c.stopCh)
|
||||
c.stopCh = nil
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Controller) IsRunning() bool {
|
||||
return c.stopCh != nil
|
||||
}
|
||||
113
config.go
113
config.go
@@ -2,13 +2,16 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
@@ -24,6 +27,7 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
@@ -78,8 +82,18 @@ func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
|
||||
func InitConfig(v *viper.Viper, name string) {
|
||||
v.SetDefault("listener", map[string]*ListenerConfig{
|
||||
"0": {
|
||||
IP: "127.0.0.1",
|
||||
Port: 53,
|
||||
IP: "",
|
||||
Port: 0,
|
||||
Policy: &ListenerPolicyConfig{
|
||||
Name: "Main Policy",
|
||||
Networks: []Rule{
|
||||
{"network.0": []string{"upstream.0"}},
|
||||
},
|
||||
Rules: []Rule{
|
||||
{"example.com": []string{"upstream.0"}},
|
||||
{"*.ads.com": []string{"upstream.1"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
v.SetDefault("network", map[string]*NetworkConfig{
|
||||
@@ -178,6 +192,8 @@ type ServiceConfig struct {
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
@@ -193,7 +209,7 @@ type NetworkConfig struct {
|
||||
type UpstreamConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
|
||||
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
|
||||
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty"`
|
||||
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
|
||||
Domain string `mapstructure:"-" toml:"-"`
|
||||
IPStack string `mapstructure:"ip_stack" toml:"ip_stack,omitempty" validate:"ipstack"`
|
||||
@@ -201,6 +217,9 @@ type UpstreamConfig struct {
|
||||
// The caller should not access this field directly.
|
||||
// Use UpstreamSendClientInfo instead.
|
||||
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
|
||||
// The caller should not access this field directly.
|
||||
// Use IsDiscoverable instead.
|
||||
Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"`
|
||||
|
||||
g singleflight.Group
|
||||
rebootstrap atomic.Bool
|
||||
@@ -216,14 +235,16 @@ type UpstreamConfig struct {
|
||||
http3RoundTripper6 http.RoundTripper
|
||||
certPool *x509.CertPool
|
||||
u *url.URL
|
||||
uid string
|
||||
}
|
||||
|
||||
// ListenerConfig specifies the networks configuration that ctrld will run on.
|
||||
type ListenerConfig struct {
|
||||
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
|
||||
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
|
||||
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
|
||||
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
|
||||
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
|
||||
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
|
||||
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
|
||||
AllowWanClients bool `mapstructure:"allow_wan_clients" toml:"allow_wan_clients,omitempty"`
|
||||
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
|
||||
}
|
||||
|
||||
// IsDirectDnsListener reports whether ctrld can be a direct listener on port 53.
|
||||
@@ -249,6 +270,7 @@ type ListenerPolicyConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
|
||||
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
@@ -260,6 +282,7 @@ type Rule map[string][]string
|
||||
|
||||
// Init initialized necessary values for an UpstreamConfig.
|
||||
func (uc *UpstreamConfig) Init() {
|
||||
uc.uid = upstreamUID()
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
uc.Domain = u.Host
|
||||
switch uc.Type {
|
||||
@@ -317,13 +340,28 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
if uc.isControlD() {
|
||||
if uc.isControlD() || uc.isNextDNS() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDiscoverable reports whether the upstream can be used for PTR discovery.
|
||||
// The caller must ensure uc.Init() was called before calling this.
|
||||
func (uc *UpstreamConfig) IsDiscoverable() bool {
|
||||
if uc.Discoverable != nil {
|
||||
return *uc.Discoverable
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate:
|
||||
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BootstrapIPs returns the bootstrap IPs list of upstreams.
|
||||
func (uc *UpstreamConfig) BootstrapIPs() []string {
|
||||
return uc.bootstrapIPs
|
||||
@@ -340,6 +378,11 @@ func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
uc.setupBootstrapIP(true)
|
||||
}
|
||||
|
||||
// UID returns the unique identifier of the upstream.
|
||||
func (uc *UpstreamConfig) UID() string {
|
||||
return uc.uid
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
@@ -384,8 +427,9 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
return
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
uc.rebootstrap.Store(true)
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
@@ -509,6 +553,16 @@ func (uc *UpstreamConfig) isControlD() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) isNextDNS() bool {
|
||||
domain := uc.Domain
|
||||
if domain == "" {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
domain = u.Hostname()
|
||||
}
|
||||
}
|
||||
return domain == "dns.nextdns.io"
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
@@ -589,6 +643,7 @@ func ValidateConfig(validate *validator.Validate, cfg *Config) error {
|
||||
_ = validate.RegisterValidation("dnsrcode", validateDnsRcode)
|
||||
_ = validate.RegisterValidation("ipstack", validateIpStack)
|
||||
_ = validate.RegisterValidation("iporempty", validateIpOrEmpty)
|
||||
validate.RegisterStructValidation(upstreamConfigStructLevelValidation, UpstreamConfig{})
|
||||
return validate.Struct(cfg)
|
||||
}
|
||||
|
||||
@@ -613,6 +668,32 @@ func validateIpOrEmpty(fl validator.FieldLevel) bool {
|
||||
return net.ParseIP(val) != nil
|
||||
}
|
||||
|
||||
func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
|
||||
uc := sl.Current().Addr().Interface().(*UpstreamConfig)
|
||||
if uc.Type == ResolverTypeOS {
|
||||
return
|
||||
}
|
||||
|
||||
// Endpoint is required for non os resolver.
|
||||
if uc.Endpoint == "" {
|
||||
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "required_unless", "")
|
||||
return
|
||||
}
|
||||
|
||||
// DoH/DoH3 requires endpoint is an HTTP url.
|
||||
if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 {
|
||||
u, err := url.Parse(uc.Endpoint)
|
||||
if err != nil || u.Host == "" {
|
||||
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
|
||||
return
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func defaultPortFor(typ string) string {
|
||||
switch typ {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
@@ -652,3 +733,15 @@ func ResolverTypeFromEndpoint(endpoint string) string {
|
||||
func pick(s []string) string {
|
||||
return s[rand.Intn(len(s))]
|
||||
}
|
||||
|
||||
// upstreamUID generates an unique identifier for an upstream.
|
||||
func upstreamUID() string {
|
||||
b := make([]byte, 4)
|
||||
for {
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
|
||||
continue
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +185,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.uid = "" // we don't care about the uid.
|
||||
assert.Equal(t, tc.expected, tc.uc)
|
||||
})
|
||||
}
|
||||
@@ -278,6 +279,61 @@ func TestUpstreamConfig_UpstreamSendClientInfo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
discoverable bool
|
||||
}{
|
||||
{
|
||||
"loopback",
|
||||
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"rfc1918",
|
||||
&UpstreamConfig{Endpoint: "192.168.1.1", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"CGNAT",
|
||||
&UpstreamConfig{Endpoint: "100.66.67.68", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Public IP",
|
||||
&UpstreamConfig{Endpoint: "8.8.8.8", Type: ResolverTypeLegacy},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"override discoverable",
|
||||
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(false)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"override non-public",
|
||||
&UpstreamConfig{Endpoint: "1.1.1.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(true)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-legacy upstream",
|
||||
&UpstreamConfig{Endpoint: "https://192.168.1.1/custom-doh", Type: ResolverTypeDOH},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ptrBool(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
@@ -8,14 +8,12 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
@@ -28,9 +26,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
if hasIPv6() {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
@@ -43,7 +39,6 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
rt := &http3.RoundTripper{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
domain := addr
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
@@ -57,20 +52,23 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
|
||||
return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
||||
}
|
||||
dialAddrs := make([]string, len(addrs))
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
pd := &quicParallelDialer{}
|
||||
conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg)
|
||||
conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) {
|
||||
rt.CloseIdleConnections()
|
||||
})
|
||||
return rt
|
||||
}
|
||||
|
||||
@@ -107,13 +105,15 @@ type parallelDialerResult struct {
|
||||
type quicParallelDialer struct{}
|
||||
|
||||
// Dial performs parallel dialing to the given address list.
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
ch := make(chan *parallelDialerResult, len(addrs))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(addrs))
|
||||
@@ -122,11 +122,6 @@ func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []st
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
go func(addr string) {
|
||||
defer wg.Done()
|
||||
@@ -135,9 +130,22 @@ func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []st
|
||||
ch <- ¶llelDialerResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
|
||||
ch <- ¶llelDialerResult{conn: conn, err: err}
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
ch <- ¶llelDialerResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
||||
select {
|
||||
case ch <- ¶llelDialerResult{conn: conn, err: err}:
|
||||
case <-done:
|
||||
if conn != nil {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
|
||||
}
|
||||
if udpConn != nil {
|
||||
udpConn.Close()
|
||||
}
|
||||
}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
|
||||
@@ -54,7 +54,12 @@ func TestLoadDefaultConfig(t *testing.T) {
|
||||
cfg := defaultConfig(t)
|
||||
validate := validator.New()
|
||||
require.NoError(t, ctrld.ValidateConfig(validate, cfg))
|
||||
assert.Len(t, cfg.Listener, 1)
|
||||
if assert.Len(t, cfg.Listener, 1) {
|
||||
l0 := cfg.Listener["0"]
|
||||
require.NotNil(t, l0.Policy)
|
||||
assert.Len(t, l0.Policy.Networks, 1)
|
||||
assert.Len(t, l0.Policy.Rules, 2)
|
||||
}
|
||||
assert.Len(t, cfg.Upstream, 2)
|
||||
}
|
||||
|
||||
@@ -95,6 +100,8 @@ func TestConfigValidation(t *testing.T) {
|
||||
{"non-existed lease file", configWithNonExistedLeaseFile(t), true},
|
||||
{"lease file format required if lease file exist", configWithExistedLeaseFile(t), true},
|
||||
{"invalid lease file format", configWithInvalidLeaseFileFormat(t), true},
|
||||
{"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true},
|
||||
{"invalid client id pref", configWithInvalidClientIDPref(t), true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -225,3 +232,16 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
|
||||
cfg.Service.DHCPLeaseFileFormat = "invalid"
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "1.1.1.1"
|
||||
cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.ClientIDPref = "foo"
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
# - Non-cgo ctrld binary.
|
||||
#
|
||||
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
||||
FROM golang:bullseye as base
|
||||
FROM golang:1.20-bullseye as base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
32
docker/Dockerfile.debug
Normal file
32
docker/Dockerfile.debug
Normal file
@@ -0,0 +1,32 @@
|
||||
# Using Debian bullseye for building regular image.
|
||||
# Using scratch image for minimal image size.
|
||||
# The final image has:
|
||||
#
|
||||
# - Timezone info file.
|
||||
# - CA certs file.
|
||||
# - /etc/{passwd,group} file.
|
||||
# - Non-cgo ctrld binary.
|
||||
#
|
||||
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
||||
FROM golang:1.20-bullseye as base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y upx-ucl
|
||||
|
||||
COPY . .
|
||||
|
||||
ARG tag=master
|
||||
ENV CI_COMMIT_TAG=$tag
|
||||
RUN CTRLD_NO_QF=yes CGO_ENABLED=0 ./scripts/build.sh
|
||||
|
||||
FROM alpine
|
||||
|
||||
COPY --from=base /usr/share/zoneinfo /usr/share/zoneinfo
|
||||
COPY --from=base /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
COPY --from=base /etc/passwd /etc/passwd
|
||||
COPY --from=base /etc/group /etc/group
|
||||
|
||||
COPY --from=base /app/ctrld-linux-*-nocgo ctrld
|
||||
|
||||
ENTRYPOINT ["./ctrld", "run"]
|
||||
@@ -193,6 +193,13 @@ Perform LAN client discovery using PTR queries.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
### discover_hosts
|
||||
Perform LAN client discovery using hosts file.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
### dhcp_lease_file_path
|
||||
Relative or absolute path to a custom DHCP leases file location.
|
||||
|
||||
@@ -208,6 +215,17 @@ DHCP leases file format.
|
||||
- Valid values: `dnsmasq`, `isc-dhcp`
|
||||
- Default: ""
|
||||
|
||||
### client_id_preference
|
||||
Decide how the client ID is generated
|
||||
|
||||
If `host` -> client id will only use the hostname i.e.`hash(hostname)`.
|
||||
If `mac` -> client id will only use the MAC address `hash(mac)`.
|
||||
Else -> client ID will use both Mac and Hostname i.e. `hash(mac + host)
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Valid values: `mac`, `host`
|
||||
- Default: ""
|
||||
|
||||
## Upstream
|
||||
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
||||
|
||||
@@ -312,6 +330,24 @@ If `ip_stack` is empty, or undefined:
|
||||
- Default value is `both` for non-Control D resolvers.
|
||||
- Default value is `split` for Control D resolvers.
|
||||
|
||||
### send_client_info
|
||||
Specifying whether to include client info when sending query to upstream.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default:
|
||||
- `true` for ControlD upstreams.
|
||||
- `false` for other upstreams.
|
||||
|
||||
### discoverable
|
||||
Specifying whether the upstream can be used for PTR discovery.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default:
|
||||
- `true` for loopback/RFC1918/CGNAT IP address.
|
||||
- `false` for public IP address.
|
||||
|
||||
## Network
|
||||
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.
|
||||
|
||||
@@ -369,7 +405,14 @@ Port number that the listener will listen on for incoming requests. If `port` is
|
||||
- Default: 0 or 53 or 5354 (depending on platform)
|
||||
|
||||
### restricted
|
||||
If set to `true` makes the listener `REFUSE` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
|
||||
If set to `true`, makes the listener `REFUSED` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
|
||||
|
||||
- Type: bool
|
||||
- Required: no
|
||||
- Default: false
|
||||
|
||||
### allow_wan_clients
|
||||
The listener `REFUSED` DNS queries from WAN clients by default. If set to `true`, makes the listener replies to them.
|
||||
|
||||
- Type: bool
|
||||
- Required: no
|
||||
@@ -379,7 +422,15 @@ If set to `true` makes the listener `REFUSE` DNS queries from all source IP addr
|
||||
Allows `ctrld` to set policy rules to determine which upstreams the requests will be forwarded to.
|
||||
If no `policy` is defined or the requests do not match any policy rules, it will be forwarded to corresponding upstream of the listener. For example, the request to `listener.0` will be forwarded to `upstream.0`.
|
||||
|
||||
The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either the `network` or a domain. Value is the list of the upstreams. For example:
|
||||
The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either:
|
||||
|
||||
- Network.
|
||||
- Domain.
|
||||
- Mac Address.
|
||||
|
||||
Value is the list of the upstreams.
|
||||
|
||||
For example:
|
||||
|
||||
```toml
|
||||
[listener.0.policy]
|
||||
@@ -393,12 +444,18 @@ rules = [
|
||||
{"*.local" = ["upstream.1"]},
|
||||
{"test.com" = ["upstream.2", "upstream.1"]},
|
||||
]
|
||||
|
||||
macs = [
|
||||
{"14:54:4a:8e:08:2d" = ["upstream.3"]},
|
||||
]
|
||||
```
|
||||
|
||||
Above policy will:
|
||||
- Forward requests on `listener.0` from `network.0` to `upstream.1`.
|
||||
|
||||
- Forward requests on `listener.0` for `.local` suffixed domains to `upstream.1`.
|
||||
- Forward requests on `listener.0` for `test.com` to `upstream.2`. If timeout is reached, retry on `upstream.1`.
|
||||
- Forward requests on `listener.0` from client with Mac `14:54:4a:8e:08:2d` to `upstream.3`.
|
||||
- Forward requests on `listener.0` from `network.0` to `upstream.1`.
|
||||
- All other requests on `listener.0` that do not match above conditions will be forwarded to `upstream.0`.
|
||||
|
||||
An empty upstream would not route the request to any defined upstreams, and use the OS default resolver.
|
||||
@@ -412,6 +469,18 @@ rules = [
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Note that the order of matching preference:
|
||||
|
||||
```
|
||||
rules => macs => networks
|
||||
```
|
||||
|
||||
And within each policy, the rules are processed from top to bottom.
|
||||
|
||||
---
|
||||
|
||||
#### name
|
||||
`name` is the name for the policy.
|
||||
|
||||
@@ -433,6 +502,13 @@ rules = [
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
### macs:
|
||||
`macs` is the list of mac rules within the policy. Mac address value is case-insensitive.
|
||||
|
||||
- Type: array of macs
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
### failover_rcodes
|
||||
For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`.
|
||||
|
||||
|
||||
128
doh.go
128
doh.go
@@ -8,23 +8,75 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cuonglm/osinfo"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
dohMacHeader = "x-cd-mac"
|
||||
dohIPHeader = "x-cd-ip"
|
||||
dohHostHeader = "x-cd-host"
|
||||
headerApplicationDNS = "application/dns-message"
|
||||
dohMacHeader = "x-cd-mac"
|
||||
dohIPHeader = "x-cd-ip"
|
||||
dohHostHeader = "x-cd-host"
|
||||
dohOsHeader = "x-cd-os"
|
||||
dohClientIDPrefHeader = "x-cd-cpref"
|
||||
headerApplicationDNS = "application/dns-message"
|
||||
)
|
||||
|
||||
// EncodeOsNameMap provides mapping from OS name to a shorter string, used for encoding x-cd-os value.
|
||||
var EncodeOsNameMap = map[string]string{
|
||||
"windows": "1",
|
||||
"darwin": "2",
|
||||
"linux": "3",
|
||||
"freebsd": "4",
|
||||
}
|
||||
|
||||
// DecodeOsNameMap provides mapping from encoded OS name to real value, used for decoding x-cd-os value.
|
||||
var DecodeOsNameMap = map[string]string{}
|
||||
|
||||
// EncodeArchNameMap provides mapping from OS arch to a shorter string, used for encoding x-cd-os value.
|
||||
var EncodeArchNameMap = map[string]string{
|
||||
"amd64": "1",
|
||||
"arm64": "2",
|
||||
"arm": "3",
|
||||
"386": "4",
|
||||
"mips": "5",
|
||||
"mipsle": "6",
|
||||
"mips64": "7",
|
||||
}
|
||||
|
||||
// DecodeArchNameMap provides mapping from encoded OS arch to real value, used for decoding x-cd-os value.
|
||||
var DecodeArchNameMap = map[string]string{}
|
||||
|
||||
func init() {
|
||||
for k, v := range EncodeOsNameMap {
|
||||
DecodeOsNameMap[v] = k
|
||||
}
|
||||
for k, v := range EncodeArchNameMap {
|
||||
DecodeArchNameMap[v] = k
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: use sync.OnceValue when upgrading to go1.21
|
||||
var xCdOsValueOnce sync.Once
|
||||
var xCdOsValue string
|
||||
|
||||
func dohOsHeaderValue() string {
|
||||
xCdOsValueOnce.Do(func() {
|
||||
oi := osinfo.New()
|
||||
xCdOsValue = strings.Join([]string{EncodeOsNameMap[runtime.GOOS], EncodeArchNameMap[runtime.GOARCH], oi.Dist}, "-")
|
||||
})
|
||||
return xCdOsValue
|
||||
}
|
||||
|
||||
func newDohResolver(uc *UpstreamConfig) *dohResolver {
|
||||
r := &dohResolver{
|
||||
endpoint: uc.u,
|
||||
isDoH3: uc.Type == ResolverTypeDOH3,
|
||||
http3RoundTripper: uc.http3RoundTripper,
|
||||
sendClientInfo: uc.UpstreamSendClientInfo(),
|
||||
uc: uc,
|
||||
}
|
||||
return r
|
||||
@@ -35,9 +87,9 @@ type dohResolver struct {
|
||||
endpoint *url.URL
|
||||
isDoH3 bool
|
||||
http3RoundTripper http.RoundTripper
|
||||
sendClientInfo bool
|
||||
}
|
||||
|
||||
// Resolve performs DNS query with given DNS message using DOH protocol.
|
||||
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
@@ -54,7 +106,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create request: %w", err)
|
||||
}
|
||||
addHeader(ctx, req, r.sendClientInfo)
|
||||
addHeader(ctx, req, r.uc)
|
||||
dnsTyp := uint16(0)
|
||||
if len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
@@ -94,21 +146,61 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
|
||||
func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
|
||||
req.Header.Set("Content-Type", headerApplicationDNS)
|
||||
req.Header.Set("Accept", headerApplicationDNS)
|
||||
if sendClientInfo {
|
||||
|
||||
printed := false
|
||||
if uc.UpstreamSendClientInfo() {
|
||||
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
|
||||
if ci.Mac != "" {
|
||||
req.Header.Set(dohMacHeader, ci.Mac)
|
||||
}
|
||||
if ci.IP != "" {
|
||||
req.Header.Set(dohIPHeader, ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set(dohHostHeader, ci.Hostname)
|
||||
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
|
||||
switch {
|
||||
case uc.isControlD():
|
||||
addControlDHeaders(req, ci)
|
||||
case uc.isNextDNS():
|
||||
addNextDNSHeaders(req, ci)
|
||||
}
|
||||
}
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header")
|
||||
if printed {
|
||||
Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header")
|
||||
}
|
||||
}
|
||||
|
||||
// addControlDHeaders set DoH/Doh3 HTTP request headers for ControlD upstream.
|
||||
func addControlDHeaders(req *http.Request, ci *ClientInfo) {
|
||||
req.Header.Set(dohOsHeader, dohOsHeaderValue())
|
||||
if ci.Mac != "" {
|
||||
req.Header.Set(dohMacHeader, ci.Mac)
|
||||
}
|
||||
if ci.IP != "" {
|
||||
req.Header.Set(dohIPHeader, ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set(dohHostHeader, ci.Hostname)
|
||||
}
|
||||
if ci.Self {
|
||||
req.Header.Set(dohOsHeader, dohOsHeaderValue())
|
||||
}
|
||||
switch ci.ClientIDPref {
|
||||
case "mac":
|
||||
req.Header.Set(dohClientIDPrefHeader, "1")
|
||||
case "host":
|
||||
req.Header.Set(dohClientIDPrefHeader, "2")
|
||||
}
|
||||
}
|
||||
|
||||
// addNextDNSHeaders set DoH/Doh3 HTTP request headers for nextdns upstream.
|
||||
// https://github.com/nextdns/nextdns/blob/v1.41.0/resolver/doh.go#L100
|
||||
func addNextDNSHeaders(req *http.Request, ci *ClientInfo) {
|
||||
if ci.Mac != "" {
|
||||
// https: //github.com/nextdns/nextdns/blob/v1.41.0/run.go#L543
|
||||
req.Header.Set("X-Device-Model", "mac:"+ci.Mac[:8])
|
||||
}
|
||||
if ci.IP != "" {
|
||||
req.Header.Set("X-Device-Ip", ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set("X-Device-Name", ci.Hostname)
|
||||
}
|
||||
}
|
||||
|
||||
23
doh_test.go
Normal file
23
doh_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_dohOsHeaderValue(t *testing.T) {
|
||||
val := dohOsHeaderValue()
|
||||
if val == "" {
|
||||
t.Fatalf("empty %s", dohOsHeader)
|
||||
}
|
||||
t.Log(val)
|
||||
|
||||
encodedOs := EncodeOsNameMap[runtime.GOOS]
|
||||
if encodedOs == "" {
|
||||
t.Fatalf("missing encoding value for: %q", runtime.GOOS)
|
||||
}
|
||||
decodedOs := DecodeOsNameMap[encodedOs]
|
||||
if decodedOs == "" {
|
||||
t.Fatalf("missing decoding value for: %q", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
2
doq.go
2
doq.go
@@ -51,7 +51,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
|
||||
}
|
||||
|
||||
func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) {
|
||||
session, err := quic.DialAddr(endpoint, tlsConfig, nil)
|
||||
session, err := quic.DialAddr(ctx, endpoint, tlsConfig, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
23
go.mod
23
go.mod
@@ -4,7 +4,7 @@ go 1.20
|
||||
|
||||
require (
|
||||
github.com/coreos/go-systemd/v22 v22.5.0
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf
|
||||
github.com/frankban/quicktest v1.14.5
|
||||
github.com/fsnotify/fsnotify v1.6.0
|
||||
github.com/go-playground/validator/v10 v10.11.1
|
||||
@@ -12,22 +12,22 @@ require (
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.1
|
||||
github.com/illarion/gonotify v1.0.1
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16
|
||||
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
|
||||
github.com/kardianos/service v1.2.1
|
||||
github.com/miekg/dns v1.1.55
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/pelletier/go-toml/v2 v2.0.8
|
||||
github.com/quic-go/quic-go v0.32.0
|
||||
github.com/quic-go/quic-go v0.38.0
|
||||
github.com/rs/zerolog v1.28.0
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.16.0
|
||||
github.com/stretchr/testify v1.8.3
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13
|
||||
golang.org/x/net v0.10.0
|
||||
golang.org/x/net v0.17.0
|
||||
golang.org/x/sync v0.2.0
|
||||
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
tailscale.com v1.44.0
|
||||
)
|
||||
@@ -37,7 +37,7 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/go-playground/locales v0.14.0 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
@@ -56,13 +56,11 @@ require (
|
||||
github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.2.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.17 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qtls-go1-18 v0.2.0 // indirect
|
||||
github.com/quic-go/qtls-go1-19 v0.2.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.1.0 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.3.2 // indirect
|
||||
github.com/rivo/uniseg v0.4.4 // indirect
|
||||
github.com/rogpeppe/go-internal v1.10.0 // indirect
|
||||
github.com/spf13/afero v1.9.5 // indirect
|
||||
@@ -71,10 +69,11 @@ require (
|
||||
github.com/subosito/gotenv v1.4.2 // indirect
|
||||
github.com/u-root/uio v0.0.0-20230305220412-3e8cd9d6bf63 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/crypto v0.9.0 // indirect
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
|
||||
golang.org/x/crypto v0.14.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 // indirect
|
||||
golang.org/x/mod v0.10.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/tools v0.9.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
45
go.sum
45
go.sum
@@ -55,8 +55,8 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf h1:40DHYsri+d1bnroFDU2FQAeq68f3kAlOzlQ93kCf26Q=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -73,6 +73,7 @@ github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbS
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
|
||||
@@ -81,8 +82,8 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j
|
||||
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
|
||||
github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ=
|
||||
github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
@@ -162,6 +163,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16 h1:+aAGyK41KRn8jbF2Q7PLL0Sxwg6dShGcQSeCC7nZQ8E=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20230407062729-974c6f05fe16/go.mod h1:IKrnDWs3/Mqq5n0lI+RxA2sB7MvN/vbMBP3ehXg65UI=
|
||||
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c h1:kbTQ8oGf+BVFvt/fM+ECI+NbZDCqoi0vtZTfB2p2hrI=
|
||||
github.com/jaytaylor/go-hostsfile v0.0.0-20220426042432-61485ac1fa6c/go.mod h1:k6+89xKz7BSMJ+DzIerBdtpEUeTlBMugO/hcVSzahog=
|
||||
github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk=
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8=
|
||||
@@ -211,9 +214,9 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/onsi/ginkgo/v2 v2.2.0 h1:3ZNA3L1c5FYDFTTxbFeVGGD8jYvjYauHD30YgLxVsNI=
|
||||
github.com/onsi/ginkgo/v2 v2.2.0/go.mod h1:MEH45j8TBi6u9BMogfbp0stKC5cdGjumZj5Y7AG4VIk=
|
||||
github.com/onsi/gomega v1.20.1 h1:PA/3qinGoukvymdIDV8pii6tiZgC8kbmJO6Z5+b002Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
@@ -227,14 +230,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/qtls-go1-18 v0.2.0 h1:5ViXqBZ90wpUcZS0ge79rf029yx0dYB0McyPJwqqj7U=
|
||||
github.com/quic-go/qtls-go1-18 v0.2.0/go.mod h1:moGulGHK7o6O8lSPSZNoOwcLvJKJ85vVNc7oJFD65bc=
|
||||
github.com/quic-go/qtls-go1-19 v0.2.0 h1:Cvn2WdhyViFUHoOqK52i51k4nDX8EwIh5VJiVM4nttk=
|
||||
github.com/quic-go/qtls-go1-19 v0.2.0/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
||||
github.com/quic-go/qtls-go1-20 v0.1.0 h1:d1PK3ErFy9t7zxKsG3NXBJXZjp/kMLoIb3y/kV54oAI=
|
||||
github.com/quic-go/qtls-go1-20 v0.1.0/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM=
|
||||
github.com/quic-go/quic-go v0.32.0 h1:lY02md31s1JgPiiyfqJijpu/UX/Iun304FI3yUqX7tA=
|
||||
github.com/quic-go/quic-go v0.32.0/go.mod h1:/fCsKANhQIeD5l76c2JFU+07gVE3KaA0FP+0zMWwfwo=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.2 h1:rRgN3WfnKbyik4dBV8A6girlJVxGand/d+jVKbQq5GI=
|
||||
github.com/quic-go/qtls-go1-20 v0.3.2/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||
github.com/quic-go/quic-go v0.38.0 h1:T45lASr5q/TrVwt+jrVccmqHhPL2XuSyoCLVCpfOSLc=
|
||||
github.com/quic-go/quic-go v0.38.0/go.mod h1:MPCuRq7KBK2hNcfKj/1iD1BGuN3eAYMeNxp3T42LRUg=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
|
||||
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
@@ -301,8 +300,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -375,8 +374,8 @@ golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -449,8 +448,8 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a h1:qMsju+PNttu/NMbq8bQ9waDdxgJMu9QNoUDuhnBaYt0=
|
||||
golang.org/x/sys v0.8.1-0.20230609144347-5059a07aa46a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -460,8 +459,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
||||
@@ -33,6 +33,9 @@ func (a *arpDiscover) String() string {
|
||||
}
|
||||
|
||||
func (a *arpDiscover) List() []string {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
a.ip.Range(func(key, value any) bool {
|
||||
ips = append(ips, value.(string))
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -68,23 +71,27 @@ type Table struct {
|
||||
refreshers []refresher
|
||||
initOnce sync.Once
|
||||
|
||||
dhcp *dhcp
|
||||
merlin *merlinDiscover
|
||||
arp *arpDiscover
|
||||
ptr *ptrDiscover
|
||||
mdns *mdns
|
||||
cfg *ctrld.Config
|
||||
quitCh chan struct{}
|
||||
selfIP string
|
||||
cdUID string
|
||||
dhcp *dhcp
|
||||
merlin *merlinDiscover
|
||||
arp *arpDiscover
|
||||
ptr *ptrDiscover
|
||||
mdns *mdns
|
||||
hf *hostsFile
|
||||
vni *virtualNetworkIface
|
||||
svcCfg ctrld.ServiceConfig
|
||||
quitCh chan struct{}
|
||||
selfIP string
|
||||
cdUID string
|
||||
ptrNameservers []string
|
||||
}
|
||||
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string) *Table {
|
||||
func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
|
||||
return &Table{
|
||||
cfg: cfg,
|
||||
quitCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
svcCfg: cfg.Service,
|
||||
quitCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
ptrNameservers: ns,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +102,7 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) {
|
||||
clientInfoFiles[name] = format
|
||||
}
|
||||
|
||||
func (t *Table) RefreshLoop(stopCh chan struct{}) {
|
||||
func (t *Table) RefreshLoop(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute * 5)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
@@ -104,7 +111,7 @@ func (t *Table) RefreshLoop(stopCh chan struct{}) {
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
case <-stopCh:
|
||||
case <-ctx.Done():
|
||||
close(t.quitCh)
|
||||
return
|
||||
}
|
||||
@@ -116,6 +123,7 @@ func (t *Table) Init() {
|
||||
}
|
||||
|
||||
func (t *Table) init() {
|
||||
// Custom client ID presents, use it as the only source.
|
||||
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery")
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP}
|
||||
@@ -125,6 +133,11 @@ func (t *Table) init() {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.dhcp)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, process all possible sources in order, that means
|
||||
// the first result of IP/MAC/Hostname lookup will be used.
|
||||
//
|
||||
// Merlin custom clients.
|
||||
if t.discoverDHCP() || t.discoverARP() {
|
||||
t.merlin = &merlinDiscover{}
|
||||
if err := t.merlin.refresh(); err != nil {
|
||||
@@ -134,6 +147,19 @@ func (t *Table) init() {
|
||||
t.refreshers = append(t.refreshers, t.merlin)
|
||||
}
|
||||
}
|
||||
// Hosts file mapping.
|
||||
if t.discoverHosts() {
|
||||
t.hf = &hostsFile{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start hosts file discovery")
|
||||
if err := t.hf.init(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init hosts file discover")
|
||||
} else {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.hf)
|
||||
t.refreshers = append(t.refreshers, t.hf)
|
||||
}
|
||||
go t.hf.watchChanges()
|
||||
}
|
||||
// DHCP lease files.
|
||||
if t.discoverDHCP() {
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start dhcp discovery")
|
||||
@@ -146,6 +172,7 @@ func (t *Table) init() {
|
||||
}
|
||||
go t.dhcp.watchChanges()
|
||||
}
|
||||
// ARP table.
|
||||
if t.discoverARP() {
|
||||
t.arp = &arpDiscover{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start arp discovery")
|
||||
@@ -157,8 +184,29 @@ func (t *Table) init() {
|
||||
t.refreshers = append(t.refreshers, t.arp)
|
||||
}
|
||||
}
|
||||
// PTR lookup.
|
||||
if t.discoverPTR() {
|
||||
t.ptr = &ptrDiscover{resolver: ctrld.NewPrivateResolver()}
|
||||
if len(t.ptrNameservers) > 0 {
|
||||
nss := make([]string, 0, len(t.ptrNameservers))
|
||||
for _, ns := range t.ptrNameservers {
|
||||
host, port := ns, "53"
|
||||
if h, p, err := net.SplitHostPort(ns); err == nil {
|
||||
host, port = h, p
|
||||
}
|
||||
// Only use valid ip:port pair.
|
||||
if _, portErr := strconv.Atoi(port); portErr == nil && port != "0" && net.ParseIP(host) != nil {
|
||||
nss = append(nss, net.JoinHostPort(host, port))
|
||||
} else {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("ignoring invalid nameserver for ptr discover: %q", ns)
|
||||
}
|
||||
}
|
||||
if len(nss) > 0 {
|
||||
t.ptr.resolver = ctrld.NewResolverWithNameserver(nss)
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("using nameservers %v for ptr discovery", nss)
|
||||
}
|
||||
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start ptr discovery")
|
||||
if err := t.ptr.refresh(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not init PTR discover")
|
||||
@@ -167,6 +215,7 @@ func (t *Table) init() {
|
||||
t.refreshers = append(t.refreshers, t.ptr)
|
||||
}
|
||||
}
|
||||
// mdns.
|
||||
if t.discoverMDNS() {
|
||||
t.mdns = &mdns{}
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start mdns discovery")
|
||||
@@ -176,6 +225,11 @@ func (t *Table) init() {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.mdns)
|
||||
}
|
||||
}
|
||||
// VPN clients.
|
||||
if t.discoverDHCP() || t.discoverARP() {
|
||||
t.vni = &virtualNetworkIface{}
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.vni)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Table) LookupIP(mac string) string {
|
||||
@@ -211,6 +265,21 @@ func (t *Table) LookupHostname(ip, mac string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// LookupRFC1918IPv4 returns the RFC1918 IPv4 address for the given MAC address, if any.
|
||||
func (t *Table) LookupRFC1918IPv4(mac string) string {
|
||||
t.initOnce.Do(t.init)
|
||||
for _, r := range t.ipResolvers {
|
||||
ip, err := netip.ParseAddr(r.LookupIP(mac))
|
||||
if err != nil || ip.Is6() {
|
||||
continue
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type macEntry struct {
|
||||
mac string
|
||||
src string
|
||||
@@ -259,7 +328,7 @@ func (t *Table) ListClients() []*Client {
|
||||
_ = r.refresh()
|
||||
}
|
||||
ipMap := make(map[string]*Client)
|
||||
il := []ipLister{t.dhcp, t.arp, t.ptr, t.mdns}
|
||||
il := []ipLister{t.dhcp, t.arp, t.ptr, t.mdns, t.vni}
|
||||
for _, ir := range il {
|
||||
for _, ip := range ir.List() {
|
||||
c, ok := ipMap[ip]
|
||||
@@ -300,32 +369,69 @@ func (t *Table) ListClients() []*Client {
|
||||
return clients
|
||||
}
|
||||
|
||||
// StoreVPNClient stores client info for VPN clients.
|
||||
func (t *Table) StoreVPNClient(ci *ctrld.ClientInfo) {
|
||||
if ci == nil || t.vni == nil {
|
||||
return
|
||||
}
|
||||
t.vni.mac.Store(ci.IP, ci.Mac)
|
||||
t.vni.ip2name.Store(ci.IP, ci.Hostname)
|
||||
}
|
||||
|
||||
// ipFinder is the interface for retrieving IP address from hostname.
|
||||
type ipFinder interface {
|
||||
lookupIPByHostname(name string, v6 bool) string
|
||||
}
|
||||
|
||||
// LookupIPByHostname returns the ip address of given hostname.
|
||||
// If v6 is true, return IPv6 instead of default IPv4.
|
||||
func (t *Table) LookupIPByHostname(hostname string, v6 bool) *netip.Addr {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
for _, finder := range []ipFinder{t.hf, t.ptr, t.mdns, t.dhcp} {
|
||||
if addr := finder.lookupIPByHostname(hostname, v6); addr != "" {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil {
|
||||
return &ip
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Table) discoverDHCP() bool {
|
||||
if t.cfg.Service.DiscoverDHCP == nil {
|
||||
if t.svcCfg.DiscoverDHCP == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverDHCP
|
||||
return *t.svcCfg.DiscoverDHCP
|
||||
}
|
||||
|
||||
func (t *Table) discoverARP() bool {
|
||||
if t.cfg.Service.DiscoverARP == nil {
|
||||
if t.svcCfg.DiscoverARP == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverARP
|
||||
return *t.svcCfg.DiscoverARP
|
||||
}
|
||||
|
||||
func (t *Table) discoverMDNS() bool {
|
||||
if t.cfg.Service.DiscoverMDNS == nil {
|
||||
if t.svcCfg.DiscoverMDNS == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverMDNS
|
||||
return *t.svcCfg.DiscoverMDNS
|
||||
}
|
||||
|
||||
func (t *Table) discoverPTR() bool {
|
||||
if t.cfg.Service.DiscoverPtr == nil {
|
||||
if t.svcCfg.DiscoverPtr == nil {
|
||||
return true
|
||||
}
|
||||
return *t.cfg.Service.DiscoverPtr
|
||||
return *t.svcCfg.DiscoverPtr
|
||||
}
|
||||
|
||||
func (t *Table) discoverHosts() bool {
|
||||
if t.svcCfg.DiscoverHosts == nil {
|
||||
return true
|
||||
}
|
||||
return *t.svcCfg.DiscoverHosts
|
||||
}
|
||||
|
||||
// normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file.
|
||||
|
||||
@@ -25,3 +25,22 @@ func Test_normalizeIP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_LookupRFC1918IPv4(t *testing.T) {
|
||||
table := &Table{
|
||||
dhcp: &dhcp{},
|
||||
arp: &arpDiscover{},
|
||||
}
|
||||
|
||||
table.ipResolvers = append(table.ipResolvers, table.dhcp)
|
||||
table.ipResolvers = append(table.ipResolvers, table.arp)
|
||||
|
||||
macAddress := "cc:19:f9:8a:49:e6"
|
||||
rfc1918IPv4 := "10.0.10.245"
|
||||
table.dhcp.ip.Store(macAddress, "127.0.0.1")
|
||||
table.arp.ip.Store(macAddress, rfc1918IPv4)
|
||||
|
||||
if got := table.LookupRFC1918IPv4(macAddress); got != rfc1918IPv4 {
|
||||
t.Fatalf("unexpected result, want: %s, got: %s", rfc1918IPv4, got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -47,12 +48,25 @@ func (d *dhcp) watchChanges() {
|
||||
if d.watcher == nil {
|
||||
return
|
||||
}
|
||||
if dir := router.LeaseFilesDir(); dir != "" {
|
||||
if err := d.watcher.Add(dir); err != nil {
|
||||
ctrld.ProxyLogger.Load().Err(err).Str("dir", dir).Msg("could not watch lease dir")
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-d.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Create) {
|
||||
if format, ok := clientInfoFiles[event.Name]; ok {
|
||||
if err := d.addLeaseFile(event.Name, format); err != nil {
|
||||
ctrld.ProxyLogger.Load().Err(err).Str("file", event.Name).Msg("could not add lease file")
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) {
|
||||
format := clientInfoFiles[event.Name]
|
||||
if err := d.readLeaseFile(event.Name, format); err != nil && !os.IsNotExist(err) {
|
||||
@@ -106,6 +120,9 @@ func (d *dhcp) String() string {
|
||||
}
|
||||
|
||||
func (d *dhcp) List() []string {
|
||||
if d == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
d.ip.Range(func(key, value any) bool {
|
||||
ips = append(ips, value.(string))
|
||||
@@ -118,6 +135,39 @@ func (d *dhcp) List() []string {
|
||||
return ips
|
||||
}
|
||||
|
||||
func (d *dhcp) lookupIPByHostname(name string, v6 bool) string {
|
||||
if d == nil {
|
||||
return ""
|
||||
}
|
||||
var (
|
||||
rfc1918Addrs []netip.Addr
|
||||
others []netip.Addr
|
||||
)
|
||||
d.ip2name.Range(func(key, value any) bool {
|
||||
if value != name {
|
||||
return true
|
||||
}
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
if addr.IsPrivate() {
|
||||
rfc1918Addrs = append(rfc1918Addrs, addr)
|
||||
} else {
|
||||
others = append(others, addr)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
result := [][]netip.Addr{rfc1918Addrs, others}
|
||||
for _, addrs := range result {
|
||||
if len(addrs) > 0 {
|
||||
sort.Slice(addrs, func(i, j int) bool {
|
||||
return addrs[i].Less(addrs[j])
|
||||
})
|
||||
return addrs[0].String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// AddLeaseFile adds given lease file for reading/watching clients info.
|
||||
func (d *dhcp) addLeaseFile(name string, format ctrld.LeaseFileFormat) error {
|
||||
if d.watcher == nil {
|
||||
|
||||
@@ -86,3 +86,15 @@ lease 192.168.1.2 {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_dhcp_lookupIPByHostname(t *testing.T) {
|
||||
d := &dhcp{}
|
||||
want := "192.168.1.123"
|
||||
d.ip2name.Store(want, "foo")
|
||||
d.ip2name.Store("127.0.0.1", "foo")
|
||||
d.ip2name.Store("169.254.123.123", "foo")
|
||||
|
||||
if got := d.lookupIPByHostname("foo", false); got != want {
|
||||
t.Fatalf("unexpected result, want: %s, got: %s", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
139
internal/clientinfo/hostsfile.go
Normal file
139
internal/clientinfo/hostsfile.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/jaytaylor/go-hostsfile"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4LocalhostName = "localhost"
|
||||
ipv6LocalhostName = "ip6-localhost"
|
||||
ipv6LoopbackName = "ip6-loopback"
|
||||
)
|
||||
|
||||
// hostsFile provides client discovery functionality using system hosts file.
|
||||
type hostsFile struct {
|
||||
watcher *fsnotify.Watcher
|
||||
mu sync.Mutex
|
||||
m map[string][]string
|
||||
}
|
||||
|
||||
// init performs initialization works, which is necessary before hostsFile can be fully operated.
|
||||
func (hf *hostsFile) init() error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hf.watcher = watcher
|
||||
if err := hf.watcher.Add(hostsfile.HostsPath); err != nil {
|
||||
return err
|
||||
}
|
||||
m, err := hostsfile.ParseHosts(hostsfile.ReadHostsFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hf.mu.Lock()
|
||||
hf.m = m
|
||||
hf.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// refresh reloads hosts file entries.
|
||||
func (hf *hostsFile) refresh() error {
|
||||
m, err := hostsfile.ParseHosts(hostsfile.ReadHostsFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hf.mu.Lock()
|
||||
hf.m = m
|
||||
hf.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// watchChanges watches and updates hosts file data if any changes happens.
|
||||
func (hf *hostsFile) watchChanges() {
|
||||
if hf.watcher == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-hf.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Rename) || event.Has(fsnotify.Chmod) || event.Has(fsnotify.Remove) {
|
||||
if err := hf.refresh(); err != nil && !os.IsNotExist(err) {
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("hosts file changed but failed to update client info")
|
||||
}
|
||||
}
|
||||
case err, ok := <-hf.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Err(err).Msg("could not watch client info file")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// LookupHostnameByIP returns hostname for given IP from current hosts file entries.
|
||||
func (hf *hostsFile) LookupHostnameByIP(ip string) string {
|
||||
hf.mu.Lock()
|
||||
defer hf.mu.Unlock()
|
||||
if names := hf.m[ip]; len(names) > 0 {
|
||||
isLoopback := ip == "127.0.0.1" || ip == "::1"
|
||||
for _, hostname := range names {
|
||||
name := normalizeHostname(hostname)
|
||||
// Ignoring ipv4/ipv6 loopback entry.
|
||||
if isLoopback && isLocalhostName(name) {
|
||||
continue
|
||||
}
|
||||
return name
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// LookupHostnameByMac returns hostname for given Mac from current hosts file entries.
|
||||
func (hf *hostsFile) LookupHostnameByMac(mac string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// String returns human-readable format of hostsFile.
|
||||
func (hf *hostsFile) String() string {
|
||||
return "hosts"
|
||||
}
|
||||
|
||||
func (hf *hostsFile) lookupIPByHostname(name string, v6 bool) string {
|
||||
if hf == nil {
|
||||
return ""
|
||||
}
|
||||
hf.mu.Lock()
|
||||
defer hf.mu.Unlock()
|
||||
for addr, names := range hf.m {
|
||||
if ip, err := netip.ParseAddr(addr); err == nil && !ip.IsLoopback() {
|
||||
for _, n := range names {
|
||||
if n == name && ip.Is6() == v6 {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isLocalhostName reports whether the given hostname represents localhost.
|
||||
func isLocalhostName(hostname string) bool {
|
||||
switch hostname {
|
||||
case ipv4LocalhostName, ipv6LocalhostName, ipv6LoopbackName:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
33
internal/clientinfo/hostsfile_test.go
Normal file
33
internal/clientinfo/hostsfile_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_hostsFile_LookupHostnameByIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
hostnames []string
|
||||
expectedHostname string
|
||||
}{
|
||||
{"ipv4 loopback", "127.0.0.1", []string{ipv4LocalhostName}, ""},
|
||||
{"ipv6 loopback", "::1", []string{ipv6LocalhostName, ipv6LoopbackName}, ""},
|
||||
{"non-localhost", "::1", []string{"foo"}, "foo"},
|
||||
{"multiple hostnames", "::1", []string{ipv4LocalhostName, "foo"}, "foo"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hf := &hostsFile{m: make(map[string][]string)}
|
||||
hf.mu.Lock()
|
||||
hf.m[tc.ip] = tc.hostnames
|
||||
hf.mu.Unlock()
|
||||
if got := hf.LookupHostnameByIP(tc.ip); got != tc.expectedHostname {
|
||||
t.Errorf("unpexpected result, want: %q, got: %q", tc.expectedHostname, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -48,6 +49,9 @@ func (m *mdns) String() string {
|
||||
}
|
||||
|
||||
func (m *mdns) List() []string {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
m.name.Range(func(key, value any) bool {
|
||||
ips = append(ips, key.(string))
|
||||
@@ -56,6 +60,27 @@ func (m *mdns) List() []string {
|
||||
return ips
|
||||
}
|
||||
|
||||
func (m *mdns) lookupIPByHostname(name string, v6 bool) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
var ip string
|
||||
m.name.Range(func(key, value any) bool {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
//lint:ignore S1008 This is used for readable.
|
||||
if addr.IsLoopback() { // Continue searching if this is loopback address.
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return ip
|
||||
}
|
||||
|
||||
func (m *mdns) init(quitCh chan struct{}) error {
|
||||
ifaces, err := multicastInterfaces()
|
||||
if err != nil {
|
||||
@@ -120,6 +145,10 @@ func (m *mdns) readLoop(conn *net.UDPConn) {
|
||||
if err, ok := err.(*net.OpError); ok && (err.Timeout() || err.Temporary()) {
|
||||
continue
|
||||
}
|
||||
// Do not complain about use of closed network connection.
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Err(err).Msg("mdns readLoop error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,17 +2,21 @@ package clientinfo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
type ptrDiscover struct {
|
||||
hostname sync.Map // ip => hostname
|
||||
resolver ctrld.Resolver
|
||||
hostname sync.Map // ip => hostname
|
||||
resolver ctrld.Resolver
|
||||
serverDown atomic.Bool
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) refresh() error {
|
||||
@@ -41,6 +45,9 @@ func (p *ptrDiscover) String() string {
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) List() []string {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
p.hostname.Range(func(key, value any) bool {
|
||||
ips = append(ips, key.(string))
|
||||
@@ -57,18 +64,25 @@ func (p *ptrDiscover) lookupHostnameFromCache(ip string) string {
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) lookupHostname(ip string) string {
|
||||
// If nameserver is down, do nothing.
|
||||
if p.serverDown.Load() {
|
||||
return ""
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
msg := new(dns.Msg)
|
||||
addr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("invalid ip address")
|
||||
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
|
||||
return ""
|
||||
}
|
||||
msg.SetQuestion(addr, dns.TypePTR)
|
||||
ans, err := p.resolver.Resolve(ctx, msg)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msg("could not lookup IP")
|
||||
if p.serverDown.CompareAndSwap(false, true) {
|
||||
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
|
||||
go p.checkServer()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
for _, rr := range ans.Answer {
|
||||
@@ -80,3 +94,46 @@ func (p *ptrDiscover) lookupHostname(ip string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
var ip string
|
||||
p.hostname.Range(func(key, value any) bool {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
//lint:ignore S1008 This is used for readable.
|
||||
if addr.IsLoopback() { // Continue searching if this is loopback address.
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return ip
|
||||
}
|
||||
|
||||
// checkServer monitors if the resolver can reach its nameserver. When the nameserver
|
||||
// is reachable, set p.serverDown to false, so p.lookupHostname can continue working.
|
||||
func (p *ptrDiscover) checkServer() {
|
||||
bo := backoff.NewBackoff("ptrDiscover", func(format string, args ...any) {}, time.Minute*5)
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(".", dns.TypeNS)
|
||||
ping := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, err := p.resolver.Resolve(ctx, m)
|
||||
return err
|
||||
}
|
||||
for {
|
||||
if err := ping(); err != nil {
|
||||
bo.BackOff(context.Background(), err)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
p.serverDown.Store(false)
|
||||
}
|
||||
|
||||
43
internal/clientinfo/virtual_iface.go
Normal file
43
internal/clientinfo/virtual_iface.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package clientinfo
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// virtualNetworkIface is the manager for clients from virtual network interface.
|
||||
type virtualNetworkIface struct {
|
||||
ip2name sync.Map // ip => name
|
||||
mac sync.Map // ip => mac
|
||||
}
|
||||
|
||||
// LookupHostnameByIP returns hostname of the given VPN client ip.
|
||||
func (v *virtualNetworkIface) LookupHostnameByIP(ip string) string {
|
||||
val, ok := v.ip2name.Load(ip)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return val.(string)
|
||||
}
|
||||
|
||||
// LookupHostnameByMac always returns empty string.
|
||||
func (v *virtualNetworkIface) LookupHostnameByMac(mac string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// String returns the string representation of virtualNetworkIface struct.
|
||||
func (v *virtualNetworkIface) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// List lists all known VPN clients IP.
|
||||
func (v *virtualNetworkIface) List() []string {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
var ips []string
|
||||
v.mac.Range(func(key, value any) bool {
|
||||
ips = append(ips, key.(string))
|
||||
return true
|
||||
})
|
||||
return ips
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -119,7 +120,7 @@ func postUtilityAPI(version string, cdDev bool, body io.Reader) (*ResolverConfig
|
||||
return d.DialContext(ctx, network, addrs)
|
||||
}
|
||||
|
||||
if router.Name() == ddwrt.Name {
|
||||
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
|
||||
}
|
||||
client := http.Client{
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package dnsmasq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -17,6 +20,10 @@ server={{ .IP }}#{{ .Port }}
|
||||
{{- end}}
|
||||
{{- if .SendClientInfo}}
|
||||
add-mac
|
||||
add-subnet=32,128
|
||||
{{- end}}
|
||||
{{- if .CacheDisabled}}
|
||||
cache-size=0
|
||||
{{- end}}
|
||||
`
|
||||
|
||||
@@ -39,10 +46,15 @@ if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then
|
||||
pc_append "server={{ .IP }}#{{ .Port }}" "$config_file"
|
||||
{{- end}}
|
||||
{{- if .SendClientInfo}}
|
||||
pc_delete "add-mac" "$config_file"
|
||||
pc_delete "add-subnet" "$config_file"
|
||||
pc_append "add-mac" "$config_file" # add client mac
|
||||
pc_append "add-subnet=32,128" "$config_file" # add client ip
|
||||
{{- end}}
|
||||
pc_delete "dnssec" "$config_file" # disable DNSSEC
|
||||
pc_delete "trust-anchor=" "$config_file" # disable DNSSEC
|
||||
pc_delete "cache-size=" "$config_file"
|
||||
pc_append "cache-size=0" "$config_file" # disable cache
|
||||
|
||||
# For John fork
|
||||
pc_delete "resolv-file" "$config_file" # no WAN DNS settings
|
||||
@@ -61,6 +73,10 @@ type Upstream struct {
|
||||
}
|
||||
|
||||
func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
|
||||
return ConfTmplWitchCacheDisabled(tmplText, cfg, true)
|
||||
}
|
||||
|
||||
func ConfTmplWitchCacheDisabled(tmplText string, cfg *ctrld.Config, cacheDisabled bool) (string, error) {
|
||||
listener := cfg.FirstListener()
|
||||
if listener == nil {
|
||||
return "", errors.New("missing listener")
|
||||
@@ -70,24 +86,26 @@ func ConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
|
||||
ip = "127.0.0.1"
|
||||
}
|
||||
upstreams := []Upstream{{IP: ip, Port: listener.Port}}
|
||||
return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo())
|
||||
return confTmpl(tmplText, upstreams, cfg.HasUpstreamSendClientInfo(), cacheDisabled)
|
||||
}
|
||||
|
||||
func FirewallaConfTmpl(tmplText string, cfg *ctrld.Config) (string, error) {
|
||||
if lc := cfg.FirstListener(); lc != nil && (lc.IP == "0.0.0.0" || lc.IP == "") {
|
||||
return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo())
|
||||
return confTmpl(tmplText, firewallaUpstreams(lc.Port), cfg.HasUpstreamSendClientInfo(), true)
|
||||
}
|
||||
return ConfTmpl(tmplText, cfg)
|
||||
}
|
||||
|
||||
func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo bool) (string, error) {
|
||||
func confTmpl(tmplText string, upstreams []Upstream, sendClientInfo, cacheDisabled bool) (string, error) {
|
||||
tmpl := template.Must(template.New("").Parse(tmplText))
|
||||
var to = &struct {
|
||||
SendClientInfo bool
|
||||
Upstreams []Upstream
|
||||
CacheDisabled bool
|
||||
}{
|
||||
SendClientInfo: sendClientInfo,
|
||||
Upstreams: upstreams,
|
||||
CacheDisabled: cacheDisabled,
|
||||
}
|
||||
var sb strings.Builder
|
||||
if err := tmpl.Execute(&sb, to); err != nil {
|
||||
@@ -113,9 +131,28 @@ func firewallaUpstreams(port int) []Upstream {
|
||||
return upstreams
|
||||
}
|
||||
|
||||
// firewallaDnsmasqConfFiles returns dnsmasq config files of all firewalla interfaces.
|
||||
func firewallaDnsmasqConfFiles() ([]string, error) {
|
||||
return filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf")
|
||||
}
|
||||
|
||||
// firewallUpdateConf updates all firewall config files using given function.
|
||||
func firewallUpdateConf(update func(conf string) error) error {
|
||||
confFiles, err := firewallaDnsmasqConfFiles()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, conf := range confFiles {
|
||||
if err := update(conf); err != nil {
|
||||
return fmt.Errorf("%s: %w", conf, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FirewallaSelfInterfaces returns list of interfaces that will be configured with default dnsmasq setup on Firewalla.
|
||||
func FirewallaSelfInterfaces() []*net.Interface {
|
||||
matches, err := filepath.Glob("/home/pi/firerouter/etc/dnsmasq.dns.*.conf")
|
||||
matches, err := firewallaDnsmasqConfFiles()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -129,3 +166,32 @@ func FirewallaSelfInterfaces() []*net.Interface {
|
||||
}
|
||||
return ifaces
|
||||
}
|
||||
|
||||
// FirewallaDisableCache comments out "cache-size" line in all firewalla dnsmasq config files.
|
||||
func FirewallaDisableCache() error {
|
||||
return firewallUpdateConf(DisableCache)
|
||||
}
|
||||
|
||||
// FirewallaEnableCache un-comments out "cache-size" line in all firewalla dnsmasq config files.
|
||||
func FirewallaEnableCache() error {
|
||||
return firewallUpdateConf(EnableCache)
|
||||
}
|
||||
|
||||
// DisableCache comments out "cache-size" line in dnsmasq config file.
|
||||
func DisableCache(conf string) error {
|
||||
return replaceFileContent(conf, "\ncache-size=", "\n#cache-size=")
|
||||
}
|
||||
|
||||
// EnableCache un-comments "cache-size" line in dnsmasq config file.
|
||||
func EnableCache(conf string) error {
|
||||
return replaceFileContent(conf, "\n#cache-size=", "\ncache-size=")
|
||||
}
|
||||
|
||||
func replaceFileContent(filename, old, new string) error {
|
||||
content, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
content = bytes.ReplaceAll(content, []byte(old), []byte(new))
|
||||
return os.WriteFile(filename, content, 0644)
|
||||
}
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -95,7 +95,7 @@ func (e *EdgeOS) setupUSG() error {
|
||||
return fmt.Errorf("setupUSG: backup current config: %w", err)
|
||||
}
|
||||
|
||||
// Removing all configured upstreams.
|
||||
// Removing all configured upstreams and cache config.
|
||||
var sb strings.Builder
|
||||
scanner := bufio.NewScanner(bytes.NewReader(buf))
|
||||
for scanner.Scan() {
|
||||
@@ -109,7 +109,7 @@ func (e *EdgeOS) setupUSG() error {
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg)
|
||||
data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func (e *EdgeOS) setupUSG() error {
|
||||
}
|
||||
|
||||
func (e *EdgeOS) setupUDM() error {
|
||||
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, e.cfg)
|
||||
data, err := dnsmasq.ConfTmplWitchCacheDisabled(dnsmasq.ConfigContentTmpl, e.cfg, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -169,9 +169,16 @@ func ContentFilteringEnabled() bool {
|
||||
return err == nil && !st.IsDir()
|
||||
}
|
||||
|
||||
func LeaseFileDir() string {
|
||||
if checkUSG() {
|
||||
return ""
|
||||
}
|
||||
return "/run"
|
||||
}
|
||||
|
||||
func checkUSG() bool {
|
||||
out, _ := exec.Command("mca-cli-op", "info").Output()
|
||||
return bytes.Contains(out, []byte("UniFi-Gateway-"))
|
||||
out, _ := os.ReadFile("/etc/version")
|
||||
return bytes.HasPrefix(out, []byte("UniFiSecurityGateway."))
|
||||
}
|
||||
|
||||
func restartDNSMasq() error {
|
||||
|
||||
@@ -65,6 +65,11 @@ func (f *Firewalla) Setup() error {
|
||||
return fmt.Errorf("writing ctrld config: %w", err)
|
||||
}
|
||||
|
||||
// Disable dnsmasq cache.
|
||||
if err := dnsmasq.FirewallaDisableCache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return fmt.Errorf("restartDNSMasq: %w", err)
|
||||
@@ -82,6 +87,11 @@ func (f *Firewalla) Cleanup() error {
|
||||
return fmt.Errorf("removing ctrld config: %w", err)
|
||||
}
|
||||
|
||||
// Enable dnsmasq cache.
|
||||
if err := dnsmasq.FirewallaEnableCache(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return fmt.Errorf("restartDNSMasq: %w", err)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -44,8 +45,24 @@ func (m *Merlin) Uninstall(_ *service.Config) error {
|
||||
}
|
||||
|
||||
func (m *Merlin) PreRun() error {
|
||||
// Wait NTP ready.
|
||||
_ = m.Cleanup()
|
||||
return ntp.WaitNvram()
|
||||
if err := ntp.WaitNvram(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Wait until directories mounted.
|
||||
for _, dir := range []string{"/tmp", "/proc"} {
|
||||
waitDirExists(dir)
|
||||
}
|
||||
// Wait dnsmasq started.
|
||||
for {
|
||||
out, _ := exec.Command("pidof", "dnsmasq").CombinedOutput()
|
||||
if len(bytes.TrimSpace(out)) > 0 {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Merlin) Setup() error {
|
||||
@@ -56,9 +73,6 @@ func (m *Merlin) Setup() error {
|
||||
if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" {
|
||||
return nil
|
||||
}
|
||||
if _, err := nvram.Run("set", nvram.CtrldSetupKey+"=1"); err != nil {
|
||||
return err
|
||||
}
|
||||
buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath)
|
||||
// Already setup.
|
||||
if bytes.Contains(buf, []byte(dnsmasq.MerlinPostConfMarker)) {
|
||||
@@ -140,3 +154,12 @@ func merlinParsePostConf(buf []byte) []byte {
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func waitDirExists(dir string) {
|
||||
for {
|
||||
if _, err := os.Stat(dir); !os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,11 +8,10 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,10 +19,9 @@ const (
|
||||
openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
|
||||
)
|
||||
|
||||
var errUCIEntryNotFound = errors.New("uci: Entry not found")
|
||||
|
||||
type Openwrt struct {
|
||||
cfg *ctrld.Config
|
||||
cfg *ctrld.Config
|
||||
dnsmasqCacheSize string
|
||||
}
|
||||
|
||||
// New returns a router.Router for configuring/setup/run ctrld on Openwrt routers.
|
||||
@@ -52,6 +50,19 @@ func (o *Openwrt) Setup() error {
|
||||
if o.cfg.FirstListener().IsDirectDnsListener() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save current dnsmasq config cache size if present.
|
||||
if cs, err := uci("get", "dhcp.@dnsmasq[0].cachesize"); err == nil {
|
||||
o.dnsmasqCacheSize = cs
|
||||
if _, err := uci("delete", "dhcp.@dnsmasq[0].cachesize"); err != nil {
|
||||
return err
|
||||
}
|
||||
// Commit.
|
||||
if _, err := uci("commit", "dhcp"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
data, err := dnsmasq.ConfTmpl(dnsmasq.ConfigContentTmpl, o.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -59,10 +70,6 @@ func (o *Openwrt) Setup() error {
|
||||
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Commit.
|
||||
if _, err := uci("commit"); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
@@ -78,6 +85,18 @@ func (o *Openwrt) Cleanup() error {
|
||||
if err := os.Remove(openwrtDNSMasqConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore original value if present.
|
||||
if o.dnsmasqCacheSize != "" {
|
||||
if _, err := uci("set", fmt.Sprintf("dhcp.@dnsmasq[0].cachesize=%s", o.dnsmasqCacheSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Commit.
|
||||
if _, err := uci("commit", "dhcp"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
@@ -92,6 +111,8 @@ func restartDNSMasq() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errUCIEntryNotFound = errors.New("uci: Entry not found")
|
||||
|
||||
func uci(args ...string) (string, error) {
|
||||
cmd := exec.Command("uci", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
@@ -173,20 +173,6 @@ func CanListenLocalhost() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// ServiceDependencies returns list of dependencies that ctrld services needs on this router.
|
||||
// See https://pkg.go.dev/github.com/kardianos/service#Config for list format.
|
||||
func ServiceDependencies() []string {
|
||||
if Name() == edgeos.Name {
|
||||
// On EdeOS, ctrld needs to start after vyatta-dhcpd, so it can read leases file.
|
||||
return []string{
|
||||
"Wants=vyatta-dhcpd.service",
|
||||
"After=vyatta-dhcpd.service",
|
||||
"Wants=dnsmasq.service",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SelfInterfaces return list of *net.Interface that will be source of requests from router itself.
|
||||
func SelfInterfaces() []*net.Interface {
|
||||
switch Name() {
|
||||
@@ -197,6 +183,14 @@ func SelfInterfaces() []*net.Interface {
|
||||
}
|
||||
}
|
||||
|
||||
// LeaseFilesDir is the directory which contains lease files.
|
||||
func LeaseFilesDir() string {
|
||||
if Name() == edgeos.Name {
|
||||
edgeos.LeaseFileDir()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func distroName() string {
|
||||
switch {
|
||||
case bytes.HasPrefix(unameO(), []byte("DD-WRT")):
|
||||
|
||||
@@ -5,16 +5,17 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/edgeos"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "ubios"
|
||||
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
|
||||
Name = "ubios"
|
||||
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
|
||||
ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf"
|
||||
)
|
||||
|
||||
type Ubios struct {
|
||||
@@ -57,6 +58,10 @@ func (u *Ubios) Setup() error {
|
||||
if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Disable dnsmasq cache.
|
||||
if err := dnsmasq.DisableCache(ubiosDNSMasqDnsConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
@@ -72,6 +77,10 @@ func (u *Ubios) Cleanup() error {
|
||||
if err := os.Remove(ubiosDNSMasqConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Enable dnsmasq cache.
|
||||
if err := dnsmasq.EnableCache(ubiosDNSMasqDnsConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
|
||||
9
nameservers_unix.go
Normal file
9
nameservers_unix.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build unix
|
||||
|
||||
package ctrld
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
|
||||
func nameserversFromResolvconf() []string {
|
||||
return resolvconffile.NameServers("")
|
||||
}
|
||||
@@ -58,3 +58,7 @@ func dnsFromAdapter() []string {
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
func nameserversFromResolvconf() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
35
net.go
35
net.go
@@ -2,13 +2,10 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
@@ -17,30 +14,36 @@ var (
|
||||
ipv6Available atomic.Bool
|
||||
)
|
||||
|
||||
const ipv6ProbingInterval = 10 * time.Second
|
||||
|
||||
func hasIPv6() bool {
|
||||
hasIPv6Once.Do(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
val := ctrldnet.IPv6Available(ctx)
|
||||
ipv6Available.Store(val)
|
||||
go probingIPv6(val)
|
||||
go probingIPv6(context.TODO(), val)
|
||||
})
|
||||
return ipv6Available.Load()
|
||||
}
|
||||
|
||||
// TODO(cuonglm): doing poll check natively for supported platforms.
|
||||
func probingIPv6(old bool) {
|
||||
b := backoff.NewBackoff("probingIPv6", func(format string, args ...any) {}, 30*time.Second)
|
||||
bCtx := context.Background()
|
||||
func probingIPv6(ctx context.Context, old bool) {
|
||||
ticker := time.NewTicker(ipv6ProbingInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
cur := ctrldnet.IPv6Available(ctx)
|
||||
if ipv6Available.CompareAndSwap(old, cur) {
|
||||
old = cur
|
||||
}
|
||||
}()
|
||||
b.BackOff(bCtx, errors.New("no change"))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
cur := ctrldnet.IPv6Available(ctx)
|
||||
if ipv6Available.CompareAndSwap(old, cur) {
|
||||
old = cur
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
94
resolver.go
94
resolver.go
@@ -5,10 +5,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"tailscale.com/net/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,16 +26,20 @@ const (
|
||||
ResolverTypeOS = "os"
|
||||
// ResolverTypeLegacy specifies legacy resolver.
|
||||
ResolverTypeLegacy = "legacy"
|
||||
// ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only.
|
||||
ResolverTypePrivate = "private"
|
||||
)
|
||||
|
||||
var bootstrapDNS = "76.76.2.0"
|
||||
var or = &osResolver{nameservers: nameservers()}
|
||||
|
||||
func init() {
|
||||
if len(or.nameservers) == 0 {
|
||||
// Add bootstrap DNS in case we did not find any.
|
||||
or.nameservers = []string{net.JoinHostPort(bootstrapDNS, "53")}
|
||||
}
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = &osResolver{nameservers: defaultNameservers()}
|
||||
|
||||
// defaultNameservers returns OS nameservers plus ctrld bootstrap nameserver.
|
||||
func defaultNameservers() []string {
|
||||
ns := nameservers()
|
||||
ns = append(ns, net.JoinHostPort(bootstrapDNS, "53"))
|
||||
return ns
|
||||
}
|
||||
|
||||
// Resolver is the interface that wraps the basic DNS operations.
|
||||
@@ -59,6 +65,8 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
return or, nil
|
||||
case ResolverTypeLegacy:
|
||||
return &legacyResolver{uc: uc}, nil
|
||||
case ResolverTypePrivate:
|
||||
return NewPrivateResolver(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
|
||||
}
|
||||
@@ -72,8 +80,9 @@ type osResolverResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// Resolve performs DNS resolvers using OS default nameservers. Nameserver is chosen from
|
||||
// available nameservers with a roundrobin algorithm.
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
numServers := len(o.nameservers)
|
||||
if numServers == 0 {
|
||||
@@ -237,13 +246,33 @@ func NewBootstrapResolver(servers ...string) Resolver {
|
||||
return resolver
|
||||
}
|
||||
|
||||
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers.
|
||||
// NewPrivateResolver returns an OS resolver, which includes only private DNS servers,
|
||||
// excluding:
|
||||
//
|
||||
// - Nameservers from /etc/resolv.conf file.
|
||||
// - Nameservers which is local RFC1918 addresses.
|
||||
//
|
||||
// This is useful for doing PTR lookup in LAN network.
|
||||
func NewPrivateResolver() Resolver {
|
||||
nss := nameservers()
|
||||
resolveConfNss := nameserversFromResolvconf()
|
||||
localRfc1918Addrs := Rfc1918Addresses()
|
||||
n := 0
|
||||
for _, ns := range nss {
|
||||
host, _, _ := net.SplitHostPort(ns)
|
||||
// Ignore nameserver from resolve.conf file, because the nameserver can be either:
|
||||
//
|
||||
// - ctrld itself.
|
||||
// - Direct listener that has ctrld as an upstream (e.g: dnsmasq).
|
||||
//
|
||||
// causing the query always succeed.
|
||||
if sliceContains(resolveConfNss, host) {
|
||||
continue
|
||||
}
|
||||
// Ignoring local RFC 1918 addresses.
|
||||
if sliceContains(localRfc1918Addrs, host) {
|
||||
continue
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.IsPrivate() && !ip.IsLoopback() {
|
||||
nss[n] = ns
|
||||
@@ -251,11 +280,35 @@ func NewPrivateResolver() Resolver {
|
||||
}
|
||||
}
|
||||
nss = nss[:n]
|
||||
if len(nss) == 0 {
|
||||
return NewResolverWithNameserver(nss)
|
||||
}
|
||||
|
||||
// NewResolverWithNameserver returns an OS resolver which uses the given nameservers
|
||||
// for resolving DNS queries. If nameservers is empty, a dummy resolver will be returned.
|
||||
//
|
||||
// Each nameserver must be form "host:port". It's the caller responsibility to ensure all
|
||||
// nameservers are well formatted by using net.JoinHostPort function.
|
||||
func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
if len(nameservers) == 0 {
|
||||
return &dummyResolver{}
|
||||
}
|
||||
resolver := &osResolver{nameservers: nss}
|
||||
return resolver
|
||||
return &osResolver{nameservers: nameservers}
|
||||
}
|
||||
|
||||
// Rfc1918Addresses returns the list of local interfaces private IP addresses
|
||||
func Rfc1918Addresses() []string {
|
||||
var res []string
|
||||
interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
addrs, _ := i.Addrs()
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
if !ok || !ipNet.IP.IsPrivate() {
|
||||
continue
|
||||
}
|
||||
res = append(res, ipNet.IP.String())
|
||||
}
|
||||
})
|
||||
return res
|
||||
}
|
||||
|
||||
func newDialer(dnsAddress string) *net.Dialer {
|
||||
@@ -269,3 +322,20 @@ func newDialer(dnsAddress string) *net.Dialer {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(cuonglm): use slices.Contains once upgrading to go1.21
|
||||
// sliceContains reports whether v is present in s.
|
||||
func sliceContains[S ~[]E, E comparable](s S, v E) bool {
|
||||
return sliceIndex(s, v) >= 0
|
||||
}
|
||||
|
||||
// sliceIndex returns the index of the first occurrence of v in s,
|
||||
// or -1 if not present.
|
||||
func sliceIndex[S ~[]E, E comparable](s S, v E) int {
|
||||
for i := range s {
|
||||
if v == s[i] {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
179
scripts/build.sh
Executable file
179
scripts/build.sh
Executable file
@@ -0,0 +1,179 @@
|
||||
#!/bin/bash
|
||||
|
||||
go=${GOBIN:-go}
|
||||
executable_name="ctrld"
|
||||
os_archs=(
|
||||
darwin/arm64
|
||||
darwin/amd64
|
||||
windows/386
|
||||
windows/amd64
|
||||
windows/arm64
|
||||
windows/arm
|
||||
linux/amd64
|
||||
linux/386
|
||||
linux/mips
|
||||
linux/mipsle
|
||||
linux/mips64
|
||||
linux/arm
|
||||
linux/arm64
|
||||
freebsd/amd64
|
||||
freebsd/386
|
||||
freebsd/arm
|
||||
freebsd/arm64
|
||||
)
|
||||
|
||||
compress() {
|
||||
binary=$1
|
||||
|
||||
if [ -z "$binary" ]; then
|
||||
echo >&2 "missing binary"
|
||||
return 1
|
||||
fi
|
||||
|
||||
case "$binary" in
|
||||
*-freebsd-*)
|
||||
echo >&2 "upx does not work with freebsd binary yet"
|
||||
return 0
|
||||
;;
|
||||
*-windows-arm*)
|
||||
echo >&2 "upx does not work with windows arm/arm64 binary yet"
|
||||
return 0
|
||||
;;
|
||||
*-darwin-*)
|
||||
echo >&2 "upx claims to work with darwin binary, but testing show that it is broken"
|
||||
return 0
|
||||
;;
|
||||
*-linux-armv*)
|
||||
echo >&2 "upx does not work on arm routers"
|
||||
return 0
|
||||
;;
|
||||
*-linux-mips*)
|
||||
echo >&2 "upx does not work on mips routers"
|
||||
return 0
|
||||
;;
|
||||
esac
|
||||
|
||||
upx -- "$binary"
|
||||
}
|
||||
|
||||
build() {
|
||||
goos=$1
|
||||
goarch=$2
|
||||
ldflags="-s -w -X github.com/Control-D-Inc/ctrld/cmd/cli.version="${CI_COMMIT_TAG:-dev}" -X github.com/Control-D-Inc/ctrld/cmd/cli.commit=$(git rev-parse HEAD)"
|
||||
|
||||
case $3 in
|
||||
5 | 6 | 7)
|
||||
goarm=$3
|
||||
if [ "${goos}${goarm}" = "freebsd5" ]; then
|
||||
# freebsd/arm require ARMv6K or above: https://github.com/golang/go/wiki/GoArm#supported-operating-systems
|
||||
return
|
||||
fi
|
||||
binary=${executable_name}-${goos}-${goarch}v${3}
|
||||
if [ "$CGO_ENABLED" = "0" ]; then
|
||||
binary=${binary}-nocgo
|
||||
fi
|
||||
GOOS=${goos} GOARCH=${goarch} GOARM=${3} "$go" build -ldflags="$ldflags" -o "$binary" ./cmd/ctrld
|
||||
compress "$binary"
|
||||
|
||||
if [ -z "${CTRLD_NO_QF}" ]; then
|
||||
binary_qf=${executable_name}-qf-${goos}-${goarch}v${3}
|
||||
if [ "$CGO_ENABLED" = "0" ]; then
|
||||
binary_qf=${binary_qf}-nocgo
|
||||
fi
|
||||
GOOS=${goos} GOARCH=${goarch} GOARM=${3} "$go" build -ldflags="$ldflags" -tags=qf -o "$binary_qf" ./cmd/ctrld
|
||||
compress "$binary_qf"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
# GOMIPS is required for linux/mips: https://nileshgr.com/2020/02/16/golang-on-openwrt-mips/
|
||||
binary=${executable_name}-${goos}-${goarch}
|
||||
if [ "$CGO_ENABLED" = "0" ]; then
|
||||
binary=${binary}-nocgo
|
||||
fi
|
||||
GOOS=${goos} GOARCH=${goarch} GOMIPS=softfloat "$go" build -ldflags="$ldflags" -o "$binary" ./cmd/ctrld
|
||||
compress "$binary"
|
||||
|
||||
if [ -z "${CTRLD_NO_QF}" ]; then
|
||||
binary_qf=${executable_name}-qf-${goos}-${goarch}
|
||||
if [ "$CGO_ENABLED" = "0" ]; then
|
||||
binary_qf=${binary_qf}-nocgo
|
||||
fi
|
||||
GOOS=${goos} GOARCH=${goarch} GOMIPS=softfloat "$go" build -ldflags="$ldflags" -tags=qf -o "$binary_qf" ./cmd/ctrld
|
||||
compress "$binary_qf"
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
}
|
||||
echo "Building binaries..."
|
||||
|
||||
case $1 in
|
||||
all)
|
||||
for os_arch in "${os_archs[@]}"; do
|
||||
goos=${os_arch%/*}
|
||||
goarch=${os_arch#*/}
|
||||
|
||||
case goarch in
|
||||
arm)
|
||||
|
||||
echo "Building $goos/$goarch ARM5..."
|
||||
build "$goos" "$goarch" "5"
|
||||
|
||||
echo "Building $goos/$goarch ARM6..."
|
||||
build "$goos" "$goarch" "6"
|
||||
|
||||
echo "Building $goos/$goarch ARM7..."
|
||||
build "$goos" "$goarch" "7"
|
||||
|
||||
;;
|
||||
*)
|
||||
echo "Building $goos/$goarch..."
|
||||
build "$goos" "$goarch"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
;;
|
||||
linux/armv5)
|
||||
goos=${1%/*}
|
||||
goarch=${1#*/}
|
||||
echo "Building $goos/$goarch..."
|
||||
build "$goos" arm "5"
|
||||
;;
|
||||
linux/armv6)
|
||||
goos=${1%/*}
|
||||
goarch=${1#*/}
|
||||
echo "Building $goos/$goarch..."
|
||||
build "$goos" arm "6"
|
||||
;;
|
||||
linux/armv7)
|
||||
goos=${1%/*}
|
||||
goarch=${1#*/}
|
||||
echo "Building $goos/$goarch..."
|
||||
build "$goos" arm "7"
|
||||
;;
|
||||
freebsd/armv6)
|
||||
goos=${1%/*}
|
||||
goarch=${1#*/}
|
||||
echo "Building $goos/$goarch..."
|
||||
build "$goos" arm "6"
|
||||
;;
|
||||
freebsd/armv7)
|
||||
goos=${1%/*}
|
||||
goarch=${1#*/}
|
||||
echo "Building $goos/$goarch..."
|
||||
build "$goos" arm "7"
|
||||
;;
|
||||
*)
|
||||
goos=${1%/*}
|
||||
goarch=${1#*/}
|
||||
if [ -z "$goos" ]; then
|
||||
goos=$(go env GOOS)
|
||||
fi
|
||||
if [ -z "$goarch" ]; then
|
||||
goarch=$(go env GOARCH)
|
||||
fi
|
||||
echo "Building $goos/$goarch..."
|
||||
build "$goos" "$goarch"
|
||||
;;
|
||||
esac
|
||||
|
||||
printf 'Done \360\237\221\214\n'
|
||||
@@ -82,4 +82,8 @@ rules = [
|
||||
{"*.ru" = ["upstream.1"]},
|
||||
{"*.local.host" = ["upstream.2", "upstream.0"]},
|
||||
]
|
||||
macs = [
|
||||
{"14:45:A0:67:83:0A" = ["upstream.2"]},
|
||||
{"14:54:4a:8e:08:2d" = ["upstream.2"]},
|
||||
]
|
||||
`
|
||||
|
||||
Reference in New Issue
Block a user