mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-25 07:20:08 +02:00
297 lines
8.6 KiB
Go
297 lines
8.6 KiB
Go
package project
|
|
|
|
import (
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"cyberstrike-ai/internal/database"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
func TestParseFactLinksText(t *testing.T) {
|
|
t.Parallel()
|
|
inputs, err := ParseFactLinksText("discovered_on: target/api\nleads_to: finding/swagger")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(inputs) != 2 {
|
|
t.Fatalf("want 2 links, got %d", len(inputs))
|
|
}
|
|
if inputs[0].Type != "discovered_on" || inputs[0].From != "target/api" {
|
|
t.Fatalf("unexpected first link: %+v", inputs[0])
|
|
}
|
|
}
|
|
|
|
func TestParseFactIncomingLinksText(t *testing.T) {
|
|
t.Parallel()
|
|
inputs, err := ParseFactIncomingLinksText("leads_to: finding/swagger\ndepends_on: target/api")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(inputs) != 2 {
|
|
t.Fatalf("want 2 links, got %d", len(inputs))
|
|
}
|
|
if inputs[0].Type != "leads_to" || inputs[0].From != "finding/swagger" {
|
|
t.Fatalf("unexpected first link: %+v", inputs[0])
|
|
}
|
|
}
|
|
|
|
func TestFormatFactIncomingLinksText(t *testing.T) {
|
|
t.Parallel()
|
|
text := FormatFactIncomingLinksText([]*database.ProjectFactEdge{
|
|
{EdgeType: "leads_to", SourceFactKey: "finding/a"},
|
|
{EdgeType: "depends_on", SourceFactKey: "target/b"},
|
|
})
|
|
want := "leads_to: finding/a\ndepends_on: target/b"
|
|
if text != want {
|
|
t.Fatalf("got %q want %q", text, want)
|
|
}
|
|
}
|
|
|
|
func TestParseFactLinkInputsEmptyClears(t *testing.T) {
|
|
t.Parallel()
|
|
parsed, err := ParseFactLinkInputs([]interface{}{})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if parsed == nil || parsed.Incoming == nil || len(parsed.Incoming) != 0 {
|
|
t.Fatalf("empty array should clear incoming links, got %v", parsed)
|
|
}
|
|
}
|
|
|
|
func TestParseFactLinkInputsFrom(t *testing.T) {
|
|
t.Parallel()
|
|
raw := []interface{}{
|
|
map[string]interface{}{
|
|
"from": "target/primary_domain",
|
|
"type": "discovered_on",
|
|
},
|
|
}
|
|
parsed, err := ParseFactLinkInputs(raw)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(parsed.Incoming) != 1 || parsed.Incoming[0].From != "target/primary_domain" {
|
|
t.Fatalf("unexpected incoming: %+v", parsed.Incoming)
|
|
}
|
|
}
|
|
|
|
func TestParseFactLinkInputsRequiresFrom(t *testing.T) {
|
|
t.Parallel()
|
|
raw := []interface{}{
|
|
map[string]interface{}{
|
|
"to": "target/primary_domain",
|
|
"type": "discovered_on",
|
|
},
|
|
}
|
|
_, err := ParseFactLinkInputs(raw)
|
|
if err == nil {
|
|
t.Fatal("expected error when from is missing")
|
|
}
|
|
}
|
|
|
|
func TestGraphNodeType(t *testing.T) {
|
|
t.Parallel()
|
|
if GraphNodeType("chain", "chain/x") != "chain" {
|
|
t.Fatal("chain category")
|
|
}
|
|
if GraphNodeType("finding", "finding/x") != "finding" {
|
|
t.Fatal("finding category")
|
|
}
|
|
if GraphNodeType("exploit", "exploit/x") != "exploit" {
|
|
t.Fatal("exploit category")
|
|
}
|
|
if GraphNodeType("finding", "evidence/x") != "finding" {
|
|
t.Fatal("category should override evidence key prefix")
|
|
}
|
|
if GraphNodeType("note", "target/x") != "note" {
|
|
t.Fatal("category should override target key prefix")
|
|
}
|
|
if GraphNodeType("vuln", "finding/x") != "vulnerability" {
|
|
t.Fatal("vuln category maps to vulnerability node type")
|
|
}
|
|
if GraphNodeType("", "target/x") != "target" {
|
|
t.Fatal("empty category falls back to target key prefix")
|
|
}
|
|
}
|
|
|
|
func TestBuildProjectFactGraphPreservesStoredEdgeDirection(t *testing.T) {
|
|
dir := t.TempDir()
|
|
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer db.Close()
|
|
|
|
p, err := db.CreateProject(&database.Project{Name: "path-edges"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for _, spec := range []struct{ key, cat string }{
|
|
{"target/primary_domain", "target"},
|
|
{"chain/full_attack_path", "chain"},
|
|
{"finding/mysql_public", "finding"},
|
|
{"exploit/mysql_creds_extract", "exploit"},
|
|
} {
|
|
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
|
ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.key, Confidence: "confirmed",
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/mysql_public", []database.ProjectFactEdgeFromInput{
|
|
{From: "target/primary_domain", Type: "discovered_on"},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/mysql_public", []database.ProjectFactEdgeFromInput{
|
|
{From: "target/primary_domain", Type: "discovered_on"},
|
|
{From: "exploit/mysql_creds_extract", Type: "exploits"},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "chain/full_attack_path", []database.ProjectFactEdgeFromInput{
|
|
{From: "target/primary_domain", Type: "discovered_on"},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "exploit/mysql_creds_extract", []database.ProjectFactEdgeFromInput{
|
|
{From: "chain/full_attack_path", Type: "leads_to"},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
graph, err := BuildProjectFactGraph(db, p.ID, "path", true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
want := map[string]struct{}{
|
|
"target/primary_domain|discovered_on|finding/mysql_public": {},
|
|
"exploit/mysql_creds_extract|exploits|finding/mysql_public": {},
|
|
"target/primary_domain|discovered_on|chain/full_attack_path": {},
|
|
"chain/full_attack_path|leads_to|exploit/mysql_creds_extract": {},
|
|
}
|
|
for _, e := range graph.Edges {
|
|
key := e.Source + "|" + e.Type + "|" + e.Target
|
|
delete(want, key)
|
|
}
|
|
if len(want) > 0 {
|
|
t.Fatalf("missing expected stored-direction edges: %v", want)
|
|
}
|
|
countInOut := func(factKey string) (out, in int) {
|
|
for _, e := range graph.Edges {
|
|
if e.Source == factKey {
|
|
out++
|
|
}
|
|
if e.Target == factKey {
|
|
in++
|
|
}
|
|
}
|
|
return out, in
|
|
}
|
|
if out, in := countInOut("chain/full_attack_path"); out != 1 || in != 1 {
|
|
t.Fatalf("chain/full_attack_path want out=1 in=1 got out=%d in=%d", out, in)
|
|
}
|
|
if out, in := countInOut("exploit/mysql_creds_extract"); out != 1 || in != 1 {
|
|
t.Fatalf("exploit/mysql_creds_extract want out=1 in=1 got out=%d in=%d", out, in)
|
|
}
|
|
}
|
|
|
|
func TestPersistFactLinksFromUsesFromAsIncoming(t *testing.T) {
|
|
dir := t.TempDir()
|
|
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer db.Close()
|
|
|
|
p, err := db.CreateProject(&database.Project{Name: "from-links"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for _, spec := range []struct{ key, cat string }{
|
|
{"target/primary_domain", "target"},
|
|
{"finding/sqli", "finding"},
|
|
} {
|
|
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
|
ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.key, Confidence: "confirmed",
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
parsed := &ParsedFactLinks{
|
|
Incoming: []database.ProjectFactEdgeFromInput{
|
|
{From: "target/primary_domain", Type: "discovered_on"},
|
|
},
|
|
}
|
|
if err := PersistFactLinksFromParsed(db, p.ID, "finding/sqli", "", parsed, false); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
graph, err := BuildProjectFactGraph(db, p.ID, "path", true)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
want := "target/primary_domain|discovered_on|finding/sqli"
|
|
for _, e := range graph.Edges {
|
|
key := e.Source + "|" + e.Type + "|" + e.Target
|
|
if key == want {
|
|
return
|
|
}
|
|
}
|
|
t.Fatalf("expected edge %s, got %+v", want, graph.Edges)
|
|
}
|
|
|
|
func TestFormatOutgoingLinksHint(t *testing.T) {
|
|
t.Parallel()
|
|
hint := FormatOutgoingLinksHint([]*database.ProjectFactEdge{
|
|
{EdgeType: "discovered_on", TargetFactKey: "target/a"},
|
|
})
|
|
if hint == "" || hint[0] != ' ' {
|
|
t.Fatalf("unexpected hint: %q", hint)
|
|
}
|
|
}
|
|
|
|
func TestReplaceIncomingAllowsNotYetCreatedSource(t *testing.T) {
|
|
dir := t.TempDir()
|
|
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer db.Close()
|
|
|
|
p, err := db.CreateProject(&database.Project{Name: "parallel-links"})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
|
ProjectID: p.ID, FactKey: "exploit/sqli", Category: "exploit", Summary: "exploit", Confidence: "confirmed",
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "exploit/sqli", []database.ProjectFactEdgeFromInput{
|
|
{From: "finding/sqli_endpoint", Type: "exploits"},
|
|
}); err != nil {
|
|
t.Fatalf("incoming edge should not require source fact to exist yet: %v", err)
|
|
}
|
|
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
|
ProjectID: p.ID, FactKey: "finding/sqli_endpoint", Category: "finding", Summary: "finding", Confidence: "confirmed",
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
in, err := db.ListIncomingProjectFactEdges(p.ID, "exploit/sqli")
|
|
if err != nil || len(in) != 1 || in[0].SourceFactKey != "finding/sqli_endpoint" {
|
|
t.Fatalf("expected persisted edge from finding, got %+v err=%v", in, err)
|
|
}
|
|
}
|
|
|
|
func TestValidateProjectFactEdgeType(t *testing.T) {
|
|
t.Parallel()
|
|
if err := database.ValidateProjectFactEdgeType("leads_to"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := database.ValidateProjectFactEdgeType("invalid"); err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
}
|