Files
EvoSynth/jailbreak_toolbox/models/implementations/multithreaded_openai_model.py
T
dongdongunique f3af94df0b first commit
2025-12-10 00:54:02 +08:00

206 lines
7.9 KiB
Python

# jailbreak_toolbox/models/implementations/multithreaded_openai_model.py
from typing import Union, List, Dict, Any, Optional
import time
import base64
import os
import io
from PIL import Image
from ..multithreaded_model import MultiThreadedModel
from ...core.registry import model_registry
@model_registry.register("multithreaded_openai")
class MultiThreadedOpenAIModel(MultiThreadedModel):
"""
Multi-threaded implementation of OpenAI API client.
Supports concurrent API calls with rate limiting and retries.
"""
def __init__(
self,
api_key: str,
model_name: str = "gpt-4",
base_url: Optional[str] = None,
system_message: str = "You are a helpful assistant.",
temperature: float = 0.7,
max_tokens: int = 1024,
max_workers: int = 5,
requests_per_minute: int = 60,
retry_attempts: int = 3,
retry_delay: float = 1.0,
**kwargs
):
"""
Initialize the multi-threaded OpenAI model.
Args:
api_key: OpenAI API key
model_name: Model name/identifier
base_url: Optional base URL for API requests
system_message: System message for conversation
temperature: Temperature parameter for generation
max_tokens: Maximum tokens to generate
max_workers: Maximum number of worker threads
requests_per_minute: Maximum number of requests per minute
retry_attempts: Number of retry attempts on failure
retry_delay: Delay between retries in seconds
"""
super().__init__(
max_workers=max_workers,
requests_per_minute=requests_per_minute,
retry_attempts=retry_attempts,
retry_delay=retry_delay,
**kwargs
)
self.api_key = api_key
self.model_name = model_name
self.base_url = base_url
self.system_message = system_message
self.temperature = temperature
self.max_tokens = max_tokens
# Initialize OpenAI client
try:
from openai import OpenAI
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
self.conversation_history = [{"role": "system", "content": self.system_message}]
except ImportError:
raise ImportError("OpenAI package is required. Install it using: pip install openai")
def query(self, text_input: Union[str, List[Dict]] = "", image_input: Any = None, maintain_history: bool = False) -> str:
"""
Send a query to the OpenAI API and return the response.
Args:
text_input: The prompt text to send (can be string or list of message dicts)
image_input: Path to image file or PIL Image object for vision models
maintain_history: Whether to add this exchange to conversation history
Returns:
The model's response as a string
"""
messages = []
# Handle image input
if image_input is not None:
try:
image_base64 = self._encode_image_to_base64(image_input)
if isinstance(text_input, list):
# If text_input is already a list of messages, use it directly
messages = text_input
elif isinstance(text_input, str):
# Create a message with both text and image
messages = [{"role": "system", "content": self.system_message}] if not maintain_history else []
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": text_input
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
})
except Exception as e:
print(f"Warning: Failed to process image input: {str(e)}")
# Fall back to text-only input
image_input = None
# Handle text-only input or fallback from image processing failure
if image_input is None:
if isinstance(text_input, list):
# If text_input is already a list of messages, use it directly
messages = text_input
elif isinstance(text_input, str):
# Add user message to history if maintaining history
if maintain_history:
self.add_user_message(text_input)
messages = self.conversation_history
else:
# For single-turn interactions without affecting history
messages = [{"role": "system", "content": self.system_message},
{"role": "user", "content": text_input}]
# Make API call with retries built into MultiThreadedModel base class
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=self.temperature,
max_tokens=self.max_tokens
)
response_text = response.choices[0].message.content
# Add assistant response to history if maintaining history
if maintain_history:
self.add_assistant_message(response_text)
return response_text
def _encode_image_to_base64(self, image_input) -> str:
"""
Convert image input (file path or PIL Image) to base64 encoding.
Args:
image_input: Path to image file or PIL Image object
Returns:
Base64 encoded image string
"""
# If input is a string, treat it as a file path
if isinstance(image_input, str):
if not os.path.exists(image_input):
raise FileNotFoundError(f"Image file not found: {image_input}")
with open(image_input, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# If input is a PIL Image
elif isinstance(image_input, Image.Image):
buffer = io.BytesIO()
image_input.save(buffer, format="JPEG")
return base64.b64encode(buffer.getvalue()).decode('utf-8')
# If input is bytes
elif isinstance(image_input, bytes):
return base64.b64encode(image_input).decode('utf-8')
else:
raise TypeError(f"Unsupported image input type: {type(image_input)}")
def add_user_message(self, message: Union[str, List]) -> None:
"""
Add a user message to the conversation history.
Args:
message: The message content (string or list of content dicts)
"""
self.conversation_history.append({"role": "user", "content": message})
def add_assistant_message(self, message: str) -> None:
"""
Add an assistant message to the conversation history.
Args:
message: The message content
"""
self.conversation_history.append({"role": "assistant", "content": message})
def get_conversation_history(self) -> List[Dict]:
"""
Get the current conversation history.
Returns:
List of message dictionaries
"""
return self.conversation_history
def reset_conversation(self) -> None:
"""Reset the conversation history to only include the system message."""
self.conversation_history = [{"role": "system", "content": self.system_message}]