diff --git a/agentic_security/http_spec.py b/agentic_security/http_spec.py index d59a398..7aa5645 100644 --- a/agentic_security/http_spec.py +++ b/agentic_security/http_spec.py @@ -69,9 +69,7 @@ class LLMSpec(BaseModel): return response - def validate( - self, prompt: str, encoded_image: str, encoded_audio: str, files: dict | None - ) -> None: + def validate(self, prompt: str, encoded_image: str, encoded_audio: str, files: dict | None) -> None: if self.has_files and not files: raise ValueError("Files are required for this request.") @@ -107,12 +105,17 @@ class LLMSpec(BaseModel): content = content.replace("<>", encoded_image) content = content.replace("<>", encoded_audio) + # Remove Content-Length from headers to avoid mismatch when + # placeholder replacement changes body size. httpx will set + # the correct Content-Length based on the actual content. + clean_headers = {k: v for k, v in self.headers.items() if k.lower() != "content-length"} + transport = httpx.AsyncHTTPTransport(retries=settings_var("network.retry", 3)) async with httpx.AsyncClient(transport=transport) as client: response = await client.request( method=self.method, url=self.url, - headers=self.headers, + headers=clean_headers, content=content, timeout=self.timeout(), ) @@ -127,9 +130,7 @@ class LLMSpec(BaseModel): return await self.probe( "test", # TODO: fix url for mp3 - encoded_audio=encode_audio_base64_by_url( - "https://www.example.com/audio.mp3" - ), + encoded_audio=encode_audio_base64_by_url("https://www.example.com/audio.mp3"), ) case LLMSpec(has_files=True): return await self._probe_with_files({}) @@ -168,18 +169,14 @@ def parse_http_spec(http_spec: str) -> LLMSpec: # Extract the method and URL from the first line request_line_parts = lines[0].split() if len(request_line_parts) < 2: - raise InvalidHTTPSpecError( - "First line of HTTP spec must include the method and URL." - ) + raise InvalidHTTPSpecError("First line of HTTP spec must include the method and URL.") method, url = request_line_parts[0], request_line_parts[1] # Check url validity valid_url = urlparse(url) # if missing the correct formatting ://, urlparse.netloc will be empty if valid_url.scheme not in ("http", "https") or not valid_url.netloc: - raise InvalidHTTPSpecError( - f"Invalid URL: {url}. Ensure it starts with 'http://' or 'https://'" - ) + raise InvalidHTTPSpecError(f"Invalid URL: {url}. Ensure it starts with 'http://' or 'https://'") # Initialize headers and body headers = {} diff --git a/agentic_security/probe_actor/fuzzer.py b/agentic_security/probe_actor/fuzzer.py index 7254fc8..5056efa 100644 --- a/agentic_security/probe_actor/fuzzer.py +++ b/agentic_security/probe_actor/fuzzer.py @@ -24,6 +24,9 @@ BUDGET_MULTIPLIER = settings_var("fuzzer.budget_multiplier", 100000000) INITIAL_OPTIMIZER_POINTS = settings_var("fuzzer.initial_optimizer_points", 25) MIN_FAILURE_SAMPLES = settings_var("fuzzer.min_failure_samples", 5) FAILURE_RATE_THRESHOLD = settings_var("fuzzer.failure_rate_threshold", 0.5) +FAILURES_CSV_PATH = settings_var("fuzzer.failures_csv_path", "failures.csv") +FULL_LOG_CSV_PATH = settings_var("fuzzer.full_log_csv_path", "full_scan_log.csv") +MAX_INJECTION_ATTEMPTS = settings_var("fuzzer.max_injection_attempts", 20) async def generate_prompts( @@ -111,9 +114,7 @@ async def process_prompt( if response.status_code >= 400: logger.error(f"HTTP {response.status_code} {response.content=}") - fuzzer_state.add_error( - module_name, prompt, response.status_code, response.text - ) + fuzzer_state.add_error(module_name, prompt, response.status_code, response.text) return tokens, True # Process successful response @@ -123,9 +124,7 @@ async def process_prompt( # Check if the response indicates a refusal refused = refusal_heuristic(response.json()) if refused: - fuzzer_state.add_refusal( - module_name, prompt, response.status_code, response_text - ) + fuzzer_state.add_refusal(module_name, prompt, response.status_code, response_text) fuzzer_state.add_output(module_name, prompt, response_text, refused) return tokens, refused @@ -169,10 +168,7 @@ async def process_prompt_batch( - Total number of tokens processed. - Number of failed prompts. """ - tasks = [ - process_prompt(request_factory, p, tokens, module_name, fuzzer_state) - for p in prompts - ] + tasks = [process_prompt(request_factory, p, tokens, module_name, fuzzer_state) for p in prompts] results = await asyncio.gather(*tasks) total_tokens = sum(r[0] for r in results) failures = sum(1 for r in results if r[1]) @@ -216,11 +212,7 @@ async def scan_module( # Initialize optimizer if optimization is enabled optimizer = ( - Optimizer( - [Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS - ) - if optimize - else None + Optimizer([Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS) if optimize else None ) module_size = 0 if module.lazy else len(module.prompts) @@ -422,8 +414,8 @@ async def perform_single_shot_scan( processed_prompts += module_size yield ScanResult.status_msg("Scan completed.") - fuzzer_state.export_failures("failures.csv") - fuzzer_state.export_full_log("full_scan_log.csv") + fuzzer_state.export_failures(FAILURES_CSV_PATH) + fuzzer_state.export_full_log(FULL_LOG_CSV_PATH) async def perform_many_shot_scan( @@ -515,7 +507,7 @@ async def perform_many_shot_scan( tokens += prompt_tokens injected = False - for _ in range(20): + for _ in range(MAX_INJECTION_ATTEMPTS): if injected: break @@ -552,14 +544,12 @@ async def perform_many_shot_scan( ).model_dump_json() if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES: - yield ScanResult.status_msg( - f"High failure rate detected ({failure_rate:.2%}). Stopping this module..." - ) + yield ScanResult.status_msg(f"High failure rate detected ({failure_rate:.2%}). Stopping this module...") break yield ScanResult.status_msg("Scan completed.") - fuzzer_state.export_failures("failures.csv") - fuzzer_state.export_full_log("full_scan_log.csv") + fuzzer_state.export_failures(FAILURES_CSV_PATH) + fuzzer_state.export_full_log(FULL_LOG_CSV_PATH) def scan_router(