Files
dongdongunique 75eb41f50b Add context length protection in access_runcontext_history
- Add loop mechanism to reduce history_depth if output exceeds max_total_chars
- Prevents context overflow when accessing run context history
- Prints reduction info and returns depth_reduced field in result
2025-12-10 03:23:53 +08:00

286 lines
12 KiB
Python

from ..base_model import BaseModel
from ...core.registry import model_registry
import openai
import time
import base64
from io import BytesIO
from typing import Any, Optional, List, Dict, Union
from PIL import Image
@model_registry.register("openai")
class OpenAIModel(BaseModel):
"""
OpenAI API model wrapper for jailbreak toolbox.
Supports various OpenAI models like gpt-3.5-turbo and gpt-4.
"""
def __init__(self,
api_key: str,
base_url: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
temperature: float = 0.7,
max_tokens: int = None,
retry_attempts: int = 3,
retry_delay: float = 2.0,
system_message: str = "You are a helpful assistant.",
embedding_model: str = "text-embedding-3-small",
**kwargs):
"""
Initialize the OpenAI model.
Args:
api_key: OpenAI API key
base_url: Optional base URL for OpenAI API
model_name: Name of the OpenAI model to use (e.g., "gpt-3.5-turbo", "gpt-4")
temperature: Sampling temperature (0.0-2.0)
max_tokens: Maximum tokens in the response
retry_attempts: Number of retry attempts for API calls
retry_delay: Delay between retry attempts in seconds
system_message: System message to set the assistant's behavior
embedding_model: Model to use for generating embeddings
"""
super().__init__(**kwargs)
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
self.retry_attempts = retry_attempts
self.retry_delay = retry_delay
self.system_message = system_message
self.embedding_model = embedding_model
# Initialize conversation history
self.conversation_history = [{"role": "system", "content": system_message}]
# Filter kwargs for different OpenAI API functions
self.client_kwargs = self._filter_client_kwargs(kwargs)
self.chat_kwargs = self._filter_chat_kwargs(kwargs)
self.embedding_kwargs = self._filter_embedding_kwargs(kwargs)
# Configure OpenAI client
openai.api_key = api_key
if base_url:
openai.base_url = base_url
self.client = openai.OpenAI(api_key=api_key, base_url=base_url, **self.client_kwargs)
print(f"Initialized OpenAI model: {model_name}")
def _filter_client_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Filter kwargs suitable for OpenAI client initialization"""
# Valid parameters for OpenAI() client
valid_client_params = {
'timeout', 'max_retries', 'default_headers', 'default_query',
'http_client', 'api_key', 'base_url', 'organization', 'project'
}
return {k: v for k, v in kwargs.items() if k in valid_client_params}
def _filter_chat_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Filter kwargs suitable for chat.completions.create()"""
# Valid parameters for chat completions
valid_chat_params = {
'frequency_penalty', 'logit_bias', 'logprobs', 'top_logprobs',
'max_tokens', 'n', 'presence_penalty', 'response_format',
'seed', 'stop', 'stream', 'temperature', 'top_p', 'tools', 'tool_choice',
'parallel_tool_calls', 'user'
}
return {k: v for k, v in kwargs.items() if k in valid_chat_params}
def _filter_embedding_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Filter kwargs suitable for embeddings.create()"""
# Valid parameters for embeddings
valid_embedding_params = {
'encoding_format', 'dimensions', 'user'
}
return {k: v for k, v in kwargs.items() if k in valid_embedding_params}
def _encode_image_to_base64(self, image_input: Union[str, Any]) -> str:
"""Encode an image to base64 string. Supports both file paths and PIL Image objects.
Args:
image_input: Path to image file or PIL Image object
Returns:
Base64 encoded string of the image
"""
# Check if it's a file path (string)
if isinstance(image_input, str):
with open(image_input, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# Check if it's a PIL Image object
if isinstance(image_input, Image.Image):
buffered = BytesIO()
image_input.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
raise ValueError("image_input must be either a file path (string) or a PIL Image object")
def query(self, text_input: Union[str, List[Dict]] = "", image_input: Any = None, maintain_history: bool = False) -> str:
"""
Send a query to the OpenAI API and return the response.
Args:
text_input: The prompt text to send (can be string or list of message dicts)
image_input: Path to image file or PIL Image object for vision models
maintain_history: Whether to add this exchange to conversation history
Returns:
The model's response as a string
"""
#print("text input: ",text_input)
messages = []
# Handle image input
if image_input is not None:
try:
image_base64 = self._encode_image_to_base64(image_input)
if isinstance(text_input, list):
# If text_input is already a list of messages, use it directly
messages = text_input
elif isinstance(text_input, str):
# Create a message with both text and image
messages = [{"role": "system", "content": self.system_message}] if not maintain_history else []
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": text_input
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
})
except Exception as e:
print(f"Warning: Failed to process image input: {str(e)}")
# Fall back to text-only input
image_input = None
# Handle text-only input or fallback from image processing failure
if image_input is None:
if isinstance(text_input, list):
# If text_input is already a list of messages, use it directly
messages = text_input
elif text_input is not None:
text_input = str(text_input)
# Add user message to history if maintaining history
if maintain_history:
self.add_user_message(text_input)
messages = self.conversation_history
else:
# For single-turn interactions without affecting history
messages = [{"role": "system", "content": self.system_message},
{"role": "user", "content": text_input}]
#print(type(text_input))
for attempt in range(self.retry_attempts):
try:
#print(messages)
# Prepare chat completion parameters
chat_params = {
'model': self.model_name,
'messages': messages,
**self.chat_kwargs
}
# Add explicit parameters if not in kwargs
if self.temperature is not None:
chat_params['temperature'] = self.temperature
if self.max_tokens is not None:
chat_params['max_tokens'] = self.max_tokens
response = self.client.chat.completions.create(**chat_params)
response_text = response.choices[0].message.content
# Add assistant response to history if maintaining history
if maintain_history:
self.add_assistant_message(response_text)
return response_text
except Exception as e:
print(f"Unexpected error: {str(e)}")
if attempt == self.retry_attempts - 1:
return f"Error: {str(e)}"
# Note: time.sleep removed to prevent blocking in asyncio environments
# For retry delays, the calling async code should handle timing
return "Error: Failed to get response from model"
def add_user_message(self, content: Union[str, List[Dict]]) -> None:
"""Add a user message to the conversation history."""
self.conversation_history.append({"role": "user", "content": content})
def add_assistant_message(self, content: Union[str, List[Dict]]) -> None:
"""Add an assistant message to the conversation history."""
self.conversation_history.append({"role": "assistant", "content": content})
def add_system_message(self, content: str) -> None:
"""Add a system message to the conversation history."""
self.conversation_history.append({"role": "system", "content": content})
def remove_last_turn(self) -> None:
"""
Remove the last turn of conversation (last user message and its corresponding assistant message).
Assumes conversation is stored sequentially as messages.
"""
if not self.conversation_history:
return
for idx in range(len(self.conversation_history) - 1, -1, -1):
if self.conversation_history[idx]["role"] == "user":
self.conversation_history = self.conversation_history[:idx]
break
def reset_conversation(self) -> None:
"""Reset the conversation history to only include the initial system message."""
self.conversation_history = [{"role": "system", "content": self.system_message}]
def get_conversation_history(self) -> List[Dict[str, str]]:
"""Get the current conversation history."""
return self.conversation_history
def set_system_message(self, system_message: str) -> None:
"""Set a new system message and reset the conversation."""
self.system_message = system_message
self.reset_conversation()
def reset_system_message(self) -> None:
"""Reset the system message to the default."""
self.system_message = "You are a helpful assistant."
self.reset_conversation()
def get_embedding(self, text_input: str, model: str = None) -> list[float]:
"""
Get embedding vector for the given text using OpenAI embedding model.
Args:
text_input: The input text to embed
model: The embedding model to use (default: uses self.embedding_model)
Returns:
A list of floats representing the embedding vector
"""
try:
import torch
clean_text = text_input.replace("\n", " ")
# Use specified model or default embedding model
embedding_model = model or self.embedding_model
# Prepare embedding parameters
embedding_params = {
'input': clean_text,
'model': embedding_model,
**self.embedding_kwargs
}
response = self.client.embeddings.create(**embedding_params)
embedding = response.data[0].embedding
return torch.tensor(embedding)
except Exception as e:
print(f"Error while generating embedding: {str(e)}")
return []