diff --git a/internal/clientinfo/client_info.go b/internal/clientinfo/client_info.go index a026d1d..ad17121 100644 --- a/internal/clientinfo/client_info.go +++ b/internal/clientinfo/client_info.go @@ -48,10 +48,11 @@ type Table struct { ptr *ptrDiscover mdns *mdns cfg *ctrld.Config + quitCh chan struct{} } func NewTable(cfg *ctrld.Config) *Table { - return &Table{cfg: cfg} + return &Table{cfg: cfg, quitCh: make(chan struct{})} } func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) { @@ -70,6 +71,7 @@ func (t *Table) RefreshLoop(stopCh chan struct{}) { _ = r.refresh() } case <-stopCh: + close(t.quitCh) return } } @@ -122,7 +124,7 @@ func (t *Table) Init() { if t.discoverMDNS() { t.mdns = &mdns{} ctrld.ProxyLog.Debug().Msg("start mdns discovery") - if err := t.mdns.init(); err != nil { + if err := t.mdns.init(t.quitCh); err != nil { ctrld.ProxyLog.Error().Err(err).Msg("could not init mDNS discover") } else { t.hostnameResolvers = append(t.hostnameResolvers, t.mdns) diff --git a/internal/clientinfo/mdns.go b/internal/clientinfo/mdns.go index ce92d50..ac34713 100644 --- a/internal/clientinfo/mdns.go +++ b/internal/clientinfo/mdns.go @@ -4,7 +4,9 @@ import ( "context" "errors" "net" + "os" "sync" + "syscall" "time" "github.com/miekg/dns" @@ -41,7 +43,7 @@ func (m *mdns) LookupHostnameByMac(mac string) string { return "" } -func (m *mdns) init() error { +func (m *mdns) init(quitCh chan struct{}) error { ifaces, err := multicastInterfaces() if err != nil { return err @@ -65,20 +67,35 @@ func (m *mdns) init() error { } } - go func() { - bo := backoff.NewBackoff("mdns probe", func(format string, args ...any) {}, time.Second*30) - for { - err := m.probe(v4ConnList, v6ConnList) - if err != nil { - ctrld.ProxyLog.Warn().Err(err).Msg("error while probing mdns") - } - bo.BackOff(context.Background(), errors.New("mdns probe backoff")) - } - }() + go m.probeLoop(v4ConnList, mdnsV4Addr, quitCh) + go m.probeLoop(v6ConnList, mdnsV6Addr, quitCh) return nil } +func (m *mdns) probeLoop(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan struct{}) { + bo := backoff.NewBackoff("mdns probe", func(format string, args ...any) {}, time.Second*30) + for { + err := m.probe(conns, remoteAddr, quitCh) + if isErrNetUnreachableOrInvalid(err) { + ctrld.ProxyLog.Warn().Msgf("stop probing %q: network unreachable or invalid", remoteAddr) + break + } + if err != nil { + ctrld.ProxyLog.Warn().Err(err).Msg("error while probing mdns") + bo.BackOff(context.Background(), errors.New("mdns probe backoff")) + } + select { + case <-quitCh: + break + } + } + <-quitCh + for _, conn := range conns { + _ = conn.Close() + } +} + func (m *mdns) readLoop(conn *net.UDPConn) { defer conn.Close() buf := make([]byte, dns.MaxMsgSize) @@ -87,12 +104,10 @@ func (m *mdns) readLoop(conn *net.UDPConn) { _ = conn.SetReadDeadline(time.Now().Add(time.Second * 30)) n, _, err := conn.ReadFromUDP(buf) if err != nil { - if err, ok := err.(*net.OpError); ok { - if err.Timeout() || err.Temporary() { - continue - } - ctrld.ProxyLog.Debug().Err(err).Msg("mdns readLoop error") + if err, ok := err.(*net.OpError); ok && (err.Timeout() || err.Temporary()) { + continue } + ctrld.ProxyLog.Debug().Err(err).Msg("mdns readLoop error") return } @@ -111,14 +126,22 @@ func (m *mdns) readLoop(conn *net.UDPConn) { } if ip != "" && name != "" { name = normalizeHostname(name) - ctrld.ProxyLog.Debug().Msgf("Found hostname: %q, ip: %q via mdns", name, ip) - m.name.Store(ip, name) + if val, loaded := m.name.LoadOrStore(ip, name); !loaded { + ctrld.ProxyLog.Debug().Msgf("found hostname: %q, ip: %q via mdns", name, ip) + } else { + old := val.(string) + if old != name { + ctrld.ProxyLog.Debug().Msgf("update hostname: %q, ip: %q, old: %q via mdns", name, ip, old) + m.name.Store(ip, name) + } + } + ip, name = "", "" } } } } -func (m *mdns) probe(v4connList, v6connList []*net.UDPConn) error { +func (m *mdns) probe(conns []*net.UDPConn, remoteAddr net.Addr, quitCh chan struct{}) error { msg := new(dns.Msg) msg.Question = make([]dns.Question, len(services)) for i, service := range services { @@ -133,16 +156,13 @@ func (m *mdns) probe(v4connList, v6connList []*net.UDPConn) error { if err != nil { return err } - do := func(connList []*net.UDPConn, remoteAddr net.Addr) error { - for _, conn := range connList { - _ = conn.SetWriteDeadline(time.Now().Add(time.Second * 30)) - if _, err := conn.WriteTo(buf, remoteAddr); err != nil { - return err - } + for _, conn := range conns { + _ = conn.SetWriteDeadline(time.Now().Add(time.Second * 30)) + if _, werr := conn.WriteTo(buf, remoteAddr); werr != nil { + err = werr } - return nil } - return errors.Join(do(v4connList, mdnsV4Addr), do(v6connList, mdnsV6Addr)) + return err } func multicastInterfaces() ([]net.Interface, error) { @@ -161,3 +181,11 @@ func multicastInterfaces() ([]net.Interface, error) { } return interfaces, nil } + +func isErrNetUnreachableOrInvalid(err error) bool { + var se *os.SyscallError + if errors.As(err, &se) { + return se.Err == syscall.ENETUNREACH || se.Err == syscall.EINVAL + } + return false +}