Files
agentic_security/langalf/probe_actor/fuzzer.py
T
Alexander Myasoedov 15106c1c23 feat(Init):
2024-04-13 18:04:24 +03:00

104 lines
3.2 KiB
Python

import os
import httpx
from loguru import logger
from pydantic import BaseModel
from langalf.probe_actor.refusal import refusal_heuristic
from langalf.probe_data.data import prepare_prompts
IS_VERCEL = os.getenv("IS_VERCEL", "f") == "t"
class ScanResult(BaseModel):
module: str
tokens: float
cost: float
progress: float
failureRate: float = 0.0
status: bool = False
@classmethod
def status_msg(cls, msg: str):
return cls(
module=msg,
tokens=0,
cost=0,
progress=0,
failureRate=0,
status=True,
).json()
async def perform_scan(request_factory, max_budget: int, datasets: list[dict] = []):
yield ScanResult.status_msg("Loading datasets...")
if IS_VERCEL:
yield ScanResult.status_msg(
"Vercel deployment detected. Streaming messages are not supported by serverless, plz run it locally."
)
return
prompt_modules = prepare_prompts(
dataset_names=[m["dataset_name"] for m in datasets if m["selected"]],
budget=max_budget,
)
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
errors = []
refusals = []
size = sum(len(m.prompts) for m in prompt_modules)
step = 0
for mi, module in enumerate(prompt_modules):
tokens = 0
module_failures = 0
logger.info(f"Scanning {module.dataset_name} {len(module.prompts)}")
for i, prompt in enumerate(module.prompts):
step += 1
progress = 100 * (step) / size
# Naive token count
tokens += len(prompt.split())
try:
r = await request_factory.fn(prompt=prompt)
except httpx.RequestError as e:
logger.error(f"Request error: {e}")
errors.append((module.dataset_name, prompt.replace("\n", ";"), e))
module_failures += 1
continue
if r.status_code >= 400:
module_failures += 1
errors.append(
(
module.dataset_name,
prompt.replace("\n", ";"),
r.status_code,
r.text,
)
)
elif not refusal_heuristic(r.json()):
refusals.append(
(
module.dataset_name,
prompt.replace("\n", ";"),
r.status_code,
r.text,
)
)
module_failures += 1
# Naive token count for llm response
tokens += len(r.text.split())
yield ScanResult(
module=module.dataset_name,
tokens=round(tokens / 1000, 1),
cost=round(tokens * 1.5 / 1000_000, 2),
progress=round(progress, 2),
failureRate=100 * module_failures / max(len(module.prompts), 1),
).json()
yield ScanResult.status_msg("Done.")
import pandas as pd
df = pd.DataFrame(
errors + refusals, columns=["module", "prompt", "status_code", "content"]
)
df.to_csv("failures.csv", index=False)
# TODO: save all results