package store import ( "context" "fmt" "reflect" "sort" "sync" "sync/atomic" "testing" "time" ) func TestUpsert_CreatesHost(t *testing.T) { s := NewMemoryStore() ctx := context.Background() err := s.Upsert(ctx, "api.example.com", func(h *Host) { h.IPs = []string{"1.2.3.4"} h.StatusCode = 200 }) if err != nil { t.Fatal(err) } h, ok := s.Get(ctx, "api.example.com") if !ok { t.Fatal("Get returned !ok after Upsert") } if h.Subdomain != "api.example.com" { t.Errorf("Subdomain = %q", h.Subdomain) } if !reflect.DeepEqual(h.IPs, []string{"1.2.3.4"}) { t.Errorf("IPs = %v", h.IPs) } if h.StatusCode != 200 { t.Errorf("StatusCode = %d", h.StatusCode) } if h.FirstSeen.IsZero() { t.Error("FirstSeen not populated") } if h.LastUpdated.IsZero() { t.Error("LastUpdated not populated") } } func TestUpsert_UpdatesExistingHost(t *testing.T) { s := NewMemoryStore() ctx := context.Background() s.Upsert(ctx, "api.example.com", func(h *Host) { h.StatusCode = 200 }) firstSeen, _ := s.Get(ctx, "api.example.com") time.Sleep(5 * time.Millisecond) // ensure LastUpdated differs s.Upsert(ctx, "api.example.com", func(h *Host) { h.Title = "API" }) h, _ := s.Get(ctx, "api.example.com") if h.StatusCode != 200 { t.Errorf("StatusCode lost: %d", h.StatusCode) } if h.Title != "API" { t.Errorf("Title not set: %q", h.Title) } if !h.FirstSeen.Equal(firstSeen.FirstSeen) { t.Error("FirstSeen changed on update") } if !h.LastUpdated.After(firstSeen.LastUpdated) { t.Error("LastUpdated did not advance") } } func TestUpsert_EmptySubdomainNoop(t *testing.T) { s := NewMemoryStore() ctx := context.Background() if err := s.Upsert(ctx, "", func(h *Host) {}); err != nil { t.Errorf("unexpected error: %v", err) } if s.Count(ctx) != 0 { t.Error("empty subdomain should be a noop") } } func TestUpsert_CanceledContext(t *testing.T) { s := NewMemoryStore() ctx, cancel := context.WithCancel(context.Background()) cancel() if err := s.Upsert(ctx, "a.example.com", func(h *Host) {}); err == nil { t.Error("expected error for canceled context") } } func TestGet_Missing(t *testing.T) { s := NewMemoryStore() _, ok := s.Get(context.Background(), "none.example.com") if ok { t.Error("expected !ok for missing host") } } func TestGet_ReturnsCopy(t *testing.T) { s := NewMemoryStore() ctx := context.Background() s.Upsert(ctx, "a.example.com", func(h *Host) { h.IPs = []string{"1.2.3.4"} h.Technologies = []string{"nginx"} h.Headers = map[string]string{"X-Test": "yes"} h.TLSFingerprint = &TLSFingerprint{Vendor: "Fortinet", InternalHosts: []string{"internal.local"}} }) a, _ := s.Get(ctx, "a.example.com") // mutate returned host aggressively a.IPs[0] = "MUTATED" a.Technologies = append(a.Technologies, "INJECTED") a.Headers["X-Test"] = "MUTATED" a.TLSFingerprint.Vendor = "MUTATED" a.TLSFingerprint.InternalHosts[0] = "MUTATED" b, _ := s.Get(ctx, "a.example.com") if b.IPs[0] != "1.2.3.4" { t.Errorf("IPs corrupted: %v", b.IPs) } if len(b.Technologies) != 1 { t.Errorf("Technologies corrupted: %v", b.Technologies) } if b.Headers["X-Test"] != "yes" { t.Errorf("Headers corrupted: %v", b.Headers) } if b.TLSFingerprint.Vendor != "Fortinet" { t.Errorf("TLSFingerprint.Vendor corrupted: %q", b.TLSFingerprint.Vendor) } if b.TLSFingerprint.InternalHosts[0] != "internal.local" { t.Errorf("InternalHosts corrupted: %v", b.TLSFingerprint.InternalHosts) } } func TestAll_Sorted(t *testing.T) { s := NewMemoryStore() ctx := context.Background() for _, name := range []string{"zeta.example.com", "alpha.example.com", "mid.example.com"} { s.Upsert(ctx, name, func(h *Host) {}) } all := s.All(ctx) got := make([]string, len(all)) for i, h := range all { got[i] = h.Subdomain } want := []string{"alpha.example.com", "mid.example.com", "zeta.example.com"} if !reflect.DeepEqual(got, want) { t.Errorf("All order = %v, want %v", got, want) } } func TestCount(t *testing.T) { s := NewMemoryStore() ctx := context.Background() if s.Count(ctx) != 0 { t.Error("initial Count != 0") } s.Upsert(ctx, "a.example.com", func(h *Host) {}) s.Upsert(ctx, "b.example.com", func(h *Host) {}) s.Upsert(ctx, "a.example.com", func(h *Host) {}) // update, not new if got := s.Count(ctx); got != 2 { t.Errorf("Count = %d, want 2", got) } } func TestConcurrentUpserts_SameHost(t *testing.T) { // All writers target the same host; only one value wins per field but // no race should fire. s := NewMemoryStore() ctx := context.Background() var wg sync.WaitGroup const writers = 50 var counter atomic.Int32 for i := 0; i < writers; i++ { wg.Add(1) go func(i int) { defer wg.Done() s.Upsert(ctx, "hot.example.com", func(h *Host) { h.Technologies = append(h.Technologies, fmt.Sprintf("t%d", i)) counter.Add(1) }) }(i) } wg.Wait() if counter.Load() != writers { t.Errorf("not all mutators ran: %d/%d", counter.Load(), writers) } h, _ := s.Get(ctx, "hot.example.com") if len(h.Technologies) != writers { t.Errorf("expected %d technologies, got %d", writers, len(h.Technologies)) } } func TestConcurrentUpserts_DifferentHosts(t *testing.T) { s := NewMemoryStore() ctx := context.Background() var wg sync.WaitGroup const hosts = 200 for i := 0; i < hosts; i++ { wg.Add(1) go func(i int) { defer wg.Done() s.Upsert(ctx, fmt.Sprintf("h%d.example.com", i), func(h *Host) { h.IPs = []string{"1.2.3.4"} }) }(i) } wg.Wait() if got := s.Count(ctx); got != hosts { t.Errorf("expected %d hosts, got %d", hosts, got) } } func TestClose_Idempotent(t *testing.T) { s := NewMemoryStore() if err := s.Close(); err != nil { t.Fatal(err) } if err := s.Close(); err != nil { t.Fatal(err) } } // ---------- Helper tests ---------- func TestAddDiscoveryMethod(t *testing.T) { h := &Host{} AddDiscoveryMethod(h, "passive:crt.sh") AddDiscoveryMethod(h, "brute") AddDiscoveryMethod(h, "passive:crt.sh") // duplicate if !reflect.DeepEqual(h.DiscoveredVia, []string{"passive:crt.sh", "brute"}) { t.Errorf("DiscoveredVia = %v", h.DiscoveredVia) } } func TestAddIPs_Dedup(t *testing.T) { h := &Host{IPs: []string{"1.1.1.1"}} AddIPs(h, []string{"1.1.1.1", "2.2.2.2", "", "3.3.3.3", "2.2.2.2"}) sort.Strings(h.IPs) want := []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} if !reflect.DeepEqual(h.IPs, want) { t.Errorf("IPs = %v, want %v", h.IPs, want) } } func TestAddTechnologies_Dedup(t *testing.T) { h := &Host{Technologies: []string{"nginx"}} AddTechnologies(h, []string{"nginx", "Go", "", "React", "Go"}) sort.Strings(h.Technologies) want := []string{"Go", "React", "nginx"} if !reflect.DeepEqual(h.Technologies, want) { t.Errorf("Technologies = %v, want %v", h.Technologies, want) } } func TestCloneHost_Nil(t *testing.T) { if got := cloneHost(nil); got != nil { t.Errorf("cloneHost(nil) = %v, want nil", got) } }