Files
CS5446-Project/prompt_based.py
2025-11-03 21:06:47 +08:00

229 lines
7.8 KiB
Python

# prompt_based.py
from __future__ import annotations
import re, time, random
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
# -----------------------------
# Optional deps (robust fallback)
# -----------------------------
try:
import torch
import torch.nn.functional as F
except Exception:
torch, F = None, None
try:
from sentence_transformers import SentenceTransformer
except Exception:
SentenceTransformer = None
# -----------------------------
# Config
# -----------------------------
@dataclass
class PromptDefenseConfig:
# Embeddings
emb_model: str = "sentence-transformers/all-MiniLM-L6-v2"
emb_thresh: float = 0.15 # clamp tiny cosine sims to 0 (denoise)
# Risk thresholds in [0,1]
strict_threshold: float = 0.30
extreme_threshold: float = 0.50
# Transform toggles
paraphrase: bool = False # generally better OFF (can help attackers)
max_edits: int = 0
strip_urls: bool = True
strip_injections: bool = True
# -----------------------------
# Patterns & text blocks
# -----------------------------
_JAILBREAK_MARKERS = [
r"ignore\s+previous\s+instructions",
r"override\s+system\s+prompt",
r"as\s+an\s+uncensored\s+model",
r"without\s+ethical\s+constraints",
r"developer\s+mode",
r"\{\{.*?\}\}",
r"#\s*system\s*:",
r"#\s*assistant\s*:",
r"simulate\s+(?:being|as)",
r"role\s*play\s+as",
r"\b(bypass|evade|circumvent)\b",
r"\bjailbreak\b",
]
MILD_PREFIX = (
"Safety note: I will keep this discussion high-level and avoid sensitive or harmful details. "
"If your intent is benign, please clarify as needed.\n\n"
)
STRICT_PREFIX = (
"Safety notice: The request appears potentially risky. I will NOT provide harmful, illegal, or dangerous instructions. "
"I can offer safe, high-level context and alternatives instead.\n\n"
)
EXTREME_PREFIX = (
"Safety override: This request appears unsafe. I will refuse to provide any operational or step-by-step guidance. "
"If you have a lawful, safe learning goal, please reframe the question accordingly.\n\n"
)
# -----------------------------
# Light paraphraser (optional)
# -----------------------------
_SYNONYM_MAP = {
"give": ["provide", "offer"],
"tell": ["share", "explain"],
"hack": ["compromise"],
"bypass": ["avoid", "circumvent"],
}
def _paraphrase_light(s: str, k: int) -> str:
if k <= 0:
return s
words = s.split()
idxs = list(range(len(words)))
random.shuffle(idxs)
edits = 0
for i in idxs:
token = words[i]
base = re.sub(r"\W+", "", token).lower()
if base in _SYNONYM_MAP:
repl = random.choice(_SYNONYM_MAP[base])
# replace base (case-insensitive) but preserve punctuation around it
words[i] = re.sub(base, repl, token, flags=re.IGNORECASE)
edits += 1
if edits >= k:
break
return " ".join(words)
# -----------------------------
# Sanitizer
# -----------------------------
def _sanitize(s: str, *, strip_urls: bool, strip_injections: bool) -> str:
x = s
if strip_urls:
x = re.sub(r"https?://\S+", "[link]", x)
if strip_injections:
for pat in _JAILBREAK_MARKERS:
x = re.sub(pat, "[removed]", x, flags=re.IGNORECASE)
x = re.sub(r"\s+", " ", x).strip()
return x
# -----------------------------
# Embedder with offline fallback
# -----------------------------
class _Embedder:
"""Use ST embeddings if available; otherwise fall back to a keyword risk."""
def __init__(self, model_name: str, unsafe_vocab: Optional[List[str]] = None):
self.ok = False
self.mdl = None
self.unsafe_vocab = set((unsafe_vocab or []) + [
"bypass","evade","circumvent","hack","exploit","malware","backdoor",
"bomb","weapon","poison","deepfake","impersonate","doxx","illegal",
"keylogger","phishing","ransomware","ddos","trojan"
])
if SentenceTransformer:
try:
self.mdl = SentenceTransformer(model_name)
self.ok = True
except Exception:
self.ok = False
def encode(self, texts: List[str]):
if not self.ok:
return None
v = self.mdl.encode(texts, convert_to_tensor=True, normalize_embeddings=True)
if torch is not None and not isinstance(v, torch.Tensor):
v = torch.tensor(v)
return v
def kw_risk(self, s: str) -> float:
s_low = s.lower()
hits = sum(1 for w in self.unsafe_vocab if re.search(rf"\b{re.escape(w)}\b", s_low))
# Map count → [0,1] with a gentle cap
return max(0.0, min(1.0, hits / 5.0))
# -----------------------------
# Builder
# -----------------------------
def build_prompt_defense(
safe_seed_prompts: List[str],
unsafe_seed_prompts: List[str],
cfg: Optional[PromptDefenseConfig] = None,
) -> Callable[[str], Tuple[str, Optional[str], Dict]]:
"""
Returns a callable that takes the raw user prompt and returns:
(rewritten_user_prompt, None, info)
We intentionally return None for the system prompt so the caller sends only the user message.
"""
cfg = cfg or PromptDefenseConfig()
emb = _Embedder(cfg.emb_model)
# Compute centroids (if embeddings available)
safe_centroid = None
unsafe_centroid = None
if emb.ok and unsafe_seed_prompts:
E_unsafe = emb.encode(unsafe_seed_prompts)
if torch is not None and E_unsafe is not None:
unsafe_centroid = F.normalize(E_unsafe.mean(dim=0, keepdim=True), dim=-1)
if emb.ok and safe_seed_prompts:
E_safe = emb.encode(safe_seed_prompts)
if torch is not None and E_safe is not None:
safe_centroid = F.normalize(E_safe.mean(dim=0, keepdim=True), dim=-1)
def _risk(prompt: str) -> float:
# Comparative risk = max(0, sim_unsafe - sim_safe)
if emb.ok and (unsafe_centroid is not None or safe_centroid is not None):
e = emb.encode([prompt])
if e is None:
return 0.0
sim_u = float((e @ unsafe_centroid.T).squeeze().item()) if unsafe_centroid is not None else 0.0
sim_s = float((e @ safe_centroid.T).squeeze().item()) if safe_centroid is not None else 0.0
# Denoise tiny sims
if sim_u < cfg.emb_thresh: sim_u = 0.0
if sim_s < cfg.emb_thresh: sim_s = 0.0
return max(0.0, min(1.0, sim_u - sim_s))
# Fallback: keyword risk
return emb.kw_risk(prompt)
def defense(user_prompt: str) -> Tuple[str, Optional[str], Dict]:
"""
Only rewrites the USER prompt (no system prompt).
The rewrite prepends a safety wrapper + sanitized content.
"""
t0 = time.time()
r0 = _risk(user_prompt)
# Sanitize (strip URLs/jailbreak markers)
x = _sanitize(
user_prompt,
strip_urls=cfg.strip_urls,
strip_injections=cfg.strip_injections,
)
# Optional paraphrase (usually keep off)
if cfg.paraphrase and cfg.max_edits > 0:
x = _paraphrase_light(x, cfg.max_edits)
# Choose prepend wrapper by risk (no system message is used)
if r0 >= cfg.extreme_threshold:
wrapped = EXTREME_PREFIX + f"User request:\n{x}"
action = "PREPEND_EXTREME"
elif r0 >= cfg.strict_threshold:
wrapped = STRICT_PREFIX + f"User request:\n{x}"
action = "PREPEND_STRICT"
else:
wrapped = MILD_PREFIX + x
action = "PREPEND_MILD"
r1 = _risk(wrapped)
info = {
"risk_before": r0,
"risk_after": r1,
"action": action,
"latency_ms": int((time.time() - t0) * 1000),
}
return wrapped, None, info # <-- No system prompt
return defense