feat(add Inspect AI):

This commit is contained in:
Alexander Myasoedov
2024-05-09 10:07:17 +03:00
parent 9c4828f259
commit cc5ea04205
8 changed files with 154 additions and 21 deletions
+1
View File
@@ -5,3 +5,4 @@ __pycache__/
failures.csv
runs/
*.todo
logs/
+48 -18
View File
@@ -1,14 +1,15 @@
import logging
import random
import sys
from asyncio import Event, Queue
from datetime import datetime
from pathlib import Path
from fastapi import BackgroundTasks, FastAPI, HTTPException, Response
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
from .http_spec import LLMSpec
from .probe_actor import fuzzer
@@ -16,15 +17,6 @@ from .probe_actor.refusal import REFUSAL_MARKS
from .probe_data import REGISTRY
from .report_chart import plot_security_report
logger.remove(0)
logger.add(
sys.stderr,
format="<green>[{level}]</green> <blue>{time:YYYY-MM-DD HH:mm:ss.SS}</blue> | <cyan>{module}:{function}:{line}</cyan> | <white>{message}</white>",
colorize=True,
level="INFO",
)
# Create the FastAPI app instance
app = FastAPI()
origins = [
@@ -164,13 +156,13 @@ class Message(BaseModel):
class CompletionRequest(BaseModel):
model: str
messages: list[Message]
temperature: float
top_p: float
n: int
stop: list[str]
max_tokens: int
presence_penalty: float
frequency_penalty: float
temperature: float = 0.7 # Default value for temperature
top_p: float = 1.0 # Default value for top_p
n: int = 1 # Default value for n
stop: list[str] = None # Optional; specify as None if not provided
max_tokens: int = 100 # Default value for max_tokens
presence_penalty: float = 0.0 # Default value for presence_penalty
frequency_penalty: float = 0.0 # Default value for frequency_penalty
# OpenAI proxy endpoint
@@ -206,3 +198,41 @@ async def proxy_completions(request: CompletionRequest):
}
],
}
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": True,
"handlers": {
"console": {
"class": "logging.StreamHandler",
},
},
"root": {
"handlers": ["console"],
"level": "INFO",
},
"loggers": {
"uvicorn.access": {
"level": "ERROR", # Set higher log level to suppress info logs globally
"handlers": ["console"],
"propagate": False,
}
},
}
)
class LogNon200ResponsesMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
if response.status_code != 200:
logger.error(
f"{request.method} {request.url} - Status code: {response.status_code}"
)
return response
# Add middleware to the application
app.add_middleware(LogNon200ResponsesMiddleware)
+10
View File
@@ -141,6 +141,16 @@ REGISTRY = [
"url": "https://github.com/leondz/garak2",
"dynamic": True,
},
{
"dataset_name": "InspectAI",
"num_prompts": 0,
"tokens": 0,
"approx_cost": 0.0,
"source": "Github: https://github.com/UKGovernmentBEIS/inspect_ai",
"selected": False,
"url": "https://github.com/UKGovernmentBEIS/inspect_ai",
"dynamic": True,
},
{
"dataset_name": "Custom CSV",
"num_prompts": len(load_local_csv().prompts),
+10 -1
View File
@@ -7,7 +7,11 @@ import pandas as pd
from loguru import logger
from agentic_security.probe_data import stenography_fn
from agentic_security.probe_data.modules import adaptive_attacks, garak_tool
from agentic_security.probe_data.modules import (
adaptive_attacks,
garak_tool,
inspect_ai_tool,
)
IS_VERCEL = os.getenv("IS_VERCEL", "f") == "t"
@@ -206,6 +210,11 @@ def prepare_prompts(dataset_names, budget, tools_inbox=None):
garak_tool.Module(group, tools_inbox=tools_inbox).apply(),
lazy=True,
),
"InspectAI": lambda: dataset_from_iterator(
"InspectAI",
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),
lazy=True,
),
"GPT fuzzer": lambda: [],
}
@@ -9,7 +9,6 @@ from loguru import logger
class Module:
def __init__(self, prompt_groups: [], tools_inbox: asyncio.Queue):
self.tools_inbox = tools_inbox
if not self.is_garak_installed():
@@ -0,0 +1,13 @@
from inspect_ai import Task, eval, task
from inspect_ai.dataset import example_dataset
from inspect_ai.scorer import model_graded_fact
from inspect_ai.solver import chain_of_thought, generate, self_critique
@task
def theory_of_mind():
return Task(
dataset=example_dataset("theory_of_mind"),
plan=[chain_of_thought(), generate(), self_critique()],
scorer=model_graded_fact(),
)
@@ -0,0 +1,71 @@
import asyncio
import importlib.util
import os
from loguru import logger
inspect_ai_task = (
__file__.replace("inspect_ai_tool.py", "inspect_ai_task.py")
.replace(os.getcwd(), "")
.strip("/")
)
class Module:
name = "Inspect AI"
def __init__(self, prompt_groups: [], tools_inbox: asyncio.Queue):
self.tools_inbox = tools_inbox
if not self.is_tool_installed():
logger.error(
"inspect_ai module is not installed. Please install it using 'pip install inspect_ai'"
)
def is_tool_installed(self) -> bool:
inspect_ai = importlib.util.find_spec("inspect_ai")
return inspect_ai is not None
async def _proc(self, command):
env = os.environ.copy()
env["OPENAI_API_BASE"] = "http://0.0.0.0:8718/proxy"
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
shell=True,
)
logger.info(f"Started {command}")
# Read output as it becomes available
async for line in process.stdout:
logger.info(line.decode().strip())
# Check for errors
err = await process.stderr.read()
if err:
logger.error(err.decode().strip())
await process.wait()
logger.info(f"Command {command} {process}finished.")
async def apply(self) -> []:
env = os.environ.copy()
env["OPENAI_API_BASE"] = "http://0.0.0.0:8718/proxy"
# Command to be executed
command = f"inspect eval {inspect_ai_task} --model openai/gpt-4 --model-base-url=http://0.0.0.0:8718/proxy"
logger.info(f"Executing command: {command}")
proc = asyncio.create_task(self._proc(command))
is_empty = self.tools_inbox.empty()
await asyncio.sleep(2)
logger.info(f"Is inbox empty? {is_empty}")
while not self.tools_inbox.empty():
ref = self.tools_inbox.get_nowait()
message, _, ready = ref["message"], ref["reply"], ref["ready"]
yield message
ready.set()
logger.info(f"{self.name} tool finished.")
await proc
+1 -1
View File
@@ -14,7 +14,7 @@ Content-Type: application/json
###
POST http://0.0.0.0:3008/v1/self-probe
POST http://0.0.0.0:8718/v1/self-probe
Authorization: Bearer XXXXX
Content-Type: application/json