package ctrld import ( "context" "crypto/rand" "encoding/hex" "net" "sync" "sync/atomic" "testing" "time" "github.com/miekg/dns" ) func Test_osResolver_Resolve(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { defer cancel() resolver := newResolverWithNameserver([]string{"127.0.0.127:5353"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true _, _ = resolver.Resolve(context.Background(), m) }() select { case <-time.After(10 * time.Second): t.Error("os resolver hangs") case <-ctx.Done(): } } func Test_osResolver_ResolveLanHostname(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() reqId := "req-id" ctx = context.WithValue(ctx, ReqIdCtxKey{}, reqId) ctx = LanQueryCtx(ctx) go func(ctx context.Context) { defer cancel() id, ok := ctx.Value(ReqIdCtxKey{}).(string) if !ok || id != reqId { t.Error("missing request id") return } lan, ok := ctx.Value(LanQueryCtxKey{}).(bool) if !ok || !lan { t.Error("not a LAN query") return } resolver := newResolverWithNameserver([]string{"76.76.2.0:53"}) m := new(dns.Msg) m.SetQuestion("controld.com.", dns.TypeA) m.RecursionDesired = true _, err := resolver.Resolve(ctx, m) if err == nil { t.Error("os resolver succeeded unexpectedly") return } }(ctx) select { case <-time.After(10 * time.Second): t.Error("os resolver hangs") case <-ctx.Done(): } } func Test_osResolver_ResolveWithNonSuccessAnswer(t *testing.T) { // Set up a LAN nameserver that returns a success response. lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") // 127.0.0.1 is considered LAN (loopback) if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, successHandler()) if err != nil { t.Fatalf("failed to run LAN test server: %v", err) } defer lanServer.Shutdown() // Set up two public nameservers that return non-success responses. publicHandlers := []dns.Handler{ nonSuccessHandlerWithRcode(dns.RcodeRefused), nonSuccessHandlerWithRcode(dns.RcodeNameError), } var publicNS []string var publicServers []*dns.Server for _, handler := range publicHandlers { pc, err := net.ListenPacket("udp", ":0") if err != nil { t.Fatalf("failed to listen on public address: %v", err) } s, addr, err := runLocalPacketConnTestServer(t, pc, handler) if err != nil { t.Fatalf("failed to run public test server: %v", err) } publicNS = append(publicNS, addr) publicServers = append(publicServers, s) } defer func() { for _, s := range publicServers { s.Shutdown() } }() // We now create an osResolver which has both a LAN and public nameserver. nss := []string{lanAddr} nss = append(nss, publicNS...) resolver := newResolverWithNameserver(nss) msg := new(dns.Msg) msg.SetQuestion(".", dns.TypeNS) answer, err := resolver.Resolve(context.Background(), msg) if err != nil { t.Fatal(err) } // Since a LAN nameserver is available and returns a success answer, we expect RcodeSuccess. if answer.Rcode != dns.RcodeSuccess { t.Errorf("expected a success answer from LAN nameserver (RcodeSuccess) but got: %s", dns.RcodeToString[answer.Rcode]) } } func Test_osResolver_InitializationRace(t *testing.T) { var wg sync.WaitGroup n := 10 wg.Add(n) for range n { go func() { defer wg.Done() InitializeOsResolver(false) }() } wg.Wait() } func Test_osResolver_Singleflight(t *testing.T) { lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { t.Fatalf("failed to run LAN test server: %v", err) } defer lanServer.Shutdown() or := newResolverWithNameserver([]string{lanAddr}) domain := "controld.com" n := 10 var wg sync.WaitGroup wg.Add(n) for i := 0; i < n; i++ { go func() { defer wg.Done() m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true _, err := or.Resolve(context.Background(), m) if err != nil { t.Error(err) } }() } wg.Wait() // All above queries should only make 1 call to server. if call.Load() != 1 { t.Fatalf("expected 1 result from singleflight lookup, got %d", call) } } func Test_osResolver_HotCache(t *testing.T) { lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { t.Fatalf("failed to run LAN test server: %v", err) } defer lanServer.Shutdown() or := newResolverWithNameserver([]string{lanAddr}) domain := "controld.com" m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true // Make 2 repeated queries to server, should hit hot cache. for i := 0; i < 2; i++ { if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { t.Fatal(err) } } if call.Load() != 1 { t.Fatalf("cache not hit, server was called: %d", call.Load()) } timeoutChan := make(chan struct{}) time.AfterFunc(5*time.Second, func() { close(timeoutChan) }) for { select { case <-timeoutChan: t.Fatal("timed out waiting for cache cleaned") default: count := 0 or.cache.Range(func(key, value interface{}) bool { count++ return true }) if count != 0 { t.Logf("hot cache is not empty: %d elements", count) continue } } break } if _, err := or.Resolve(context.Background(), m.Copy()); err != nil { t.Fatal(err) } if call.Load() != 2 { t.Fatal("cache hit unexpectedly") } } func Test_Edns0_CacheReply(t *testing.T) { lanPC, err := net.ListenPacket("udp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to listen on LAN address: %v", err) } call := &atomic.Int64{} lanServer, lanAddr, err := runLocalPacketConnTestServer(t, lanPC, countHandler(call)) if err != nil { t.Fatalf("failed to run LAN test server: %v", err) } defer lanServer.Shutdown() or := newResolverWithNameserver([]string{lanAddr}) domain := "controld.com" m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeA) m.RecursionDesired = true do := func() *dns.Msg { msg := m.Copy() msg.SetEdns0(4096, true) cookieOption := new(dns.EDNS0_COOKIE) cookieOption.Code = dns.EDNS0COOKIE cookieOption.Cookie = generateEdns0ClientCookie() msg.IsEdns0().Option = append(msg.IsEdns0().Option, cookieOption) answer, err := or.Resolve(context.Background(), msg) if err != nil { t.Fatal(err) } return answer } answer1 := do() answer2 := do() // Ensure the cache was hit, so we can check that edns0 cookie must be modified. if call.Load() != 1 { t.Fatalf("cache not hit, server was called: %d", call.Load()) } cookie1 := getEdns0Cookie(answer1.IsEdns0()) cookie2 := getEdns0Cookie(answer2.IsEdns0()) if cookie1 == nil || cookie2 == nil { t.Fatalf("unexpected nil cookie value (cookie1: %v, cookie2: %v)", cookie1, cookie2) } if cookie1.Cookie == cookie2.Cookie { t.Fatalf("edns0 cookie is not modified: %v", cookie1) } } func Test_upstreamTypeFromEndpoint(t *testing.T) { tests := []struct { name string endpoint string resolverType string }{ {"doh", "https://freedns.controld.com/p2", ResolverTypeDOH}, {"doq", "quic://p2.freedns.controld.com", ResolverTypeDOQ}, {"dot", "p2.freedns.controld.com", ResolverTypeDOT}, {"legacy", "8.8.8.8:53", ResolverTypeLegacy}, {"legacy ipv6", "[2404:6800:4005:809::200e]:53", ResolverTypeLegacy}, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() if rt := ResolverTypeFromEndpoint(tc.endpoint); rt != tc.resolverType { t.Errorf("mismatch, want: %s, got: %s", tc.resolverType, rt) } }) } } func runLocalPacketConnTestServer(t *testing.T, pc net.PacketConn, handler dns.Handler, opts ...func(*dns.Server)) (*dns.Server, string, error) { t.Helper() server := &dns.Server{ PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour, Handler: handler, } waitLock := sync.Mutex{} waitLock.Lock() server.NotifyStartedFunc = waitLock.Unlock for _, opt := range opts { opt(server) } addr, closer := pc.LocalAddr().String(), pc go func() { if err := server.ActivateAndServe(); err != nil { t.Error(err) } closer.Close() }() waitLock.Lock() return server, addr, nil } func successHandler() dns.HandlerFunc { return func(w dns.ResponseWriter, msg *dns.Msg) { m := new(dns.Msg) m.SetRcode(msg, dns.RcodeSuccess) w.WriteMsg(m) } } func nonSuccessHandlerWithRcode(rcode int) dns.HandlerFunc { return func(w dns.ResponseWriter, msg *dns.Msg) { m := new(dns.Msg) m.SetRcode(msg, rcode) w.WriteMsg(m) } } func countHandler(call *atomic.Int64) dns.HandlerFunc { return func(w dns.ResponseWriter, msg *dns.Msg) { m := new(dns.Msg) m.SetRcode(msg, dns.RcodeSuccess) if cookie := getEdns0Cookie(msg.IsEdns0()); cookie != nil { if m.IsEdns0() == nil { m.SetEdns0(4096, false) } cookieOption := new(dns.EDNS0_COOKIE) cookieOption.Code = dns.EDNS0COOKIE cookieOption.Cookie = generateEdns0ServerCookie(cookie.Cookie) m.IsEdns0().Option = append(m.IsEdns0().Option, cookieOption) } w.WriteMsg(m) call.Add(1) } } func generateEdns0ClientCookie() string { cookie := make([]byte, 8) if _, err := rand.Read(cookie); err != nil { panic(err) } return hex.EncodeToString(cookie) } func generateEdns0ServerCookie(clientCookie string) string { cookie := make([]byte, 32) if _, err := rand.Read(cookie); err != nil { panic(err) } return clientCookie + hex.EncodeToString(cookie) }