from ..base_model import BaseModel from ...core.registry import model_registry import openai import time import base64 from io import BytesIO from typing import Any, Optional, List, Dict, Union from PIL import Image @model_registry.register("openai") class OpenAIModel(BaseModel): """ OpenAI API model wrapper for jailbreak toolbox. Supports various OpenAI models like gpt-3.5-turbo and gpt-4. """ def __init__(self, api_key: str, base_url: Optional[str] = None, model_name: str = "gpt-3.5-turbo", temperature: float = 0.7, max_tokens: int = None, retry_attempts: int = 3, retry_delay: float = 2.0, system_message: str = "You are a helpful assistant.", embedding_model: str = "text-embedding-3-small", **kwargs): """ Initialize the OpenAI model. Args: api_key: OpenAI API key base_url: Optional base URL for OpenAI API model_name: Name of the OpenAI model to use (e.g., "gpt-3.5-turbo", "gpt-4") temperature: Sampling temperature (0.0-2.0) max_tokens: Maximum tokens in the response retry_attempts: Number of retry attempts for API calls retry_delay: Delay between retry attempts in seconds system_message: System message to set the assistant's behavior embedding_model: Model to use for generating embeddings """ super().__init__(**kwargs) self.model_name = model_name self.temperature = temperature self.max_tokens = max_tokens self.retry_attempts = retry_attempts self.retry_delay = retry_delay self.system_message = system_message self.embedding_model = embedding_model # Initialize conversation history self.conversation_history = [{"role": "system", "content": system_message}] # Filter kwargs for different OpenAI API functions self.client_kwargs = self._filter_client_kwargs(kwargs) self.chat_kwargs = self._filter_chat_kwargs(kwargs) self.embedding_kwargs = self._filter_embedding_kwargs(kwargs) # Configure OpenAI client openai.api_key = api_key if base_url: openai.base_url = base_url self.client = openai.OpenAI(api_key=api_key, base_url=base_url, **self.client_kwargs) print(f"Initialized OpenAI model: {model_name}") def _filter_client_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Filter kwargs suitable for OpenAI client initialization""" # Valid parameters for OpenAI() client valid_client_params = { 'timeout', 'max_retries', 'default_headers', 'default_query', 'http_client', 'api_key', 'base_url', 'organization', 'project' } return {k: v for k, v in kwargs.items() if k in valid_client_params} def _filter_chat_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Filter kwargs suitable for chat.completions.create()""" # Valid parameters for chat completions valid_chat_params = { 'frequency_penalty', 'logit_bias', 'logprobs', 'top_logprobs', 'max_tokens', 'n', 'presence_penalty', 'response_format', 'seed', 'stop', 'stream', 'temperature', 'top_p', 'tools', 'tool_choice', 'parallel_tool_calls', 'user' } return {k: v for k, v in kwargs.items() if k in valid_chat_params} def _filter_embedding_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Filter kwargs suitable for embeddings.create()""" # Valid parameters for embeddings valid_embedding_params = { 'encoding_format', 'dimensions', 'user' } return {k: v for k, v in kwargs.items() if k in valid_embedding_params} def _encode_image_to_base64(self, image_input: Union[str, Any]) -> str: """Encode an image to base64 string. Supports both file paths and PIL Image objects. Args: image_input: Path to image file or PIL Image object Returns: Base64 encoded string of the image """ # Check if it's a file path (string) if isinstance(image_input, str): with open(image_input, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') # Check if it's a PIL Image object if isinstance(image_input, Image.Image): buffered = BytesIO() image_input.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode('utf-8') raise ValueError("image_input must be either a file path (string) or a PIL Image object") def query(self, text_input: Union[str, List[Dict]] = "", image_input: Any = None, maintain_history: bool = False) -> str: """ Send a query to the OpenAI API and return the response. Args: text_input: The prompt text to send (can be string or list of message dicts) image_input: Path to image file or PIL Image object for vision models maintain_history: Whether to add this exchange to conversation history Returns: The model's response as a string """ #print("text input: ",text_input) messages = [] # Handle image input if image_input is not None: try: image_base64 = self._encode_image_to_base64(image_input) if isinstance(text_input, list): # If text_input is already a list of messages, use it directly messages = text_input elif isinstance(text_input, str): # Create a message with both text and image messages = [{"role": "system", "content": self.system_message}] if not maintain_history else [] messages.append({ "role": "user", "content": [ { "type": "text", "text": text_input }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{image_base64}" }, }, ], }) except Exception as e: print(f"Warning: Failed to process image input: {str(e)}") # Fall back to text-only input image_input = None # Handle text-only input or fallback from image processing failure if image_input is None: if isinstance(text_input, list): # If text_input is already a list of messages, use it directly messages = text_input elif text_input is not None: text_input = str(text_input) # Add user message to history if maintaining history if maintain_history: self.add_user_message(text_input) messages = self.conversation_history else: # For single-turn interactions without affecting history messages = [{"role": "system", "content": self.system_message}, {"role": "user", "content": text_input}] #print(type(text_input)) for attempt in range(self.retry_attempts): try: #print(messages) # Prepare chat completion parameters chat_params = { 'model': self.model_name, 'messages': messages, **self.chat_kwargs } # Add explicit parameters if not in kwargs if self.temperature is not None: chat_params['temperature'] = self.temperature if self.max_tokens is not None: chat_params['max_tokens'] = self.max_tokens response = self.client.chat.completions.create(**chat_params) response_text = response.choices[0].message.content # Add assistant response to history if maintaining history if maintain_history: self.add_assistant_message(response_text) return response_text except Exception as e: print(f"Unexpected error: {str(e)}") if attempt == self.retry_attempts - 1: return f"Error: {str(e)}" # Note: time.sleep removed to prevent blocking in asyncio environments # For retry delays, the calling async code should handle timing return "Error: Failed to get response from model" def add_user_message(self, content: Union[str, List[Dict]]) -> None: """Add a user message to the conversation history.""" self.conversation_history.append({"role": "user", "content": content}) def add_assistant_message(self, content: Union[str, List[Dict]]) -> None: """Add an assistant message to the conversation history.""" self.conversation_history.append({"role": "assistant", "content": content}) def add_system_message(self, content: str) -> None: """Add a system message to the conversation history.""" self.conversation_history.append({"role": "system", "content": content}) def remove_last_turn(self) -> None: """ Remove the last turn of conversation (last user message and its corresponding assistant message). Assumes conversation is stored sequentially as messages. """ if not self.conversation_history: return for idx in range(len(self.conversation_history) - 1, -1, -1): if self.conversation_history[idx]["role"] == "user": self.conversation_history = self.conversation_history[:idx] break def reset_conversation(self) -> None: """Reset the conversation history to only include the initial system message.""" self.conversation_history = [{"role": "system", "content": self.system_message}] def get_conversation_history(self) -> List[Dict[str, str]]: """Get the current conversation history.""" return self.conversation_history def set_system_message(self, system_message: str) -> None: """Set a new system message and reset the conversation.""" self.system_message = system_message self.reset_conversation() def reset_system_message(self) -> None: """Reset the system message to the default.""" self.system_message = "You are a helpful assistant." self.reset_conversation() def get_embedding(self, text_input: str, model: str = None) -> list[float]: """ Get embedding vector for the given text using OpenAI embedding model. Args: text_input: The input text to embed model: The embedding model to use (default: uses self.embedding_model) Returns: A list of floats representing the embedding vector """ try: import torch clean_text = text_input.replace("\n", " ") # Use specified model or default embedding model embedding_model = model or self.embedding_model # Prepare embedding parameters embedding_params = { 'input': clean_text, 'model': embedding_model, **self.embedding_kwargs } response = self.client.embeddings.create(**embedding_params) embedding = response.data[0].embedding return torch.tensor(embedding) except Exception as e: print(f"Error while generating embedding: {str(e)}") return []