Files
CS5446-Project/proposed_domain.ipynb
2025-11-05 14:20:01 +08:00

595 lines
21 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "6c2e3bb3-5d74-4259-ac11-6e446ff4c685",
"metadata": {},
"outputs": [],
"source": [
"import os, json, random, time\n",
"from pathlib import Path\n",
"\n",
"import gc\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm.auto import tqdm\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Local modules\n",
"from model import load_model\n",
"from dataset import load_combined_minimal, balanced_sample, SimpleTextDataset, get_seed_sets_for_steering\n",
"from validator import evaluate_minimal\n",
"\n",
"from domain_specific_proposed import build_agentic_defense, cleanup_models"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ce3d9b5-5e59-457f-ba20-ec34b1007c98",
"metadata": {},
"outputs": [],
"source": [
"SEED = 42\n",
"random.seed(SEED); np.random.seed(SEED)\n",
"try:\n",
" import torch\n",
" torch.manual_seed(SEED)\n",
"except Exception:\n",
" pass\n",
"\n",
"# Config\n",
"ALIGNED_KEY = \"aligned\" \n",
"UNALIGNED_KEY = \"unaligned\"\n",
"\n",
"MAX_UNSAFE_EVAL = 200\n",
"SAFE_SEED = 96\n",
"UNSAFE_SEED= 96\n",
"\n",
"OUTDIR = Path(\"results/asv_notebook_domain\")\n",
"OUTDIR.mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16d224b4-fe71-4992-8dbe-e19aae08924d",
"metadata": {},
"outputs": [],
"source": [
"df_all_ds = load_combined_minimal(\"xstest_prompts.csv\")\n",
"full_df = df_all_ds.df\n",
"\n",
"unsafe_df = full_df[full_df['label'] == 'unsafe'].copy()\n",
"if len(unsafe_df) == 0:\n",
" raise RuntimeError(\"No 'unsafe' rows found. Ensure your CSV/HF data is available.\")\n",
"\n",
"if len(unsafe_df) > MAX_UNSAFE_EVAL:\n",
" unsafe_eval_df = unsafe_df.sample(MAX_UNSAFE_EVAL, random_state=SEED).reset_index(drop=True)\n",
" unsafe_eval_df['id'] = unsafe_eval_df.index.astype(int)\n",
"else:\n",
" unsafe_eval_df = unsafe_df.reset_index(drop=True)\n",
" unsafe_eval_df['id'] = unsafe_eval_df.index.astype(int)\n",
"\n",
"print(f\"[INFO] Unsafe eval set size: {len(unsafe_eval_df)}\")\n",
"\n",
"\n",
"safe_prompts_seed, unsafe_prompts_seed = get_seed_sets_for_steering(\n",
" full_df, n_safe=SAFE_SEED, n_unsafe=UNSAFE_SEED, seed=SEED\n",
")\n",
"print(f\"[INFO] Steering seed sets -> safe: {len(safe_prompts_seed)}, unsafe: {len(unsafe_prompts_seed)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d9558ae-89b7-4785-b95b-957fe2027db0",
"metadata": {},
"outputs": [],
"source": [
"MAX_INPUT_TOKENS = 2048\n",
"MAX_NEW_TOKENS = 768\n",
"USE_CACHE = False\n",
"\n",
"def _asr(df):\n",
" if df is None or len(df) == 0 or \"jailbreak_success\" not in df.columns:\n",
" return float(\"nan\")\n",
" return float(df[\"jailbreak_success\"].astype(bool).mean())\n",
"\n",
"def run_generation_df(wrapper, df_prompts, *,\n",
" steerer=None,\n",
" prompt_defense=None,\n",
" agentic=None,\n",
" max_input_tokens=MAX_INPUT_TOKENS,\n",
" max_new_tokens=MAX_NEW_TOKENS,\n",
" temperature=0.0, top_p=1.0, clear_cuda_each=1):\n",
" rows = []\n",
" iterator = tqdm(df_prompts.itertuples(index=False), total=len(df_prompts), desc=\"Generate (full)\")\n",
"\n",
" hf = wrapper.hf\n",
" tok = wrapper.tokenizer\n",
"\n",
" # Respect global USE_CACHE if defined\n",
" try:\n",
" hf.config.use_cache = USE_CACHE\n",
" if hasattr(hf, \"generation_config\"):\n",
" hf.generation_config.use_cache = USE_CACHE\n",
" except Exception:\n",
" pass\n",
"\n",
" eos_id = tok.eos_token_id\n",
" if eos_id is None and getattr(tok, \"pad_token_id\", None) is not None:\n",
" eos_id = tok.pad_token_id\n",
"\n",
" for i, row in enumerate(iterator, 1):\n",
" orig_prompt = row.prompt\n",
"\n",
" if agentic is not None:\n",
" try:\n",
" final_out, verdict, info_agentic, rewritten_prompt = agentic(\n",
" wrapper, orig_prompt,\n",
" max_input_tokens=max_input_tokens,\n",
" max_new_tokens=max_new_tokens,\n",
" use_cache=USE_CACHE,\n",
" )\n",
" except Exception as e:\n",
" rewritten_prompt = None\n",
" verdict = {\"safe\": False, \"reasons\": [f\"agentic_error: {type(e).__name__}\"]} # optional\n",
" if hasattr(tok, \"apply_chat_template\"):\n",
" msgs = [{\"role\": \"user\", \"content\": orig_prompt}]\n",
" text = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)\n",
" else:\n",
" text = (\n",
" \"<|begin_of_text|>\"\n",
" \"<|start_header_id|>user<|end_header_id|>\\n\"\n",
" f\"{orig_prompt}\\n<|eot_id|>\"\n",
" \"<|start_header_id|>assistant<|end_header_id|>\\n\"\n",
" )\n",
" enc = tok(text, return_tensors=\"pt\", truncation=True, max_length=max_input_tokens).to(hf.device)\n",
" gen_kwargs = dict(\n",
" max_new_tokens=max_new_tokens,\n",
" do_sample=False if (temperature is None or temperature == 0.0) else True,\n",
" temperature=None if (temperature is None or temperature == 0.0) else float(temperature),\n",
" top_p=top_p,\n",
" use_cache=USE_CACHE,\n",
" )\n",
" if eos_id is not None:\n",
" gen_kwargs[\"eos_token_id\"] = eos_id\n",
" gen_kwargs[\"pad_token_id\"] = eos_id\n",
" with torch.inference_mode():\n",
" out_ids = hf.generate(**enc, **gen_kwargs)\n",
" gen_only = out_ids[0][enc[\"input_ids\"].shape[1]:]\n",
" final_out = tok.decode(gen_only, skip_special_tokens=True).strip()\n",
"\n",
" rows.append({\n",
" \"id\": int(row.id),\n",
" \"prompt\": orig_prompt,\n",
" \"label\": row.label,\n",
" \"model_output\": final_out,\n",
" \"rewritten_prompt\": rewritten_prompt,\n",
" \"judge_safe\": verdict.get(\"safe\") if isinstance(verdict, dict) else None,\n",
" \"judge_reasons\": \"; \".join(verdict.get(\"reasons\", [])) if isinstance(verdict, dict) else None,\n",
" })\n",
"\n",
" if (clear_cuda_each is not None) and (i % clear_cuda_each == 0):\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.ipc_collect()\n",
" import gc as _gc; _gc.collect()\n",
" continue # next row\n",
"\n",
" sys_prompt = None\n",
" prompt = orig_prompt\n",
"\n",
" if prompt_defense is not None:\n",
" try:\n",
" transformed, _sys_ignored, info = prompt_defense(orig_prompt)\n",
" prompt = transformed if transformed is not None else orig_prompt\n",
" sys_prompt = None\n",
" except Exception:\n",
" prompt = orig_prompt\n",
" sys_prompt = None\n",
"\n",
" if hasattr(tok, \"apply_chat_template\"):\n",
" msgs = [{\"role\": \"user\", \"content\": prompt}]\n",
" text = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)\n",
" else:\n",
" text = (\n",
" \"<|begin_of_text|>\"\n",
" \"<|start_header_id|>user<|end_header_id|>\\n\"\n",
" f\"{prompt}\\n<|eot_id|>\"\n",
" \"<|start_header_id|>assistant<|end_header_id|>\\n\"\n",
" )\n",
"\n",
" enc = tok(text, return_tensors=\"pt\", truncation=True, max_length=max_input_tokens).to(hf.device)\n",
"\n",
" gen_kwargs = dict(\n",
" max_new_tokens=max_new_tokens,\n",
" do_sample=False if (temperature is None or temperature == 0.0) else True,\n",
" temperature=None if (temperature is None or temperature == 0.0) else float(temperature),\n",
" top_p=top_p,\n",
" use_cache=USE_CACHE,\n",
" )\n",
" if eos_id is not None:\n",
" gen_kwargs[\"eos_token_id\"] = eos_id\n",
" gen_kwargs[\"pad_token_id\"] = eos_id\n",
"\n",
" with torch.inference_mode():\n",
" if steerer is None:\n",
" out_ids = hf.generate(**enc, **gen_kwargs)\n",
" else:\n",
" with steerer.steering_context(prompt_for_alpha=orig_prompt):\n",
" out_ids = hf.generate(**enc, **gen_kwargs)\n",
"\n",
" gen_only = out_ids[0][enc[\"input_ids\"].shape[1]:]\n",
" text_out = tok.decode(gen_only, skip_special_tokens=True).strip()\n",
"\n",
" rows.append({\n",
" \"id\": int(row.id),\n",
" \"prompt\": orig_prompt,\n",
" \"label\": row.label,\n",
" \"model_output\": text_out\n",
" })\n",
"\n",
" if (clear_cuda_each is not None) and (i % clear_cuda_each == 0):\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.ipc_collect()\n",
" import gc as _gc; _gc.collect()\n",
"\n",
" return pd.DataFrame(rows)"
]
},
{
"cell_type": "markdown",
"id": "07c90dee-645d-4fb6-a485-12500ea3f660",
"metadata": {},
"source": [
"## Aligned Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "059dac08-2621-4551-9d6c-2c72b2309138",
"metadata": {},
"outputs": [],
"source": [
"print(\"[INFO] Loading aligned model...\")\n",
"aligned = load_model(ALIGNED_KEY)\n",
"\n",
"print(\"[TEST] Aligned sample ->\", aligned.generate(\"Say 'ok'.\", max_new_tokens=5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e387a8e-3244-4719-b9c1-e1bb48bf3509",
"metadata": {},
"outputs": [],
"source": [
"print(\"[BASELINE] Evaluating ALIGNED (no defense, FULL outputs) ...\")\n",
"df_gen_aligned_base = run_generation_df(\n",
" aligned, unsafe_eval_df, steerer=None,\n",
" max_input_tokens=MAX_INPUT_TOKENS,\n",
" max_new_tokens=MAX_NEW_TOKENS, \n",
" temperature=None, \n",
" top_p=1.0\n",
")\n",
"df_eval_aligned_base = evaluate_minimal(\n",
" df_gen_aligned_base,\n",
" use_llm_judge=True,\n",
" ensemble_with_heuristic=True\n",
")\n",
"\n",
"asr_aligned_base = _asr(df_eval_aligned_base)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5bbee730-137a-4eb5-842d-755851b3710e",
"metadata": {},
"outputs": [],
"source": [
"adef_aligned = build_agentic_defense(\n",
" model_id=\"Qwen/Qwen3-0.6B\",\n",
" refusal_text=\"I can not answer that, sorry.\",\n",
" max_new_tokens_rewriter=128,\n",
" max_new_tokens_judge=128,\n",
")\n",
"\n",
"print(\"Aligned domain defence init\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff877f13-b64c-46e4-ba3d-a97f56b14185",
"metadata": {},
"outputs": [],
"source": [
"print(\"[DEFENSE] Evaluating ALIGNED (domain-based, FULL outputs) ...\")\n",
"df_gen_aligned_def = run_generation_df(\n",
" aligned, unsafe_eval_df,\n",
" agentic=adef_aligned,\n",
" steerer=None,\n",
" prompt_defense=None,\n",
" max_input_tokens=MAX_INPUT_TOKENS,\n",
" max_new_tokens=MAX_NEW_TOKENS,\n",
" temperature=None,\n",
" top_p=1.0\n",
")\n",
"df_eval_aligned_def = evaluate_minimal(\n",
" df_gen_aligned_def,\n",
" use_llm_judge=True,\n",
" ensemble_with_heuristic=True\n",
")\n",
"asr_aligned_def = _asr(df_eval_aligned_def)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97d176e1-9e38-4cc5-b523-c14174a1a815",
"metadata": {},
"outputs": [],
"source": [
"# clean up the model\n",
"print(\"[CLEANUP] Releasing ALIGNED model from memory...\")\n",
"cleanup_models(adef_aligned)\n",
"del aligned\n",
"gc.collect()\n",
"if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.ipc_collect()"
]
},
{
"cell_type": "markdown",
"id": "47632778-b65b-4a7a-a22a-f013857de0a8",
"metadata": {},
"source": [
"## Unaligned Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32864439-2e43-49b5-a271-5b696a35a040",
"metadata": {},
"outputs": [],
"source": [
"print(\"[INFO] Loading unaligned model...\")\n",
"unaligned = load_model(UNALIGNED_KEY)\n",
"print(\"[TEST] Unaligned sample ->\", unaligned.generate(\"Say 'ok'.\", max_new_tokens=5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2251a9e9-2093-4aee-b419-25e667c166cb",
"metadata": {},
"outputs": [],
"source": [
"print(\"[BASELINE] Evaluating UNALIGNED (no defense, FULL outputs) ...\")\n",
"df_gen_unaligned_base = run_generation_df(\n",
" unaligned, unsafe_eval_df, steerer=None,\n",
" max_input_tokens=MAX_INPUT_TOKENS,\n",
" max_new_tokens=MAX_NEW_TOKENS,\n",
" temperature=None,\n",
" top_p=1.0\n",
")\n",
"df_eval_unaligned_base = evaluate_minimal(\n",
" df_gen_unaligned_base,\n",
" use_llm_judge=True,\n",
" ensemble_with_heuristic=True\n",
")\n",
"\n",
"asr_unaligned_base = _asr(df_eval_unaligned_base)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0483cf69-bffa-4380-9eb9-2320e1570cbe",
"metadata": {},
"outputs": [],
"source": [
"adef_unaligned = build_agentic_defense(\n",
" model_id=\"Qwen/Qwen3-0.6B\",\n",
" refusal_text=\"I can not answer that, sorry.\",\n",
" max_new_tokens_rewriter=128,\n",
" max_new_tokens_judge=128,\n",
")\n",
"\n",
"print(\"Unaligned domain defence init\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf29ba37-b78d-43e2-b573-2f1a5d425d16",
"metadata": {},
"outputs": [],
"source": [
"print(\"[DEFENSE] Evaluating UNALIGNED (domain-based, FULL outputs) ...\")\n",
"df_gen_unaligned_def = run_generation_df(\n",
" unaligned,\n",
" unsafe_eval_df,\n",
" agentic=adef_unaligned, \n",
" steerer=None,\n",
" prompt_defense=None,\n",
" max_input_tokens=MAX_INPUT_TOKENS,\n",
" max_new_tokens=MAX_NEW_TOKENS,\n",
" temperature=None,\n",
" top_p=1.0\n",
")\n",
"\n",
"df_eval_unaligned_def = evaluate_minimal(\n",
" df_gen_unaligned_def,\n",
" use_llm_judge=True,\n",
" ensemble_with_heuristic=True\n",
")\n",
"asr_unaligned_def = _asr(df_eval_unaligned_def)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6177b6d4-0ee6-4ebd-8add-41079adfd9b3",
"metadata": {},
"outputs": [],
"source": [
"print(\"[CLEANUP] Releasing UNALIGNED model and steerer from memory...\")\n",
"del unaligned\n",
"cleanup_models(adef_unaligned)\n",
"gc.collect()\n",
"if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.ipc_collect()"
]
},
{
"cell_type": "markdown",
"id": "3f3e6ce1-cf12-4843-9517-0b84be75520f",
"metadata": {},
"source": [
"# Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e99f224-3059-46c9-8801-1c66782ba901",
"metadata": {},
"outputs": [],
"source": [
"print(f\"[RESULT] Baseline ASR — ALIGNED: {asr_aligned_base:.3f} | UNALIGNED: {asr_unaligned_base:.3f}\")\n",
"\n",
"OUTDIR.mkdir(parents=True, exist_ok=True)\n",
"df_gen_aligned_base.to_csv(OUTDIR / \"gen_aligned_baseline.csv\", index=False)\n",
"df_gen_unaligned_base.to_csv(OUTDIR / \"gen_unaligned_baseline.csv\", index=False)\n",
"df_eval_aligned_base.to_csv(OUTDIR / \"eval_aligned_baseline.csv\", index=False)\n",
"df_eval_unaligned_base.to_csv(OUTDIR / \"eval_unaligned_baseline.csv\", index=False)\n",
"\n",
"print(f\"[RESULT] With Defense ASR — ALIGNED: {asr_aligned_def:.3f} | UNALIGNED: {asr_unaligned_def:.3f}\")\n",
"\n",
"OUTDIR.mkdir(parents=True, exist_ok=True)\n",
"df_gen_aligned_def.to_csv(OUTDIR / \"gen_aligned_domain.csv\", index=False)\n",
"df_gen_unaligned_def.to_csv(OUTDIR / \"gen_unaligned_domain.csv\", index=False)\n",
"df_eval_aligned_def.to_csv(OUTDIR / \"eval_aligned_domain.csv\", index=False)\n",
"df_eval_unaligned_def.to_csv(OUTDIR / \"eval_unaligned_domain.csv\", index=False)\n",
"\n",
"summary = {\n",
" \"baseline\": {\"aligned\": asr_aligned_base, \"unaligned\": asr_unaligned_base},\n",
" \"defense\": {\"aligned\": asr_aligned_def, \"unaligned\": asr_unaligned_def},\n",
"}\n",
"with open(OUTDIR / \"summary.json\", \"w\") as f:\n",
" json.dump(summary, f, indent=2)\n",
"print(\"\\n[SUMMARY]\", json.dumps(summary, indent=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66d21350-1ec1-4f19-80bb-c2aa7c5d83a4",
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 4))\n",
"y_a = df_eval_aligned_base['jailbreak_success'].astype(int).values\n",
"y_u = df_eval_unaligned_base['jailbreak_success'].astype(int).values\n",
"x = np.arange(len(y_a))\n",
"\n",
"plt.plot(x, y_a, label=\"Aligned (no defense)\")\n",
"plt.plot(x, y_u, label=\"Unaligned (no defense)\")\n",
"plt.xlabel(\"Attempt index\")\n",
"plt.ylabel(\"Success (0/1)\")\n",
"plt.title(\"Jailbreak Attempts vs Success — Baseline\")\n",
"plt.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00b4072a-cc01-419d-a89b-cfddfd45ec14",
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 4))\n",
"y_a = df_eval_aligned_def['jailbreak_success'].astype(int).values\n",
"y_u = df_eval_unaligned_def['jailbreak_success'].astype(int).values\n",
"x = np.arange(len(y_a))\n",
"\n",
"plt.plot(x, y_a, label=\"Aligned (defense)\")\n",
"plt.plot(x, y_u, label=\"Unaligned (defense)\")\n",
"plt.xlabel(\"Attempt index\")\n",
"plt.ylabel(\"Success (0/1)\")\n",
"plt.title(\"Jailbreak Attempts vs Success — defense\")\n",
"plt.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7986b2a6-a0af-4301-9b5e-773ce3493dce",
"metadata": {},
"outputs": [],
"source": [
"labels = [\"Aligned\", \"Unaligned\"]\n",
"baseline = [asr_aligned_base, asr_unaligned_base]\n",
"defense = [asr_aligned_def, asr_unaligned_def]\n",
"\n",
"plt.figure(figsize=(6,4))\n",
"x = np.arange(len(labels))\n",
"width = 0.35\n",
"plt.bar(x - width/2, baseline, width, label='Baseline')\n",
"plt.bar(x + width/2, defense, width, label='With Domain Defence')\n",
"plt.xticks(x, labels)\n",
"plt.ylabel('ASR')\n",
"plt.title('Attack Success Rate (lower is better)')\n",
"plt.legend()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af7dfa1e-3bf9-4524-bc60-033247a67948",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}