{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "6c2e3bb3-5d74-4259-ac11-6e446ff4c685", "metadata": {}, "outputs": [], "source": [ "import os, json, random, time\n", "from pathlib import Path\n", "\n", "import gc\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm.auto import tqdm\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "# Local modules\n", "from model import load_model\n", "from dataset import load_combined_minimal, balanced_sample, SimpleTextDataset, get_seed_sets_for_steering\n", "from validator import evaluate_minimal\n", "from prompt_based import build_prompt_defense, PromptDefenseConfig" ] }, { "cell_type": "code", "execution_count": null, "id": "8ce3d9b5-5e59-457f-ba20-ec34b1007c98", "metadata": {}, "outputs": [], "source": [ "SEED = 42\n", "random.seed(SEED); np.random.seed(SEED)\n", "try:\n", " import torch\n", " torch.manual_seed(SEED)\n", "except Exception:\n", " pass\n", "\n", "# Config\n", "ALIGNED_KEY = \"aligned\" \n", "UNALIGNED_KEY = \"unaligned\"\n", "\n", "MAX_UNSAFE_EVAL = 200\n", "SAFE_SEED = 96\n", "UNSAFE_SEED= 96\n", "\n", "OUTDIR = Path(\"results/asv_notebook_prompt\")\n", "OUTDIR.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "16d224b4-fe71-4992-8dbe-e19aae08924d", "metadata": {}, "outputs": [], "source": [ "df_all_ds = load_combined_minimal(\"xstest_prompts.csv\")\n", "full_df = df_all_ds.df\n", "\n", "unsafe_df = full_df[full_df['label'] == 'unsafe'].copy()\n", "if len(unsafe_df) == 0:\n", " raise RuntimeError(\"No 'unsafe' rows found. Ensure your CSV/HF data is available.\")\n", "\n", "if len(unsafe_df) > MAX_UNSAFE_EVAL:\n", " unsafe_eval_df = unsafe_df.sample(MAX_UNSAFE_EVAL, random_state=SEED).reset_index(drop=True)\n", " unsafe_eval_df['id'] = unsafe_eval_df.index.astype(int)\n", "else:\n", " unsafe_eval_df = unsafe_df.reset_index(drop=True)\n", " unsafe_eval_df['id'] = unsafe_eval_df.index.astype(int)\n", "\n", "print(f\"[INFO] Unsafe eval set size: {len(unsafe_eval_df)}\")\n", "\n", "\n", "safe_prompts_seed, unsafe_prompts_seed = get_seed_sets_for_steering(\n", " full_df, n_safe=SAFE_SEED, n_unsafe=UNSAFE_SEED, seed=SEED\n", ")\n", "print(f\"[INFO] Steering seed sets -> safe: {len(safe_prompts_seed)}, unsafe: {len(unsafe_prompts_seed)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "1d9558ae-89b7-4785-b95b-957fe2027db0", "metadata": {}, "outputs": [], "source": [ "MAX_INPUT_TOKENS = 2048\n", "MAX_NEW_TOKENS = 768\n", "USE_CACHE = False\n", "\n", "def _asr(df):\n", " if df is None or len(df) == 0 or \"jailbreak_success\" not in df.columns:\n", " return float(\"nan\")\n", " return float(df[\"jailbreak_success\"].astype(bool).mean())\n", "\n", "def run_generation_df(wrapper, df_prompts, *,\n", " steerer=None,\n", " prompt_defense=None,\n", " max_input_tokens=MAX_INPUT_TOKENS,\n", " max_new_tokens=MAX_NEW_TOKENS,\n", " temperature=0.0, top_p=1.0, clear_cuda_each=1):\n", " \"\"\"\n", " Generate model outputs for a dataframe of prompts.\n", " - Does NOT send a system prompt.\n", " - If `prompt_defense` is provided, it must return (rewritten_user_prompt, None, info).\n", " - Works with or without a steering context.\n", " \"\"\"\n", " rows = []\n", " iterator = tqdm(df_prompts.itertuples(index=False), total=len(df_prompts), desc=\"Generate (full)\")\n", "\n", " hf = wrapper.hf\n", " tok = wrapper.tokenizer\n", "\n", " # Respect global USE_CACHE if defined\n", " try:\n", " hf.config.use_cache = USE_CACHE\n", " if hasattr(hf, \"generation_config\"):\n", " hf.generation_config.use_cache = USE_CACHE\n", " except Exception:\n", " pass\n", "\n", " eos_id = tok.eos_token_id\n", " if eos_id is None and getattr(tok, \"pad_token_id\", None) is not None:\n", " eos_id = tok.pad_token_id\n", "\n", " for i, row in enumerate(iterator, 1):\n", " orig_prompt = row.prompt\n", " sys_prompt = None\n", " prompt = orig_prompt\n", "\n", " if prompt_defense is not None:\n", " try:\n", " transformed, _sys_ignored, info = prompt_defense(orig_prompt)\n", " prompt = transformed if transformed is not None else orig_prompt\n", " sys_prompt = None\n", " except Exception:\n", " prompt = orig_prompt\n", " sys_prompt = None\n", "\n", " if hasattr(tok, \"apply_chat_template\"):\n", " msgs = [{\"role\": \"user\", \"content\": prompt}]\n", " text = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)\n", " else:\n", " text = (\n", " \"<|begin_of_text|>\"\n", " \"<|start_header_id|>user<|end_header_id|>\\n\"\n", " f\"{prompt}\\n<|eot_id|>\"\n", " \"<|start_header_id|>assistant<|end_header_id|>\\n\"\n", " )\n", "\n", " enc = tok(text, return_tensors=\"pt\", truncation=True, max_length=max_input_tokens).to(hf.device)\n", "\n", " gen_kwargs = dict(\n", " max_new_tokens=max_new_tokens,\n", " do_sample=False if (temperature is None or temperature == 0.0) else True,\n", " temperature=None if (temperature is None or temperature == 0.0) else float(temperature),\n", " top_p=top_p,\n", " use_cache=USE_CACHE,\n", " )\n", " if eos_id is not None:\n", " gen_kwargs[\"eos_token_id\"] = eos_id\n", " gen_kwargs[\"pad_token_id\"] = eos_id\n", "\n", " with torch.inference_mode():\n", " if steerer is None:\n", " out_ids = hf.generate(**enc, **gen_kwargs)\n", " else:\n", " with steerer.steering_context(prompt_for_alpha=orig_prompt):\n", " out_ids = hf.generate(**enc, **gen_kwargs)\n", "\n", " gen_only = out_ids[0][enc[\"input_ids\"].shape[1]:]\n", " text_out = tok.decode(gen_only, skip_special_tokens=True).strip()\n", "\n", " rows.append({\n", " \"id\": int(row.id),\n", " \"prompt\": orig_prompt,\n", " \"label\": row.label,\n", " \"model_output\": text_out\n", " })\n", "\n", " if (clear_cuda_each is not None) and (i % clear_cuda_each == 0):\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " torch.cuda.ipc_collect()\n", " import gc as _gc; _gc.collect()\n", "\n", " return pd.DataFrame(rows)" ] }, { "cell_type": "markdown", "id": "07c90dee-645d-4fb6-a485-12500ea3f660", "metadata": {}, "source": [ "## Aligned Model" ] }, { "cell_type": "code", "execution_count": null, "id": "059dac08-2621-4551-9d6c-2c72b2309138", "metadata": {}, "outputs": [], "source": [ "print(\"[INFO] Loading aligned model...\")\n", "aligned = load_model(ALIGNED_KEY)\n", "\n", "print(\"[TEST] Aligned sample ->\", aligned.generate(\"Say 'ok'.\", max_new_tokens=5))" ] }, { "cell_type": "code", "execution_count": null, "id": "9e387a8e-3244-4719-b9c1-e1bb48bf3509", "metadata": {}, "outputs": [], "source": [ "print(\"[BASELINE] Evaluating ALIGNED (no defense, FULL outputs) ...\")\n", "df_gen_aligned_base = run_generation_df(\n", " aligned, unsafe_eval_df, steerer=None,\n", " max_input_tokens=MAX_INPUT_TOKENS,\n", " max_new_tokens=MAX_NEW_TOKENS, \n", " temperature=None, \n", " top_p=1.0\n", ")\n", "df_eval_aligned_base = evaluate_minimal(\n", " df_gen_aligned_base,\n", " use_llm_judge=True,\n", " ensemble_with_heuristic=True\n", ")\n", "\n", "asr_aligned_base = _asr(df_eval_aligned_base)" ] }, { "cell_type": "code", "execution_count": null, "id": "5bbee730-137a-4eb5-842d-755851b3710e", "metadata": {}, "outputs": [], "source": [ "pdef_aligned = build_prompt_defense(\n", " safe_prompts_seed, unsafe_prompts_seed,\n", " PromptDefenseConfig(\n", " emb_thresh=0.05,\n", " strict_threshold=0.15,\n", " extreme_threshold=0.30,\n", " paraphrase=True,\n", " max_edits=4,\n", " strip_urls=True,\n", " strip_injections=True,\n", " ))\n", "print(\"Aligned prompt defence init\")" ] }, { "cell_type": "code", "execution_count": null, "id": "ff877f13-b64c-46e4-ba3d-a97f56b14185", "metadata": {}, "outputs": [], "source": [ "print(\"[DEFENSE] Evaluating ALIGNED (prompt-based, FULL outputs) ...\")\n", "df_gen_aligned_def = run_generation_df(\n", " aligned, unsafe_eval_df,\n", " steerer=None,\n", " prompt_defense=pdef_aligned,\n", " max_input_tokens=MAX_INPUT_TOKENS,\n", " max_new_tokens=MAX_NEW_TOKENS,\n", " temperature=None,\n", " top_p=1.0\n", ")\n", "df_eval_aligned_def = evaluate_minimal(\n", " df_gen_aligned_def,\n", " use_llm_judge=True,\n", " ensemble_with_heuristic=True\n", ")\n", "asr_aligned_def = _asr(df_eval_aligned_def)" ] }, { "cell_type": "code", "execution_count": null, "id": "97d176e1-9e38-4cc5-b523-c14174a1a815", "metadata": {}, "outputs": [], "source": [ "# clean up the model\n", "print(\"[CLEANUP] Releasing ALIGNED model from memory...\")\n", "del aligned\n", "gc.collect()\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " torch.cuda.ipc_collect()" ] }, { "cell_type": "markdown", "id": "47632778-b65b-4a7a-a22a-f013857de0a8", "metadata": {}, "source": [ "## Unaligned Model" ] }, { "cell_type": "code", "execution_count": null, "id": "32864439-2e43-49b5-a271-5b696a35a040", "metadata": {}, "outputs": [], "source": [ "print(\"[INFO] Loading unaligned model...\")\n", "unaligned = load_model(UNALIGNED_KEY)\n", "print(\"[TEST] Unaligned sample ->\", unaligned.generate(\"Say 'ok'.\", max_new_tokens=5))" ] }, { "cell_type": "code", "execution_count": null, "id": "2251a9e9-2093-4aee-b419-25e667c166cb", "metadata": {}, "outputs": [], "source": [ "print(\"[BASELINE] Evaluating UNALIGNED (no defense, FULL outputs) ...\")\n", "df_gen_unaligned_base = run_generation_df(\n", " unaligned, unsafe_eval_df, steerer=None,\n", " max_input_tokens=MAX_INPUT_TOKENS,\n", " max_new_tokens=MAX_NEW_TOKENS,\n", " temperature=None,\n", " top_p=1.0\n", ")\n", "df_eval_unaligned_base = evaluate_minimal(\n", " df_gen_unaligned_base,\n", " use_llm_judge=True,\n", " ensemble_with_heuristic=True\n", ")\n", "\n", "asr_unaligned_base = _asr(df_eval_unaligned_base)" ] }, { "cell_type": "code", "execution_count": null, "id": "0483cf69-bffa-4380-9eb9-2320e1570cbe", "metadata": {}, "outputs": [], "source": [ "pdef_unaligned = build_prompt_defense(\n", " safe_prompts_seed, unsafe_prompts_seed,\n", " PromptDefenseConfig(\n", " emb_thresh=0.05,\n", " strict_threshold=0.15,\n", " extreme_threshold=0.30,\n", " paraphrase=True,\n", " max_edits=4,\n", " strip_urls=True,\n", " strip_injections=True,\n", " )\n", ")\n", "\n", "print(\"Unaligned prompt defence init\")" ] }, { "cell_type": "code", "execution_count": null, "id": "cf29ba37-b78d-43e2-b573-2f1a5d425d16", "metadata": {}, "outputs": [], "source": [ "print(\"[DEFENSE] Evaluating UNALIGNED (prompt-based, FULL outputs) ...\")\n", "df_gen_unaligned_def = run_generation_df(\n", " unaligned, unsafe_eval_df,\n", " steerer=None,\n", " prompt_defense=pdef_unaligned,\n", " max_input_tokens=MAX_INPUT_TOKENS,\n", " max_new_tokens=MAX_NEW_TOKENS,\n", " temperature=None,\n", " top_p=1.0\n", ")\n", "df_eval_unaligned_def = evaluate_minimal(\n", " df_gen_unaligned_def,\n", " use_llm_judge=True,\n", " ensemble_with_heuristic=True\n", ")\n", "asr_unaligned_def = _asr(df_eval_unaligned_def)" ] }, { "cell_type": "code", "execution_count": null, "id": "6177b6d4-0ee6-4ebd-8add-41079adfd9b3", "metadata": {}, "outputs": [], "source": [ "print(\"[CLEANUP] Releasing UNALIGNED model and steerer from memory...\")\n", "del unaligned\n", "gc.collect()\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " torch.cuda.ipc_collect()" ] }, { "cell_type": "markdown", "id": "3f3e6ce1-cf12-4843-9517-0b84be75520f", "metadata": {}, "source": [ "# Results" ] }, { "cell_type": "code", "execution_count": null, "id": "2e99f224-3059-46c9-8801-1c66782ba901", "metadata": {}, "outputs": [], "source": [ "print(f\"[RESULT] Baseline ASR — ALIGNED: {asr_aligned_base:.3f} | UNALIGNED: {asr_unaligned_base:.3f}\")\n", "\n", "OUTDIR.mkdir(parents=True, exist_ok=True)\n", "df_gen_aligned_base.to_csv(OUTDIR / \"gen_aligned_baseline.csv\", index=False)\n", "df_gen_unaligned_base.to_csv(OUTDIR / \"gen_unaligned_baseline.csv\", index=False)\n", "df_eval_aligned_base.to_csv(OUTDIR / \"eval_aligned_baseline.csv\", index=False)\n", "df_eval_unaligned_base.to_csv(OUTDIR / \"eval_unaligned_baseline.csv\", index=False)\n", "\n", "print(f\"[RESULT] With Defense ASR — ALIGNED: {asr_aligned_def:.3f} | UNALIGNED: {asr_unaligned_def:.3f}\")\n", "\n", "OUTDIR.mkdir(parents=True, exist_ok=True)\n", "df_gen_aligned_def.to_csv(OUTDIR / \"gen_aligned_prompt.csv\", index=False)\n", "df_gen_unaligned_def.to_csv(OUTDIR / \"gen_unaligned_prompt.csv\", index=False)\n", "df_eval_aligned_def.to_csv(OUTDIR / \"eval_aligned_prompt.csv\", index=False)\n", "df_eval_unaligned_def.to_csv(OUTDIR / \"eval_unaligned_prompt.csv\", index=False)\n", "\n", "summary = {\n", " \"baseline\": {\"aligned\": asr_aligned_base, \"unaligned\": asr_unaligned_base},\n", " \"defense\": {\"aligned\": asr_aligned_def, \"unaligned\": asr_unaligned_def},\n", "}\n", "with open(OUTDIR / \"summary.json\", \"w\") as f:\n", " json.dump(summary, f, indent=2)\n", "print(\"\\n[SUMMARY]\", json.dumps(summary, indent=2))" ] }, { "cell_type": "code", "execution_count": null, "id": "66d21350-1ec1-4f19-80bb-c2aa7c5d83a4", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(10, 4))\n", "y_a = df_eval_aligned_base['jailbreak_success'].astype(int).values\n", "y_u = df_eval_unaligned_base['jailbreak_success'].astype(int).values\n", "x = np.arange(len(y_a))\n", "\n", "plt.plot(x, y_a, label=\"Aligned (no defense)\")\n", "plt.plot(x, y_u, label=\"Unaligned (no defense)\")\n", "plt.xlabel(\"Attempt index\")\n", "plt.ylabel(\"Success (0/1)\")\n", "plt.title(\"Jailbreak Attempts vs Success — Baseline\")\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "00b4072a-cc01-419d-a89b-cfddfd45ec14", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(10, 4))\n", "y_a = df_eval_aligned_def['jailbreak_success'].astype(int).values\n", "y_u = df_eval_unaligned_def['jailbreak_success'].astype(int).values\n", "x = np.arange(len(y_a))\n", "\n", "plt.plot(x, y_a, label=\"Aligned (defense)\")\n", "plt.plot(x, y_u, label=\"Unaligned (defense)\")\n", "plt.xlabel(\"Attempt index\")\n", "plt.ylabel(\"Success (0/1)\")\n", "plt.title(\"Jailbreak Attempts vs Success — defense\")\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "7986b2a6-a0af-4301-9b5e-773ce3493dce", "metadata": {}, "outputs": [], "source": [ "labels = [\"Aligned\", \"Unaligned\"]\n", "baseline = [asr_aligned_base, asr_unaligned_base]\n", "defense = [asr_aligned_def, asr_unaligned_def]\n", "\n", "plt.figure(figsize=(6,4))\n", "x = np.arange(len(labels))\n", "width = 0.35\n", "plt.bar(x - width/2, baseline, width, label='Baseline')\n", "plt.bar(x + width/2, defense, width, label='With Prompt Defence')\n", "plt.xticks(x, labels)\n", "plt.ylabel('ASR')\n", "plt.title('Attack Success Rate (lower is better)')\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "af7dfa1e-3bf9-4524-bc60-033247a67948", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 5 }