mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
import os
|
|
|
|
import httpx
|
|
from loguru import logger
|
|
from pydantic import BaseModel
|
|
|
|
from langalf.probe_actor.refusal import refusal_heuristic
|
|
from langalf.probe_data.data import prepare_prompts
|
|
|
|
IS_VERCEL = os.getenv("IS_VERCEL", "f") == "t"
|
|
|
|
|
|
class ScanResult(BaseModel):
|
|
module: str
|
|
tokens: float
|
|
cost: float
|
|
progress: float
|
|
failureRate: float = 0.0
|
|
status: bool = False
|
|
|
|
@classmethod
|
|
def status_msg(cls, msg: str):
|
|
return cls(
|
|
module=msg,
|
|
tokens=0,
|
|
cost=0,
|
|
progress=0,
|
|
failureRate=0,
|
|
status=True,
|
|
).json()
|
|
|
|
|
|
async def perform_scan(request_factory, max_budget: int, datasets: list[dict] = []):
|
|
yield ScanResult.status_msg("Loading datasets...")
|
|
if IS_VERCEL:
|
|
yield ScanResult.status_msg(
|
|
"Vercel deployment detected. Streaming messages are not supported by serverless, plz run it locally."
|
|
)
|
|
return
|
|
prompt_modules = prepare_prompts(
|
|
dataset_names=[m["dataset_name"] for m in datasets if m["selected"]],
|
|
budget=max_budget,
|
|
)
|
|
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
|
|
|
|
errors = []
|
|
refusals = []
|
|
size = sum(len(m.prompts) for m in prompt_modules)
|
|
step = 0
|
|
for mi, module in enumerate(prompt_modules):
|
|
tokens = 0
|
|
module_failures = 0
|
|
logger.info(f"Scanning {module.dataset_name} {len(module.prompts)}")
|
|
for i, prompt in enumerate(module.prompts):
|
|
step += 1
|
|
progress = 100 * (step) / size
|
|
|
|
# Naive token count
|
|
tokens += len(prompt.split())
|
|
try:
|
|
r = await request_factory.fn(prompt=prompt)
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Request error: {e}")
|
|
errors.append((module.dataset_name, prompt.replace("\n", ";"), e))
|
|
module_failures += 1
|
|
continue
|
|
if r.status_code >= 400:
|
|
module_failures += 1
|
|
errors.append(
|
|
(
|
|
module.dataset_name,
|
|
prompt.replace("\n", ";"),
|
|
r.status_code,
|
|
r.text,
|
|
)
|
|
)
|
|
elif not refusal_heuristic(r.json()):
|
|
refusals.append(
|
|
(
|
|
module.dataset_name,
|
|
prompt.replace("\n", ";"),
|
|
r.status_code,
|
|
r.text,
|
|
)
|
|
)
|
|
module_failures += 1
|
|
# Naive token count for llm response
|
|
tokens += len(r.text.split())
|
|
yield ScanResult(
|
|
module=module.dataset_name,
|
|
tokens=round(tokens / 1000, 1),
|
|
cost=round(tokens * 1.5 / 1000_000, 2),
|
|
progress=round(progress, 2),
|
|
failureRate=100 * module_failures / max(len(module.prompts), 1),
|
|
).json()
|
|
yield ScanResult.status_msg("Done.")
|
|
import pandas as pd
|
|
|
|
df = pd.DataFrame(
|
|
errors + refusals, columns=["module", "prompt", "status_code", "content"]
|
|
)
|
|
df.to_csv("failures.csv", index=False)
|
|
# TODO: save all results
|