diff --git a/agentic_security/fuzz_chain/__init__.py b/agentic_security/fuzz_chain/__init__.py new file mode 100644 index 0000000..c516b11 --- /dev/null +++ b/agentic_security/fuzz_chain/__init__.py @@ -0,0 +1,13 @@ +from agentic_security.fuzz_chain.chain import ( + FuzzChain, + FuzzNode, + FuzzRunnable, +) +from agentic_security.fuzz_chain.provider import LLMProvider + +__all__ = [ + "FuzzChain", + "FuzzNode", + "FuzzRunnable", + "LLMProvider", +] diff --git a/agentic_security/fuzz_chain/chain.py b/agentic_security/fuzz_chain/chain.py new file mode 100644 index 0000000..caadd81 --- /dev/null +++ b/agentic_security/fuzz_chain/chain.py @@ -0,0 +1,78 @@ +from __future__ import annotations +import logging +from typing import Any, Protocol + +logger = logging.getLogger(__name__) + + +class FuzzRunnable(Protocol): + """Protocol for objects that can be run in a fuzzing chain.""" + + async def run(self, **kwargs: Any) -> str: + ... + + +class FuzzNode: + """A single node in a fuzzing chain that executes an LLM call with template variables.""" + + def __init__(self, llm: Any, prompt: str) -> None: + self._llm = llm + self._prompt = prompt + + async def run(self, **kwargs: Any) -> str: + full_prompt = self._render_prompt(kwargs) + response = await self._llm.generate(full_prompt) + return response if response else "" + + def _render_prompt(self, kwargs: dict[str, Any]) -> str: + if not kwargs: + return self._prompt + result = self._prompt + for key, value in kwargs.items(): + result = result.replace(f"{{{key}}}", str(value)) + return result + + def __or__(self, other: Any) -> FuzzChain: + if isinstance(other, FuzzChain): + return FuzzChain([self, *other._nodes]) + if isinstance(other, FuzzNode): + return FuzzChain([self, other]) + # Assume LLMProvider-like object + return FuzzChain([self, FuzzNode(other, "{input}")]) + + def __repr__(self) -> str: + return f"FuzzNode(prompt={self._prompt!r})" + + +class FuzzChain: + """A chain of FuzzNodes that execute sequentially, passing output as input.""" + + def __init__(self, nodes: list[FuzzNode] | None = None) -> None: + self._nodes: list[FuzzNode] = [] + if nodes: + self._nodes.extend(nodes) + + async def run(self, **kwargs: Any) -> str: + if not self._nodes: + return "" + result = "" + for i, node in enumerate(self._nodes): + logger.debug(f"Running node {i}: {node} with kwargs {kwargs}") + result = await node.run(**kwargs) + logger.debug(f"Node {i} result: {result[:100]}...") + kwargs = {"input": result} + return result + + def __or__(self, other: Any) -> FuzzChain: + if isinstance(other, FuzzChain): + return FuzzChain([*self._nodes, *other._nodes]) + if isinstance(other, FuzzNode): + return FuzzChain([*self._nodes, other]) + # Assume LLMProvider-like object + return FuzzChain([*self._nodes, FuzzNode(other, "{input}")]) + + def __len__(self) -> int: + return len(self._nodes) + + def __repr__(self) -> str: + return f"FuzzChain({self._nodes!r})" diff --git a/agentic_security/fuzz_chain/provider.py b/agentic_security/fuzz_chain/provider.py new file mode 100644 index 0000000..ff54dae --- /dev/null +++ b/agentic_security/fuzz_chain/provider.py @@ -0,0 +1,9 @@ +from typing import Protocol, Any + + +class LLMProvider(Protocol): + """Protocol for LLM providers that can be used in FuzzChain.""" + + async def generate(self, prompt: str, **kwargs: Any) -> str: + """Generate response from LLM. Returns the response text.""" + ... diff --git a/tests/unit/fuzz_chain/__init__.py b/tests/unit/fuzz_chain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/fuzz_chain/test_chain.py b/tests/unit/fuzz_chain/test_chain.py new file mode 100644 index 0000000..3cbe457 --- /dev/null +++ b/tests/unit/fuzz_chain/test_chain.py @@ -0,0 +1,223 @@ +import pytest +from inline_snapshot import snapshot +from typing import Any + +from agentic_security.fuzz_chain import FuzzNode, FuzzChain, LLMProvider + + +class MockLLMProvider: + """Mock LLM provider for testing.""" + + def __init__(self, responses: list[str] | str = "mock response"): + self._responses = responses if isinstance(responses, list) else [responses] + self._call_count = 0 + self.prompts: list[str] = [] + + async def generate(self, prompt: str, **kwargs: Any) -> str: + self.prompts.append(prompt) + response = self._responses[min(self._call_count, len(self._responses) - 1)] + self._call_count += 1 + return response + + +class TestFuzzNode: + @pytest.mark.asyncio + async def test_simple_prompt(self): + llm = MockLLMProvider("test response") + node = FuzzNode(llm, "Hello world") + result = await node.run() + assert result == "test response" + assert llm.prompts == ["Hello world"] + + @pytest.mark.asyncio + async def test_template_variable_substitution(self): + llm = MockLLMProvider("response") + node = FuzzNode(llm, "Hello {name}!") + result = await node.run(name="Alice") + assert result == "response" + assert llm.prompts == ["Hello Alice!"] + + @pytest.mark.asyncio + async def test_multiple_template_variables(self): + llm = MockLLMProvider("response") + node = FuzzNode(llm, "{greeting} {name}, welcome to {place}!") + await node.run(greeting="Hello", name="Bob", place="Wonderland") + assert llm.prompts == ["Hello Bob, welcome to Wonderland!"] + + @pytest.mark.asyncio + async def test_missing_variable_preserved(self): + llm = MockLLMProvider("response") + node = FuzzNode(llm, "Hello {name} and {other}!") + await node.run(name="Alice") + # Only replaces variables that are provided + assert llm.prompts == ["Hello Alice and {other}!"] + + @pytest.mark.asyncio + async def test_input_variable(self): + llm = MockLLMProvider("response") + node = FuzzNode(llm, "Process: {input}") + await node.run(input="some data") + assert llm.prompts == ["Process: some data"] + + @pytest.mark.asyncio + async def test_empty_response_handling(self): + llm = MockLLMProvider("") + node = FuzzNode(llm, "Test") + result = await node.run() + assert result == "" + + def test_repr(self): + llm = MockLLMProvider() + node = FuzzNode(llm, "Test prompt") + assert repr(node) == snapshot("FuzzNode(prompt='Test prompt')") + + def test_pipe_two_nodes(self): + llm = MockLLMProvider() + node1 = FuzzNode(llm, "First") + node2 = FuzzNode(llm, "Second") + chain = node1 | node2 + assert isinstance(chain, FuzzChain) + assert len(chain) == 2 + + def test_pipe_node_to_chain(self): + llm = MockLLMProvider() + node1 = FuzzNode(llm, "First") + node2 = FuzzNode(llm, "Second") + node3 = FuzzNode(llm, "Third") + chain = node1 | node2 + extended = node3 | chain + # node3 followed by chain nodes + assert len(extended) == 3 + + +class TestFuzzChain: + @pytest.mark.asyncio + async def test_empty_chain(self): + chain = FuzzChain() + result = await chain.run(input="test") + assert result == "" + + @pytest.mark.asyncio + async def test_single_node_chain(self): + llm = MockLLMProvider("output") + node = FuzzNode(llm, "Prompt: {input}") + chain = FuzzChain([node]) + result = await chain.run(input="test data") + assert result == "output" + assert llm.prompts == ["Prompt: test data"] + + @pytest.mark.asyncio + async def test_multi_node_chain_passes_output_as_input(self): + llm = MockLLMProvider(["step1 result", "step2 result", "final result"]) + node1 = FuzzNode(llm, "First: {input}") + node2 = FuzzNode(llm, "Second: {input}") + node3 = FuzzNode(llm, "Third: {input}") + chain = FuzzChain([node1, node2, node3]) + + result = await chain.run(input="initial") + assert result == "final result" + assert llm.prompts == snapshot([ + "First: initial", + "Second: step1 result", + "Third: step2 result", + ]) + + @pytest.mark.asyncio + async def test_chain_with_custom_variables(self): + llm = MockLLMProvider(["analyzed", "evaluated"]) + node1 = FuzzNode(llm, "Analyze {topic}: {input}") + node2 = FuzzNode(llm, "Evaluate: {input}") + chain = FuzzChain([node1, node2]) + + result = await chain.run(topic="security", input="test prompt") + assert result == "evaluated" + assert llm.prompts == snapshot([ + "Analyze security: test prompt", + "Evaluate: analyzed", + ]) + + def test_pipe_chain_to_node(self): + llm = MockLLMProvider() + node1 = FuzzNode(llm, "First") + node2 = FuzzNode(llm, "Second") + node3 = FuzzNode(llm, "Third") + chain = FuzzChain([node1, node2]) + extended = chain | node3 + assert len(extended) == 3 + + def test_pipe_chain_to_chain(self): + llm = MockLLMProvider() + node1 = FuzzNode(llm, "A") + node2 = FuzzNode(llm, "B") + node3 = FuzzNode(llm, "C") + node4 = FuzzNode(llm, "D") + chain1 = FuzzChain([node1, node2]) + chain2 = FuzzChain([node3, node4]) + combined = chain1 | chain2 + assert len(combined) == 4 + + def test_len(self): + llm = MockLLMProvider() + chain = FuzzChain([ + FuzzNode(llm, "A"), + FuzzNode(llm, "B"), + FuzzNode(llm, "C"), + ]) + assert len(chain) == 3 + + def test_repr(self): + llm = MockLLMProvider() + chain = FuzzChain([FuzzNode(llm, "Test")]) + repr_str = repr(chain) + assert "FuzzChain" in repr_str + assert "FuzzNode" in repr_str + + +class TestPipeOperatorChaining: + @pytest.mark.asyncio + async def test_triple_pipe_chain(self): + llm = MockLLMProvider(["a", "b", "c"]) + node1 = FuzzNode(llm, "Step1: {input}") + node2 = FuzzNode(llm, "Step2: {input}") + node3 = FuzzNode(llm, "Step3: {input}") + + chain = node1 | node2 | node3 + result = await chain.run(input="start") + + assert result == "c" + assert llm.prompts == snapshot([ + "Step1: start", + "Step2: a", + "Step3: b", + ]) + + @pytest.mark.asyncio + async def test_chain_with_different_providers(self): + llm1 = MockLLMProvider("from llm1") + llm2 = MockLLMProvider("from llm2") + + node1 = FuzzNode(llm1, "Provider1: {input}") + node2 = FuzzNode(llm2, "Provider2: {input}") + + chain = node1 | node2 + result = await chain.run(input="test") + + assert result == "from llm2" + assert llm1.prompts == ["Provider1: test"] + assert llm2.prompts == ["Provider2: from llm1"] + + +class TestProtocolCompliance: + def test_llm_provider_protocol_mock(self): + provider = MockLLMProvider() + # Should have generate method that accepts prompt and kwargs + assert hasattr(provider, "generate") + + def test_fuzz_node_has_run(self): + llm = MockLLMProvider() + node = FuzzNode(llm, "Test") + assert hasattr(node, "run") + + def test_fuzz_chain_has_run(self): + chain = FuzzChain() + assert hasattr(chain, "run")