Files
ctrld/internal/clientinfo/ptr_lookup.go
Cuong Manh Le fd48e6d795 fix: ensure upstream health checks can handle large DNS responses
- Add UpstreamConfig.VerifyMsg() method with proper EDNS0 support
- Replace hardcoded DNS messages in health checks with standardized verification method
- Set EDNS0 buffer size to 4096 bytes to handle large DNS responses
- Add test case for legacy resolver with extensive extra sections
2025-08-15 22:55:47 +07:00

139 lines
3.1 KiB
Go

package clientinfo
import (
"context"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"tailscale.com/logtail/backoff"
"github.com/Control-D-Inc/ctrld"
)
type ptrDiscover struct {
hostname sync.Map // ip => hostname
resolver ctrld.Resolver
serverDown atomic.Bool
}
func (p *ptrDiscover) refresh() error {
p.hostname.Range(func(key, value any) bool {
ip := key.(string)
if name := p.lookupHostname(ip); name != "" {
p.hostname.Store(ip, name)
}
return true
})
return nil
}
func (p *ptrDiscover) LookupHostnameByIP(ip string) string {
if val, ok := p.hostname.Load(ip); ok {
return val.(string)
}
return p.lookupHostname(ip)
}
func (p *ptrDiscover) LookupHostnameByMac(mac string) string {
return ""
}
func (p *ptrDiscover) String() string {
return "ptr"
}
func (p *ptrDiscover) List() []string {
if p == nil {
return nil
}
var ips []string
p.hostname.Range(func(key, value any) bool {
ips = append(ips, key.(string))
return true
})
return ips
}
func (p *ptrDiscover) lookupHostnameFromCache(ip string) string {
if val, ok := p.hostname.Load(ip); ok {
return val.(string)
}
return ""
}
func (p *ptrDiscover) lookupHostname(ip string) string {
// If nameserver is down, do nothing.
if p.serverDown.Load() {
return ""
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
msg := new(dns.Msg)
addr, err := dns.ReverseAddr(ip)
if err != nil {
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("invalid ip address")
return ""
}
msg.SetQuestion(addr, dns.TypePTR)
ans, err := p.resolver.Resolve(ctx, msg)
if err != nil {
if p.serverDown.CompareAndSwap(false, true) {
ctrld.ProxyLogger.Load().Info().Str("discovery", "ptr").Err(err).Msg("could not perform PTR lookup")
go p.checkServer()
}
return ""
}
for _, rr := range ans.Answer {
if ptr, ok := rr.(*dns.PTR); ok {
hostname := normalizeHostname(ptr.Ptr)
p.hostname.Store(ip, hostname)
return hostname
}
}
return ""
}
func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string {
if p == nil {
return ""
}
var ip string
p.hostname.Range(func(key, value any) bool {
if value == name {
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
ip = addr.String()
//lint:ignore S1008 This is used for readable.
if addr.IsLoopback() { // Continue searching if this is loopback address.
return true
}
return false
}
}
return true
})
return ip
}
// checkServer monitors if the resolver can reach its nameserver. When the nameserver
// is reachable, set p.serverDown to false, so p.lookupHostname can continue working.
func (p *ptrDiscover) checkServer() {
bo := backoff.NewBackoff("ptrDiscover", func(format string, args ...any) {}, time.Minute*5)
m := (&ctrld.UpstreamConfig{}).VerifyMsg()
ping := func() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err := p.resolver.Resolve(ctx, m)
return err
}
for {
if err := ping(); err != nil {
bo.BackOff(context.Background(), err)
continue
}
break
}
p.serverDown.Store(false)
}