Files
CVEs-PoC/scripts/poc_pipeline.py

275 lines
11 KiB
Python

from __future__ import annotations
import re
from dataclasses import dataclass, field
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple
from github_client import GitHubClient, SearchResult, build_client
from poc_scoring import match_score, score_repo
from utils import API_DIR, EVIDENCE_DIR, chunked, cve_year, ensure_dirs, isoformat, load_blacklist, load_json, save_json, today_str
LANG_PARTITIONS = ("python", "go", "c", "shell", "powershell", "java", "ruby", "js")
CVE_RE = re.compile(r"CVE-\d{4}-\d{4,}", re.IGNORECASE)
@dataclass
class MatchEvidence:
path: str
match_type: str
query: str
score: float | None = None
@dataclass
class RepoCandidate:
cve_id: str
repo_full_name: str
repo_url: str
matches: List[MatchEvidence] = field(default_factory=list)
metadata: Dict[str, object] = field(default_factory=dict)
def add_match(self, path: str, match_type: str, query: str) -> None:
key = (path, match_type)
existing = {(m.path, m.match_type) for m in self.matches}
if key in existing:
return
self.matches.append(MatchEvidence(path=path, match_type=match_type, query=query))
def build_created_ranges(days: int, *, window: int = 7) -> List[Tuple[str, str]]:
end = date.today()
start = end - timedelta(days=max(days, 1))
ranges: List[Tuple[str, str]] = []
cursor = start
while cursor <= end:
window_end = min(cursor + timedelta(days=window - 1), end)
ranges.append((cursor.isoformat(), window_end.isoformat()))
cursor = window_end + timedelta(days=1)
return ranges or [(start.isoformat(), end.isoformat())]
def build_query_pack(cve_id: str, created_range: Tuple[str, str] | None = None) -> List[Dict[str, str]]:
base_repo = f'{cve_id} in:name,description,readme fork:false'
enriched_repo = f'{cve_id} (poc OR exploit) in:name,description,readme fork:false'
topic_query = f"topic:{cve_id.lower()} fork:false"
created_suffix = ""
if created_range:
created_suffix = f" created:{created_range[0]}..{created_range[1]}"
queries = [
{"kind": "repositories", "query": base_repo + created_suffix, "match_type": "name"},
{"kind": "repositories", "query": enriched_repo + created_suffix, "match_type": "description"},
{"kind": "repositories", "query": topic_query + created_suffix, "match_type": "topic"},
]
for lang in LANG_PARTITIONS:
base_code = f'{cve_id} in:file language:{lang}{created_suffix}'
queries.append({"kind": "code", "query": base_code, "match_type": "code"})
# generic code search without language partition for the most recent window
queries.append({"kind": "code", "query": f"{cve_id} in:file{created_suffix}", "match_type": "code"})
return queries
def parse_repo_from_item(item: Dict) -> Tuple[str | None, str | None]:
repo_full_name = item.get("full_name") or item.get("repository", {}).get("full_name")
repo_url = item.get("html_url") or item.get("repository", {}).get("html_url")
if not repo_full_name and "repository" in item:
repo_full_name = item["repository"].get("owner", {}).get("login", "")
if repo_full_name:
repo_full_name = f"{repo_full_name}/{item['repository'].get('name', '')}"
return repo_full_name, repo_url
def extract_matches(item: Dict, default_type: str, query: str) -> List[MatchEvidence]:
matches: List[MatchEvidence] = []
for text_match in item.get("text_matches", []) or []:
prop = text_match.get("property") or text_match.get("object_type") or ""
fragment = text_match.get("fragment") or text_match.get("path") or prop or ""
match_type = prop if prop else default_type
matches.append(MatchEvidence(path=str(fragment), match_type=str(match_type), query=query))
if not matches:
path = item.get("path") or default_type
matches.append(MatchEvidence(path=str(path), match_type=default_type, query=query))
return matches
def normalise_metadata(meta: Dict, fallback_full_name: str, fallback_url: str) -> Dict:
topics = []
if meta.get("repositoryTopics"):
for node in meta["repositoryTopics"].get("nodes", []):
topic = (node.get("topic") or {}).get("name")
if topic:
topics.append(topic)
primary_language = None
if meta.get("primaryLanguage"):
primary_language = meta["primaryLanguage"].get("name")
parent = meta.get("parent") or {}
return {
"repo_full_name": meta.get("nameWithOwner") or fallback_full_name,
"repo_url": meta.get("url") or fallback_url,
"description": meta.get("description") or "",
"is_fork": bool(meta.get("isFork")),
"parent_repo_url": parent.get("url"),
"stars": meta.get("stargazerCount") or 0,
"forks": meta.get("forkCount") or 0,
"archived": bool(meta.get("isArchived")),
"pushed_at": meta.get("pushedAt"),
"updated_at": meta.get("updatedAt"),
"topics": topics,
"primary_language": primary_language,
}
class PoCPipeline:
def __init__(
self,
client: GitHubClient | None = None,
*,
blacklist_path: Path | None = None,
search_ttl: int = 3 * 3600,
) -> None:
self.client = client or build_client()
self.blacklist = load_blacklist(blacklist_path)
self.search_ttl = search_ttl
def _run_query(self, query: Dict, page: int) -> SearchResult:
if query["kind"] == "repositories":
return self.client.search_repositories(query["query"], page=page, per_page=50, ttl=self.search_ttl)
if query["kind"] == "code":
return self.client.search_code(query["query"], page=page, per_page=50, ttl=self.search_ttl)
return self.client.search_topics(query["query"], page=page, per_page=50, ttl=self.search_ttl)
def discover_for_cve(self, cve_id: str, *, days: int, max_pages_repo: int = 2, max_pages_code: int = 2) -> Dict:
ranges = build_created_ranges(days)
candidates: Dict[str, RepoCandidate] = {}
query_log: List[Dict] = []
for created_range in ranges:
query_pack = build_query_pack(cve_id, created_range)
for query in query_pack:
query_log.append({"query": query["query"], "kind": query["kind"], "window": created_range})
page_limit = max_pages_code if query["kind"] == "code" else max_pages_repo
for page in range(1, page_limit + 1):
result = self._run_query(query, page)
items = result.payload.get("items", [])
for item in items:
repo_full_name, repo_url = parse_repo_from_item(item)
if not repo_full_name or not repo_url:
continue
candidate = candidates.setdefault(
repo_full_name,
RepoCandidate(cve_id=cve_id, repo_full_name=repo_full_name, repo_url=repo_url),
)
for match in extract_matches(item, query["match_type"], query["query"]):
candidate.add_match(match.path, match.match_type, match.query)
if len(items) < 50:
break
metadata = self.client.fetch_repo_metadata(candidates.keys())
for repo_full_name, candidate in candidates.items():
meta = metadata.get(repo_full_name, {})
candidate.metadata = normalise_metadata(meta, repo_full_name, candidate.repo_url)
repos: List[Dict] = []
for candidate in candidates.values():
matches_dicts = []
for m in candidate.matches:
m.score = match_score({"path": m.path, "match_type": m.match_type})
matches_dicts.append({"path": m.path, "match_type": m.match_type, "query": m.query, "score": m.score})
score, tier = score_repo(candidate.metadata, matches_dicts, self.blacklist)
repo_entry = {
**candidate.metadata,
"matches": matches_dicts,
"confidence_score": score,
"confidence_tier": tier,
"cve_id": cve_id,
}
repos.append(repo_entry)
repos.sort(key=lambda r: (-r["confidence_score"], -r.get("stars", 0)))
evidence = {
"queries": query_log,
"candidates": [
{
"repo_full_name": r["repo_full_name"],
"matches": r["matches"],
"match_count": len(r["matches"]),
"score": r["confidence_score"],
"tier": r["confidence_tier"],
}
for r in repos
],
}
return {"cve_id": cve_id, "last_updated": isoformat(), "pocs": repos, "evidence": evidence}
def discover_many(self, cve_ids: Iterable[str], *, days: int, limit: Optional[int] = None) -> List[Dict]:
results: List[Dict] = []
for idx, cve_id in enumerate(cve_ids):
if limit and idx >= limit:
break
results.append(self.discover_for_cve(cve_id, days=days))
return results
def persist_evidence(results: List[Dict]) -> None:
ensure_dirs(EVIDENCE_DIR)
for result in results:
cve_id = result["cve_id"]
evidence_path = EVIDENCE_DIR / f"{cve_id}.json"
save_json(evidence_path, result.get("evidence", {}))
def discover_from_github_list(path: Path) -> List[str]:
if not path.exists():
return []
ids: List[str] = []
for line in path.read_text(encoding="utf-8").splitlines():
matches = CVE_RE.findall(line)
for match in matches:
if match.upper() not in ids:
ids.append(match.upper())
return ids
def load_existing_cves(api_dir: Path = API_DIR / "cve") -> List[str]:
if not api_dir.exists():
return []
return sorted({p.stem.upper() for p in api_dir.glob("CVE-*.json") if CVE_RE.match(p.stem)})
def build_scope(
days: int,
*,
github_list: Path,
existing_api: Path,
prefer_recent_years: bool = True,
max_cves: int | None = None,
low_conf_threshold: int = 1,
) -> List[str]:
seeds = discover_from_github_list(github_list)
existing = load_existing_cves(existing_api)
candidates = seeds or existing
if prefer_recent_years:
current_year = date.today().year
candidates = [cve for cve in candidates if cve_year(cve) and cve_year(cve) >= current_year - 2] or candidates
index_path = API_DIR / "index.json"
low_conf: List[str] = []
if index_path.exists():
index_payload = load_json(index_path, default={}) or {}
for item in index_payload.get("items", []):
score = (item.get("high_confidence", 0) or 0) + (item.get("medium_confidence", 0) or 0)
if score <= low_conf_threshold:
low_conf.append(item.get("cve_id"))
scoped = candidates + [cve for cve in low_conf if cve and cve not in candidates]
if max_cves:
scoped = scoped[:max_cves]
return scoped