mirror of
https://github.com/Kuro0911/CS5446-Project.git
synced 2026-02-12 21:12:54 +00:00
229 lines
7.8 KiB
Python
229 lines
7.8 KiB
Python
# prompt_based.py
|
|
from __future__ import annotations
|
|
import re, time, random
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Dict, List, Optional, Tuple
|
|
|
|
# -----------------------------
|
|
# Optional deps (robust fallback)
|
|
# -----------------------------
|
|
try:
|
|
import torch
|
|
import torch.nn.functional as F
|
|
except Exception:
|
|
torch, F = None, None
|
|
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
except Exception:
|
|
SentenceTransformer = None
|
|
|
|
# -----------------------------
|
|
# Config
|
|
# -----------------------------
|
|
@dataclass
|
|
class PromptDefenseConfig:
|
|
# Embeddings
|
|
emb_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
|
emb_thresh: float = 0.15 # clamp tiny cosine sims to 0 (denoise)
|
|
# Risk thresholds in [0,1]
|
|
strict_threshold: float = 0.30
|
|
extreme_threshold: float = 0.50
|
|
# Transform toggles
|
|
paraphrase: bool = False # generally better OFF (can help attackers)
|
|
max_edits: int = 0
|
|
strip_urls: bool = True
|
|
strip_injections: bool = True
|
|
|
|
# -----------------------------
|
|
# Patterns & text blocks
|
|
# -----------------------------
|
|
_JAILBREAK_MARKERS = [
|
|
r"ignore\s+previous\s+instructions",
|
|
r"override\s+system\s+prompt",
|
|
r"as\s+an\s+uncensored\s+model",
|
|
r"without\s+ethical\s+constraints",
|
|
r"developer\s+mode",
|
|
r"\{\{.*?\}\}",
|
|
r"#\s*system\s*:",
|
|
r"#\s*assistant\s*:",
|
|
r"simulate\s+(?:being|as)",
|
|
r"role\s*play\s+as",
|
|
r"\b(bypass|evade|circumvent)\b",
|
|
r"\bjailbreak\b",
|
|
]
|
|
|
|
MILD_PREFIX = (
|
|
"Safety note: I will keep this discussion high-level and avoid sensitive or harmful details. "
|
|
"If your intent is benign, please clarify as needed.\n\n"
|
|
)
|
|
STRICT_PREFIX = (
|
|
"Safety notice: The request appears potentially risky. I will NOT provide harmful, illegal, or dangerous instructions. "
|
|
"I can offer safe, high-level context and alternatives instead.\n\n"
|
|
)
|
|
EXTREME_PREFIX = (
|
|
"Safety override: This request appears unsafe. I will refuse to provide any operational or step-by-step guidance. "
|
|
"If you have a lawful, safe learning goal, please reframe the question accordingly.\n\n"
|
|
)
|
|
|
|
# -----------------------------
|
|
# Light paraphraser (optional)
|
|
# -----------------------------
|
|
_SYNONYM_MAP = {
|
|
"give": ["provide", "offer"],
|
|
"tell": ["share", "explain"],
|
|
"hack": ["compromise"],
|
|
"bypass": ["avoid", "circumvent"],
|
|
}
|
|
|
|
def _paraphrase_light(s: str, k: int) -> str:
|
|
if k <= 0:
|
|
return s
|
|
words = s.split()
|
|
idxs = list(range(len(words)))
|
|
random.shuffle(idxs)
|
|
edits = 0
|
|
for i in idxs:
|
|
token = words[i]
|
|
base = re.sub(r"\W+", "", token).lower()
|
|
if base in _SYNONYM_MAP:
|
|
repl = random.choice(_SYNONYM_MAP[base])
|
|
# replace base (case-insensitive) but preserve punctuation around it
|
|
words[i] = re.sub(base, repl, token, flags=re.IGNORECASE)
|
|
edits += 1
|
|
if edits >= k:
|
|
break
|
|
return " ".join(words)
|
|
|
|
# -----------------------------
|
|
# Sanitizer
|
|
# -----------------------------
|
|
def _sanitize(s: str, *, strip_urls: bool, strip_injections: bool) -> str:
|
|
x = s
|
|
if strip_urls:
|
|
x = re.sub(r"https?://\S+", "[link]", x)
|
|
if strip_injections:
|
|
for pat in _JAILBREAK_MARKERS:
|
|
x = re.sub(pat, "[removed]", x, flags=re.IGNORECASE)
|
|
x = re.sub(r"\s+", " ", x).strip()
|
|
return x
|
|
|
|
# -----------------------------
|
|
# Embedder with offline fallback
|
|
# -----------------------------
|
|
class _Embedder:
|
|
"""Use ST embeddings if available; otherwise fall back to a keyword risk."""
|
|
def __init__(self, model_name: str, unsafe_vocab: Optional[List[str]] = None):
|
|
self.ok = False
|
|
self.mdl = None
|
|
self.unsafe_vocab = set((unsafe_vocab or []) + [
|
|
"bypass","evade","circumvent","hack","exploit","malware","backdoor",
|
|
"bomb","weapon","poison","deepfake","impersonate","doxx","illegal",
|
|
"keylogger","phishing","ransomware","ddos","trojan"
|
|
])
|
|
if SentenceTransformer:
|
|
try:
|
|
self.mdl = SentenceTransformer(model_name)
|
|
self.ok = True
|
|
except Exception:
|
|
self.ok = False
|
|
|
|
def encode(self, texts: List[str]):
|
|
if not self.ok:
|
|
return None
|
|
v = self.mdl.encode(texts, convert_to_tensor=True, normalize_embeddings=True)
|
|
if torch is not None and not isinstance(v, torch.Tensor):
|
|
v = torch.tensor(v)
|
|
return v
|
|
|
|
def kw_risk(self, s: str) -> float:
|
|
s_low = s.lower()
|
|
hits = sum(1 for w in self.unsafe_vocab if re.search(rf"\b{re.escape(w)}\b", s_low))
|
|
# Map count → [0,1] with a gentle cap
|
|
return max(0.0, min(1.0, hits / 5.0))
|
|
|
|
# -----------------------------
|
|
# Builder
|
|
# -----------------------------
|
|
def build_prompt_defense(
|
|
safe_seed_prompts: List[str],
|
|
unsafe_seed_prompts: List[str],
|
|
cfg: Optional[PromptDefenseConfig] = None,
|
|
) -> Callable[[str], Tuple[str, Optional[str], Dict]]:
|
|
"""
|
|
Returns a callable that takes the raw user prompt and returns:
|
|
(rewritten_user_prompt, None, info)
|
|
We intentionally return None for the system prompt so the caller sends only the user message.
|
|
"""
|
|
cfg = cfg or PromptDefenseConfig()
|
|
emb = _Embedder(cfg.emb_model)
|
|
|
|
# Compute centroids (if embeddings available)
|
|
safe_centroid = None
|
|
unsafe_centroid = None
|
|
|
|
if emb.ok and unsafe_seed_prompts:
|
|
E_unsafe = emb.encode(unsafe_seed_prompts)
|
|
if torch is not None and E_unsafe is not None:
|
|
unsafe_centroid = F.normalize(E_unsafe.mean(dim=0, keepdim=True), dim=-1)
|
|
|
|
if emb.ok and safe_seed_prompts:
|
|
E_safe = emb.encode(safe_seed_prompts)
|
|
if torch is not None and E_safe is not None:
|
|
safe_centroid = F.normalize(E_safe.mean(dim=0, keepdim=True), dim=-1)
|
|
|
|
def _risk(prompt: str) -> float:
|
|
# Comparative risk = max(0, sim_unsafe - sim_safe)
|
|
if emb.ok and (unsafe_centroid is not None or safe_centroid is not None):
|
|
e = emb.encode([prompt])
|
|
if e is None:
|
|
return 0.0
|
|
sim_u = float((e @ unsafe_centroid.T).squeeze().item()) if unsafe_centroid is not None else 0.0
|
|
sim_s = float((e @ safe_centroid.T).squeeze().item()) if safe_centroid is not None else 0.0
|
|
# Denoise tiny sims
|
|
if sim_u < cfg.emb_thresh: sim_u = 0.0
|
|
if sim_s < cfg.emb_thresh: sim_s = 0.0
|
|
return max(0.0, min(1.0, sim_u - sim_s))
|
|
# Fallback: keyword risk
|
|
return emb.kw_risk(prompt)
|
|
|
|
def defense(user_prompt: str) -> Tuple[str, Optional[str], Dict]:
|
|
"""
|
|
Only rewrites the USER prompt (no system prompt).
|
|
The rewrite prepends a safety wrapper + sanitized content.
|
|
"""
|
|
t0 = time.time()
|
|
r0 = _risk(user_prompt)
|
|
|
|
# Sanitize (strip URLs/jailbreak markers)
|
|
x = _sanitize(
|
|
user_prompt,
|
|
strip_urls=cfg.strip_urls,
|
|
strip_injections=cfg.strip_injections,
|
|
)
|
|
|
|
# Optional paraphrase (usually keep off)
|
|
if cfg.paraphrase and cfg.max_edits > 0:
|
|
x = _paraphrase_light(x, cfg.max_edits)
|
|
|
|
# Choose prepend wrapper by risk (no system message is used)
|
|
if r0 >= cfg.extreme_threshold:
|
|
wrapped = EXTREME_PREFIX + f"User request:\n{x}"
|
|
action = "PREPEND_EXTREME"
|
|
elif r0 >= cfg.strict_threshold:
|
|
wrapped = STRICT_PREFIX + f"User request:\n{x}"
|
|
action = "PREPEND_STRICT"
|
|
else:
|
|
wrapped = MILD_PREFIX + x
|
|
action = "PREPEND_MILD"
|
|
|
|
r1 = _risk(wrapped)
|
|
info = {
|
|
"risk_before": r0,
|
|
"risk_after": r1,
|
|
"action": action,
|
|
"latency_ms": int((time.time() - t0) * 1000),
|
|
}
|
|
return wrapped, None, info # <-- No system prompt
|
|
|
|
return defense |