mirror of
https://github.com/Tokfinity/InfCode.git
synced 2026-02-13 05:32:44 +00:00
initial
This commit is contained in:
16
src/managers/__init__.py
Normal file
16
src/managers/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Export all modules
|
||||
from .image_builder import *
|
||||
from .log import *
|
||||
from .llm_api import *
|
||||
from .prompts import *
|
||||
from .loop import *
|
||||
from .decorators import *
|
||||
|
||||
__all__ = [
|
||||
"image_builder",
|
||||
"log",
|
||||
"llm_api",
|
||||
"prompts",
|
||||
"loop",
|
||||
"decorators"
|
||||
]
|
||||
13
src/managers/decorators/singleton.py
Normal file
13
src/managers/decorators/singleton.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import threading
|
||||
|
||||
def singleton(cls):
|
||||
instances = {}
|
||||
lock = threading.Lock()
|
||||
|
||||
def get_instance(*args, **kwargs):
|
||||
if cls not in instances:
|
||||
with lock:
|
||||
if cls not in instances:
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
return get_instance
|
||||
8
src/managers/image_builder/__init__.py
Normal file
8
src/managers/image_builder/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# 导出全部模块
|
||||
from .build_image import *
|
||||
from .dockerfiles import *
|
||||
|
||||
__all__ = [
|
||||
"build_image",
|
||||
"dockerfiles"
|
||||
]
|
||||
1009
src/managers/image_builder/build_image.py
Normal file
1009
src/managers/image_builder/build_image.py
Normal file
File diff suppressed because it is too large
Load Diff
8
src/managers/image_builder/dockerfiles.py
Normal file
8
src/managers/image_builder/dockerfiles.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# User image Dockerfile template
|
||||
# image_key: instance_image_key corresponding to instance_id
|
||||
# apt_packages: split by ' ', load multi packages needed to be installed in ubuntu
|
||||
_DOCKERFILE_USER_IMAGE_PY = r"""
|
||||
FROM {image_key}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends {apt_packages}
|
||||
"""
|
||||
128
src/managers/image_builder/logger_patch.py
Normal file
128
src/managers/image_builder/logger_patch.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
User defined logger for swebench
|
||||
"""
|
||||
import logging
|
||||
from typing import Union, Optional
|
||||
|
||||
from swebench.harness.docker_utils import remove_image
|
||||
from src.managers.log.logger import Logger as CustomLogger
|
||||
|
||||
# 导入原始的 swebench 函数
|
||||
from swebench.harness.docker_build import build_instance_image, run_threadpool
|
||||
|
||||
|
||||
def patched_build_instance_images(
|
||||
client,
|
||||
dataset: list,
|
||||
force_rebuild: bool = False,
|
||||
max_workers: int = 4,
|
||||
namespace: str = None,
|
||||
tag: str = None,
|
||||
env_image_tag: str = None,
|
||||
custom_logger: Optional[Union[logging.Logger, CustomLogger]] = None, # 新增参数
|
||||
):
|
||||
"""
|
||||
Monkey patched version of build_instance_images that supports custom logger
|
||||
|
||||
Args:
|
||||
client: Docker client
|
||||
dataset: List of test specs or dataset to build images for
|
||||
force_rebuild: Whether to force rebuild the images even if they already exist
|
||||
max_workers: Maximum number of worker threads
|
||||
namespace: Namespace for images
|
||||
tag: Tag for images
|
||||
env_image_tag: Environment image tag
|
||||
custom_logger: Custom logger to use (instead of creating new ones)
|
||||
"""
|
||||
from swebench.harness.docker_build import make_test_spec, build_env_images
|
||||
|
||||
test_specs = []
|
||||
for instance in dataset:
|
||||
spec = make_test_spec(
|
||||
instance,
|
||||
namespace=namespace,
|
||||
instance_image_tag=tag,
|
||||
env_image_tag=env_image_tag,
|
||||
)
|
||||
test_specs.append(spec)
|
||||
|
||||
if force_rebuild:
|
||||
for spec in test_specs:
|
||||
remove_image(client, spec.instance_image_key, "quiet")
|
||||
|
||||
_, env_failed = build_env_images(client, test_specs, force_rebuild, max_workers)
|
||||
|
||||
if len(env_failed) > 0:
|
||||
# Don't build images for instances that depend on failed-to-build env images
|
||||
dont_run_specs = [
|
||||
spec for spec in test_specs if spec.env_image_key in env_failed
|
||||
]
|
||||
test_specs = [
|
||||
spec for spec in test_specs if spec.env_image_key not in env_failed
|
||||
]
|
||||
if custom_logger:
|
||||
custom_logger.info(
|
||||
f"Skipping {len(dont_run_specs)} instances - due to failed env image builds"
|
||||
)
|
||||
else:
|
||||
print(f"Skipping {len(dont_run_specs)} instances - due to failed env image builds")
|
||||
|
||||
if custom_logger:
|
||||
custom_logger.info(f"Building instance images for {len(test_specs)} instances")
|
||||
else:
|
||||
print(f"Building instance images for {len(test_specs)} instances")
|
||||
|
||||
successful, failed = [], []
|
||||
|
||||
if custom_logger:
|
||||
payloads = [(spec, client, custom_logger, False) for spec in test_specs]
|
||||
else:
|
||||
payloads = [(spec, client, None, False) for spec in test_specs]
|
||||
|
||||
successful, failed = run_threadpool(build_instance_image, payloads, max_workers)
|
||||
|
||||
if len(failed) == 0:
|
||||
if custom_logger:
|
||||
custom_logger.info("All instance images built successfully.")
|
||||
else:
|
||||
print("All instance images built successfully.")
|
||||
else:
|
||||
if custom_logger:
|
||||
custom_logger.warning(f"{len(failed)} instance images failed to build.")
|
||||
else:
|
||||
print(f"{len(failed)} instance images failed to build.")
|
||||
return successful, failed
|
||||
|
||||
|
||||
def apply_logger_patch():
|
||||
"""应用 logger patch"""
|
||||
import swebench.harness.docker_build as docker_build_module
|
||||
|
||||
original_build_instance_images = docker_build_module.build_instance_images
|
||||
|
||||
docker_build_module.build_instance_images = patched_build_instance_images
|
||||
|
||||
return original_build_instance_images
|
||||
|
||||
|
||||
def restore_logger_patch(original_function):
|
||||
"""Recover original logger actions"""
|
||||
import swebench.harness.docker_build as docker_build_module
|
||||
docker_build_module.build_instance_images = original_function
|
||||
|
||||
|
||||
# 上下文管理器版本
|
||||
class LoggerPatch:
|
||||
"""Context manager"""
|
||||
|
||||
def __init__(self, logger: Optional[Union[logging.Logger, CustomLogger]] = None):
|
||||
self.logger = logger
|
||||
self.original_function = None
|
||||
|
||||
def __enter__(self):
|
||||
self.original_function = apply_logger_patch()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.original_function:
|
||||
restore_logger_patch(self.original_function)
|
||||
132
src/managers/image_builder/print_redirect.py
Normal file
132
src/managers/image_builder/print_redirect.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Redirect print output of a third party repo to a specific log
|
||||
"""
|
||||
import sys
|
||||
import logging
|
||||
import re
|
||||
from typing import Union, Set, Pattern
|
||||
from typing import Optional, TextIO
|
||||
from pathlib import Path
|
||||
from src.managers.log.logger import Logger as CustomLogger
|
||||
|
||||
|
||||
class StreamWrapper:
|
||||
"""Simple stream wrapper,for redirecting stdout/stderr"""
|
||||
|
||||
def __init__(self, original_stream, write_func):
|
||||
self.original_stream = original_stream
|
||||
self.write_func = write_func
|
||||
self.buffer = ""
|
||||
|
||||
def write(self, text):
|
||||
self.buffer += text
|
||||
|
||||
while '\n' in self.buffer:
|
||||
line, self.buffer = self.buffer.split('\n', 1)
|
||||
if line.strip():
|
||||
self.write_func(line + '\n')
|
||||
|
||||
return len(text)
|
||||
|
||||
def flush(self):
|
||||
if self.buffer:
|
||||
if self.buffer.strip():
|
||||
self.write_func(self.buffer)
|
||||
self.buffer = ""
|
||||
if hasattr(self.original_stream, 'flush'):
|
||||
self.original_stream.flush()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.original_stream, name)
|
||||
|
||||
|
||||
class PrintRedirector:
|
||||
"""Redirect print output of a third party repo to a specific log"""
|
||||
|
||||
TRACEBACK_START_PATTERN = re.compile(r'Traceback\s*\(most recent call last\):', re.IGNORECASE)
|
||||
EXCEPTION_END_PATTERN = re.compile(
|
||||
r'\b(Error|Exception|KeyError|ValueError|TypeError|AttributeError|'
|
||||
r'ImportError|BuildError|DockerError|IndentationError|SyntaxError|'
|
||||
r'RuntimeError|OSError|FileNotFoundError|PermissionError)\s*:',
|
||||
re.IGNORECASE
|
||||
)
|
||||
|
||||
ERROR_KEYWORDS = {'error', 'failed', 'exception', 'fatal', 'critical'}
|
||||
WARNING_KEYWORDS = {'warning', 'warn', 'deprecated'}
|
||||
SKIP_KEYWORDS = {'skipping', 'skip', 'ignoring', 'ignore'}
|
||||
|
||||
def __init__(self, logger: Union[logging.Logger, CustomLogger]):
|
||||
self.logger = logger
|
||||
self.original_print = None
|
||||
self.original_stdout = None
|
||||
self.original_stderr = None
|
||||
|
||||
self.traceback_buffer = []
|
||||
self.in_traceback = False
|
||||
|
||||
def __enter__(self):
|
||||
self.start_redirect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop_redirect()
|
||||
|
||||
def start_redirect(self):
|
||||
if isinstance(__builtins__, dict):
|
||||
self.original_print = __builtins__['print']
|
||||
else:
|
||||
self.original_print = getattr(__builtins__, 'print')
|
||||
if isinstance(__builtins__, dict):
|
||||
__builtins__['print'] = self._redirected_print
|
||||
else:
|
||||
setattr(__builtins__, 'print', self._redirected_print)
|
||||
|
||||
def stop_redirect(self):
|
||||
if self.original_print:
|
||||
if isinstance(__builtins__, dict):
|
||||
__builtins__['print'] = self.original_print
|
||||
else:
|
||||
setattr(__builtins__, 'print', self.original_print)
|
||||
|
||||
def _redirected_print(self, *args, **kwargs):
|
||||
if not args:
|
||||
return
|
||||
output = ' '.join(str(arg) for arg in args)
|
||||
|
||||
if self.TRACEBACK_START_PATTERN.search(output):
|
||||
self.in_traceback = True
|
||||
self.traceback_buffer = [output]
|
||||
return
|
||||
|
||||
if self.in_traceback:
|
||||
self.traceback_buffer.append(output)
|
||||
|
||||
if self.EXCEPTION_END_PATTERN.search(output):
|
||||
self._log_traceback_and_reset()
|
||||
return
|
||||
|
||||
self._log_by_level(output)
|
||||
|
||||
def _log_traceback_and_reset(self):
|
||||
full_traceback = '\n'.join(self.traceback_buffer)
|
||||
self.logger.error(f"[Third party repo exception stack]\n{full_traceback}")
|
||||
|
||||
self.in_traceback = False
|
||||
self.traceback_buffer.clear()
|
||||
|
||||
def _log_by_level(self, output: str):
|
||||
output_lower = output.lower()
|
||||
|
||||
if any(keyword in output_lower for keyword in self.ERROR_KEYWORDS):
|
||||
self.logger.error(f"[Third party] {output}")
|
||||
elif any(keyword in output_lower for keyword in self.WARNING_KEYWORDS):
|
||||
self.logger.warning(f"[Third party] {output}")
|
||||
elif any(keyword in output_lower for keyword in self.SKIP_KEYWORDS):
|
||||
self.logger.info(f"[Third party] {output}")
|
||||
else:
|
||||
self.logger.info(f"[Third party] {output}")
|
||||
|
||||
|
||||
def redirect_swebench_prints(logger: Union[logging.Logger, CustomLogger]):
|
||||
return PrintRedirector(logger)
|
||||
|
||||
98
src/managers/llm_api/__init__.py
Normal file
98
src/managers/llm_api/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
LLM API Management Module
|
||||
|
||||
Provide a standard OpenAI format LLM API interface that supports:
|
||||
- Unified Chat Completions endpoint
|
||||
- Synchronous and asynchronous operations
|
||||
- Streaming responses
|
||||
- Tool Calling
|
||||
- Error handling and retry mechanism(s)
|
||||
|
||||
Supported providers:
|
||||
- OpenAI: OpenAI API and compatible services
|
||||
- Anthropic: Claude models
|
||||
- DeepSeek: DeepSeek models
|
||||
- Private: Private models(vLLM、TGI、Ollama etc.)
|
||||
|
||||
Usage example::
|
||||
# 1: use general manager (suggested)
|
||||
from llm_api import LLMAPIManager
|
||||
|
||||
# create manager
|
||||
manager = LLMAPIManager(
|
||||
client_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
stream=False
|
||||
)
|
||||
|
||||
# sendf messages
|
||||
response = manager.chat("Hello world!")
|
||||
print(response)
|
||||
|
||||
# 2: Use client directly
|
||||
from llm_api import OpenAIClient, ChatMessage, MessageRole
|
||||
|
||||
# creat client
|
||||
client = OpenAIClient(api_key="your-api-key")
|
||||
|
||||
# create4 message
|
||||
messages = [
|
||||
ChatMessage(role=MessageRole.USER, content="你好,世界!")
|
||||
]
|
||||
|
||||
# send request
|
||||
request = client.create_request(messages=messages, model="gpt-3.5-turbo")
|
||||
response = client.chat_completions_create(request)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
# 客户端实现
|
||||
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
|
||||
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
|
||||
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
|
||||
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
|
||||
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
|
||||
|
||||
# API 管理器
|
||||
from src.managers.llm_api.api_manager import (
|
||||
LLMAPIManager,
|
||||
create_manager,
|
||||
create_common_manager,
|
||||
COMMON_CONFIGS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMAPI",
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatCompletionResponse",
|
||||
"ChatCompletionChunk",
|
||||
"MessageRole",
|
||||
"Choice",
|
||||
"Usage",
|
||||
"OpenAIClient",
|
||||
"AnthropicClient",
|
||||
"DeepSeekClient",
|
||||
"OpenRouterClient",
|
||||
"PrivateModelClient",
|
||||
"LLMAPIManager",
|
||||
"create_manager",
|
||||
"create_common_manager",
|
||||
"COMMON_CONFIGS",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Tokfinity Team"
|
||||
__description__ = "标准 OpenAI 格式的 LLM API 基类库"
|
||||
479
src/managers/llm_api/api_manager.py
Normal file
479
src/managers/llm_api/api_manager.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""
|
||||
LLM API manager
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Union, Generator
|
||||
from dotenv import load_dotenv
|
||||
import yaml
|
||||
from traceback import format_exc
|
||||
from .base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
|
||||
# 导入所有客户端
|
||||
from .clients.openai.openai_client import OpenAIClient
|
||||
from .clients.anthropic.anthropic_client import AnthropicClient
|
||||
from .clients.deepseek.deepseek_client import DeepSeekClient
|
||||
from .clients.openrouter.openrouter_client import OpenRouterClient
|
||||
from .clients.private.private_client import PrivateModelClient
|
||||
|
||||
|
||||
class LLMAPIManager:
|
||||
SUPPORTED_CLIENTS = {
|
||||
"openai": OpenAIClient,
|
||||
"anthropic": AnthropicClient,
|
||||
"deepseek": DeepSeekClient,
|
||||
"openrouter": OpenRouterClient,
|
||||
"private": PrivateModelClient,
|
||||
}
|
||||
|
||||
DEFAULT_CONFIGS = {
|
||||
"openai": {"base_url": None, "api_key_env": "OPENAI_API_KEY"},
|
||||
"anthropic": {"base_url": None, "api_key_env": "ANTHROPIC_API_KEY"},
|
||||
"deepseek": {"base_url": None, "api_key_env": "DEEPSEEK_API_KEY"},
|
||||
"openrouter": {
|
||||
"base_url": None,
|
||||
"api_key_env": "OPENROUTER_API_KEY",
|
||||
"extra_config": {
|
||||
"app_name": "tokfinity-llm-client",
|
||||
"site_url": "https://github.com/your-repo",
|
||||
},
|
||||
},
|
||||
"private": {
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key_env": "PRIVATE_API_KEY",
|
||||
"extra_config": {"deployment_type": "vllm"},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_name: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
auto_load_env: bool = True,
|
||||
logger: Optional[Any] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# if client not provided, get first provider from default providers in config
|
||||
if client_name is None:
|
||||
default_client, default_model = (
|
||||
self._load_default_client_and_model_from_config()
|
||||
)
|
||||
self.client_name = default_client
|
||||
self.default_model = default_model
|
||||
else:
|
||||
self.client_name = client_name.lower()
|
||||
self.default_model = None
|
||||
|
||||
self.stream = stream
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.logger = logger
|
||||
self.logger.info(f"[LLMAPIManager]: Using client: {self.client_name}, model: {self.default_model}.")
|
||||
self.config = config
|
||||
|
||||
if auto_load_env:
|
||||
self._load_environment()
|
||||
|
||||
if self.client_name not in self.SUPPORTED_CLIENTS:
|
||||
raise ValueError(
|
||||
f"Unsupported client: {client_name}。"
|
||||
f"Support client: {list(self.SUPPORTED_CLIENTS.keys())}"
|
||||
)
|
||||
|
||||
self.client = self._create_client(api_key, base_url, logger, **kwargs)
|
||||
|
||||
|
||||
def _load_environment(self) -> None:
|
||||
"""load environment variables"""
|
||||
env_paths = [".env", "../.env", "../../.env", "../../../.env"]
|
||||
|
||||
for env_path in env_paths:
|
||||
if os.path.exists(env_path):
|
||||
load_dotenv(env_path)
|
||||
break
|
||||
|
||||
def _create_client(
|
||||
self, api_key: Optional[str] = None, base_url: Optional[str] = None, logger: Optional[Any] = None, **kwargs
|
||||
) -> BaseLLMAPI:
|
||||
client_class = self.SUPPORTED_CLIENTS[self.client_name]
|
||||
config = self.DEFAULT_CONFIGS[self.client_name]
|
||||
|
||||
if api_key is None:
|
||||
api_key = os.getenv(config["api_key_env"])
|
||||
if not api_key:
|
||||
if self.client_name == "private":
|
||||
api_key = "EMPTY" # private mode may not need a key
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Fail to find env variable, please set: {config['api_key_env']} "
|
||||
f"or upload ai_key parameter when initialize"
|
||||
)
|
||||
|
||||
if base_url is None:
|
||||
env_key = f"{self.client_name.upper()}_BASE_URL"
|
||||
if self.client_name == "private":
|
||||
env_key = "PRIVATE_URL"
|
||||
|
||||
base_url = os.getenv(env_key)
|
||||
if base_url is None:
|
||||
base_url = config.get("base_url")
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": api_key,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
if logger is not None:
|
||||
client_kwargs["logger"] = logger
|
||||
|
||||
extra_config = config.get("extra_config", {})
|
||||
client_kwargs.update(extra_config)
|
||||
client_kwargs.update(kwargs)
|
||||
|
||||
if self.client_name == "openrouter":
|
||||
client_kwargs.setdefault("app_name", "tokfinity-llm-client")
|
||||
elif self.client_name == "private":
|
||||
client_kwargs.setdefault("deployment_type", "vllm")
|
||||
|
||||
return client_class(**client_kwargs)
|
||||
|
||||
def _load_default_client_and_model_from_config(self) -> (str, Optional[str]):
|
||||
"""
|
||||
Get first item from providers from config/config.yaml as the default client
|
||||
And take the first model as default model
|
||||
"""
|
||||
# Parse the root of config file (relative to the project root)
|
||||
# This file is in src/managers/llm_api/api_manager.py
|
||||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
|
||||
config_path = os.path.join(base_dir, "config", "config.yaml")
|
||||
if not os.path.exists(config_path):
|
||||
raise FileNotFoundError(f"No config file: {config_path}")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
providers = cfg.get("providers")
|
||||
if not isinstance(providers, dict) or len(providers) == 0:
|
||||
raise ValueError("config.yaml lack of providers config or format error")
|
||||
|
||||
first_provider_name = next(iter(providers.keys()))
|
||||
models = providers.get(first_provider_name) or []
|
||||
first_model = (
|
||||
models[0] if isinstance(models, list) and len(models) > 0 else None
|
||||
)
|
||||
|
||||
client_key = first_provider_name.strip().lower()
|
||||
if client_key not in self.SUPPORTED_CLIENTS:
|
||||
raise ValueError(
|
||||
f"Default provider '{first_provider_name}' in config not registered in SUPPORTED_CLIENTS"
|
||||
)
|
||||
|
||||
return client_key, first_model
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Union[Dict[str, Any], ChatMessage]],
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
retry: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None
|
||||
]:
|
||||
"""
|
||||
Send chat message and get response
|
||||
|
||||
Args:
|
||||
model: model name (optional, default self.default_model)
|
||||
messages: Concatenated message list.
|
||||
- If list of dict: [{"role": "system|user|assistant|tool", "content": "..."}]
|
||||
- Or `ChatMessage` list
|
||||
temperature: Temperature
|
||||
max_tokens: Max tokens
|
||||
timeout: Request timeout, use the value from initialization if not specified
|
||||
retry: Max retries, use the value from initialization if not specified
|
||||
tools: Tools description list (OpenAI tools format)
|
||||
tool_choice: Tool choice strategy("auto" | "none" | {"type":..., ...})
|
||||
**kwargs: Other request args
|
||||
|
||||
Returns:
|
||||
Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None]:
|
||||
- Non-streaming: Return the complete ChatCompletionResponse object
|
||||
- Stream: Return a streaming response generator
|
||||
- Fail: Return None
|
||||
"""
|
||||
actual_timeout = timeout if timeout is not None else self.timeout
|
||||
actual_retry = retry if retry is not None else self.max_retries
|
||||
|
||||
actual_model = model if model is not None else self.default_model
|
||||
if not actual_model:
|
||||
raise ValueError("Model unprovided, and cannot get default model from config")
|
||||
|
||||
normalized_messages: List[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
if msg.content is None:
|
||||
msg.content = ""
|
||||
normalized_messages.append(msg)
|
||||
else:
|
||||
role_value = msg.get("role")
|
||||
content_value = msg.get("content")
|
||||
|
||||
if content_value is None:
|
||||
content_value = ""
|
||||
|
||||
role_enum = MessageRole(role_value)
|
||||
normalized_messages.append(
|
||||
ChatMessage(
|
||||
role=role_enum,
|
||||
content=content_value,
|
||||
name=msg.get("name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
)
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
for attempt in range(actual_retry + 1):
|
||||
try:
|
||||
request = self.client.create_request(
|
||||
messages=normalized_messages,
|
||||
model=actual_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=self.stream,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
config=self.config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
response = self.client.chat_completions_create(request)
|
||||
|
||||
if self.stream:
|
||||
# Streaming response: has not processed
|
||||
return response
|
||||
else:
|
||||
# Non-streaming response: return complete ChatCompletionResponse
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < actual_retry:
|
||||
delay = min(2**attempt, 30)
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"{attempt + 1}th try failed,retry after {delay} seconds: {str(e)}, traceback: {format_exc()}."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"All {actual_retry + 1} tries failed: {str(e)}"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def chat_simple(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Union[Dict[str, Any], ChatMessage]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
retry: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> Optional[str]:
|
||||
# Ban stream
|
||||
original_stream = self.stream
|
||||
self.stream = False
|
||||
|
||||
try:
|
||||
response = self.chat(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
timeout=timeout,
|
||||
retry=retry,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if response and hasattr(response, "choices") and response.choices:
|
||||
return response.choices[0].message.content
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Recover original stream settings
|
||||
self.stream = original_stream
|
||||
|
||||
def create_embeddings(
|
||||
self,
|
||||
input_text: Union[str, List[str]],
|
||||
model: str,
|
||||
encoding_format: str = "float",
|
||||
dimensions: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
retry: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Optional[EmbeddingResponse]:
|
||||
actual_timeout = timeout if timeout is not None else self.timeout
|
||||
actual_retry = retry if retry is not None else self.max_retries
|
||||
|
||||
if isinstance(input_text, str):
|
||||
text_list = [input_text]
|
||||
single_input = True
|
||||
else:
|
||||
text_list = input_text
|
||||
single_input = False
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(
|
||||
f"Start embedding request - client: {self.client_name}, model: {model}, text num: {len(text_list)}"
|
||||
)
|
||||
|
||||
if not hasattr(self.client, "create_embeddings"):
|
||||
error_msg = f"Client {self.client_name} not support embedding"
|
||||
if self.logger:
|
||||
self.logger.error(error_msg)
|
||||
return None
|
||||
|
||||
last_exception = None
|
||||
for attempt in range(actual_retry + 1):
|
||||
try:
|
||||
if self.logger:
|
||||
self.logger.debug(f"{attempt + 1}th try to create embedding vector")
|
||||
|
||||
response = self.client.create_embeddings(
|
||||
input_text=text_list,
|
||||
model=model,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
timeout=actual_timeout,
|
||||
max_retries=1,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
else:
|
||||
raise Exception("Client return empty response")
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < actual_retry:
|
||||
delay = min(2**attempt, 30)
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"{attempt + 1}th embedding request failed,retry after {delay}s: {str(e)}, traceback: {format_exc()}."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"All {actual_retry + 1} embedding tries failed: {str(e)}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_client_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"client_name": self.client_name,
|
||||
"stream": self.stream,
|
||||
"client_info": (
|
||||
self.client.get_model_info()
|
||||
if hasattr(self.client, "get_model_info")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.default_model or "unknown"
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self.client, "close"):
|
||||
self.client.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"LLMAPIManager(client={self.client_name}, stream={self.stream})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"LLMAPIManager("
|
||||
f"client_name='{self.client_name}', "
|
||||
f"stream={self.stream})"
|
||||
)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_manager(
|
||||
client_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
|
||||
) -> LLMAPIManager:
|
||||
return LLMAPIManager(
|
||||
client_name=client_name, stream=stream, logger=logger, **kwargs
|
||||
)
|
||||
|
||||
|
||||
COMMON_CONFIGS = {
|
||||
"openai_gpt4": {"client_name": "openai", "model_name": "gpt-4o"},
|
||||
"openai_gpt35": {"client_name": "openai", "model_name": "gpt-3.5-turbo"},
|
||||
"claude_sonnet": {
|
||||
"client_name": "anthropic",
|
||||
"model_name": "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
"claude_haiku": {
|
||||
"client_name": "anthropic",
|
||||
"model_name": "claude-3-haiku-20240307",
|
||||
},
|
||||
"deepseek_chat": {"client_name": "deepseek", "model_name": "deepseek-chat"},
|
||||
"deepseek_coder": {"client_name": "deepseek", "model_name": "deepseek-coder"},
|
||||
}
|
||||
|
||||
|
||||
def create_common_manager(
|
||||
config_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
|
||||
) -> LLMAPIManager:
|
||||
if config_name not in COMMON_CONFIGS:
|
||||
raise ValueError(
|
||||
f"Unknown config: {config_name}. Available configs: {list(COMMON_CONFIGS.keys())}"
|
||||
)
|
||||
|
||||
config = COMMON_CONFIGS[config_name]
|
||||
return LLMAPIManager(
|
||||
client_name=config["client_name"], stream=stream, logger=logger, **kwargs
|
||||
)
|
||||
594
src/managers/llm_api/base_client.py
Normal file
594
src/managers/llm_api/base_client.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""
|
||||
LLM API base class - standard OpenAI format
|
||||
Multi LLM providers' universal API supported
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
import asyncio
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
|
||||
class MessageRole(Enum):
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: MessageRole
|
||||
content: Union[str, List[Dict[str, Any]]]
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionRequest:
|
||||
messages: List[ChatMessage]
|
||||
model: str
|
||||
temperature: float = 0.7
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: float = 1.0
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||
response_format: Optional[Dict[str, Any]] = None
|
||||
seed: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Usage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Choice:
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionResponse:
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Choice]
|
||||
usage: Optional[Usage] = None
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionChunk:
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingRequest:
|
||||
input: Union[str, List[str]]
|
||||
model: str
|
||||
encoding_format: str = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingData:
|
||||
object: str
|
||||
embedding: List[float]
|
||||
index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingUsage:
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResponse:
|
||||
object: str
|
||||
data: List[EmbeddingData]
|
||||
model: str
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
class BaseLLMAPI(ABC):
|
||||
"""
|
||||
LLM API base class
|
||||
|
||||
Provide standard OpenAI format API, support:
|
||||
- Synchronous/Asynchronous Chat Completions
|
||||
- Streaming Responses
|
||||
- Tool Calling
|
||||
- Error handling and Retry
|
||||
- Usage Statics
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.extra_config = kwargs
|
||||
|
||||
self.logger = logger
|
||||
|
||||
self.session: Optional[requests.Session] = None
|
||||
|
||||
self._initialize_client()
|
||||
|
||||
def _create_http_clients(self, headers: Dict[str, str]) -> None:
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(headers)
|
||||
self.session.timeout = self.timeout
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_client(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
pass
|
||||
|
||||
def chat_completions_create(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None]]:
|
||||
self.validate_request(request)
|
||||
|
||||
def _make_request():
|
||||
payload = self._build_request_payload(request)
|
||||
endpoint = self._get_chat_endpoint()
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(payload, endpoint)
|
||||
else:
|
||||
|
||||
full_url = self.base_url.rstrip("/") + endpoint
|
||||
|
||||
headers = {}
|
||||
if self.session and self.session.headers:
|
||||
headers.update(self.session.headers)
|
||||
else:
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key and self.api_key != "EMPTY":
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
response = requests.post(
|
||||
full_url, json=payload, headers=headers, timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"API Error {response.status_code}: {response.text}"
|
||||
print(error_msg)
|
||||
raise requests.exceptions.HTTPError(error_msg, response=response)
|
||||
response.raise_for_status()
|
||||
return self._parse_response(response.json())
|
||||
|
||||
return self._retry_with_backoff(_make_request)
|
||||
|
||||
async def achat_completions_create(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> Union[ChatCompletionResponse, AsyncGenerator[ChatCompletionChunk, None]]:
|
||||
self.validate_request(request)
|
||||
|
||||
def _make_async_request():
|
||||
payload = self._build_request_payload(request)
|
||||
endpoint = self._get_chat_endpoint()
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(
|
||||
payload, endpoint
|
||||
)
|
||||
else:
|
||||
full_url = self.base_url.rstrip("/") + endpoint
|
||||
|
||||
headers = {}
|
||||
if self.session and self.session.headers:
|
||||
headers.update(self.session.headers)
|
||||
else:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key and self.api_key != "EMPTY":
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
response = requests.post(
|
||||
full_url, json=payload, headers=headers, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._parse_response(response.json())
|
||||
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _make_async_request)
|
||||
|
||||
def create_message(self, role: MessageRole, content: str, config: Optional[Dict[str, Any]] = None, **kwargs) -> ChatMessage:
|
||||
return ChatMessage(role=role, content=content, **kwargs)
|
||||
|
||||
def _make_token_cache_request(self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None) -> list[ChatMessage]:
|
||||
"""
|
||||
Insert cache_control block to Claude/Anthropic model according to config, forward compatibility
|
||||
|
||||
- Only activate when model is "Claude/Anthropic"
|
||||
- Preserve the content as is for other providers to avoid breaking compatibility.
|
||||
- If token_cache.role_turns is configured, inject it matching the turn and role; otherwise, only inject it into the first system message.
|
||||
"""
|
||||
role_turns = self._load_role_turns(config)
|
||||
#if self.logger:
|
||||
# self.logger.debug(f"In _make_token_cache_request, got role_turns: {role_turns}.")
|
||||
if not self._is_claude_like(model):
|
||||
return messages
|
||||
|
||||
turn_to_roles: Dict[int, List[str]] = {}
|
||||
try:
|
||||
for rt in role_turns or []:
|
||||
try:
|
||||
turn = int(rt.get("turn"))
|
||||
role = (rt.get("role") or "").strip().lower()
|
||||
if role:
|
||||
if turn not in turn_to_roles:
|
||||
turn_to_roles[turn] = []
|
||||
turn_to_roles[turn].append(role)
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
turn_to_roles = {}
|
||||
|
||||
current_turn = 0
|
||||
result_messages: List[ChatMessage] = []
|
||||
|
||||
for msg in messages:
|
||||
msg_blocks = self._to_claude_blocks(msg.content)
|
||||
|
||||
# 0th turn rule: If met user, only add cache to its own content and put it into result.
|
||||
if current_turn == 0 and msg.role == MessageRole.USER:
|
||||
cached_blocks = self._add_cache_flag(msg_blocks)
|
||||
result_messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=cached_blocks,
|
||||
name=msg.name,
|
||||
tool_calls=msg.tool_calls,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
)
|
||||
#if self.logger:
|
||||
# self.logger.debug(
|
||||
# f"Applied cache to initial user message at turn {current_turn}."
|
||||
# )
|
||||
continue
|
||||
|
||||
# Other messages: just add it to result
|
||||
result_messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg_blocks,
|
||||
name=msg.name,
|
||||
tool_calls=msg.tool_calls,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Hit anchor point: When the current turn's configuration includes the assistant and the current message is from the assistant
|
||||
# add an extra system turn.
|
||||
roles_for_turn = turn_to_roles.get(current_turn, [])
|
||||
if msg.role == MessageRole.ASSISTANT and roles_for_turn and ("assistant" in roles_for_turn):
|
||||
refresh_text = f"refresh cache tokens."
|
||||
refresh_blocks = self._to_claude_blocks(refresh_text)
|
||||
refresh_blocks = self._add_cache_flag(refresh_blocks)
|
||||
result_messages.append(
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content=refresh_blocks,
|
||||
)
|
||||
)
|
||||
#if self.logger:
|
||||
# self.logger.debug(
|
||||
# f"Appended system refresh after assistant at turn {current_turn}, refresh_blocks: {refresh_blocks}."
|
||||
# )
|
||||
|
||||
# assistant message make the round +1
|
||||
if msg.role == MessageRole.ASSISTANT:
|
||||
current_turn += 1
|
||||
|
||||
return result_messages
|
||||
|
||||
def _to_claude_blocks(self, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
if isinstance(content, list):
|
||||
return content
|
||||
return [{"type": "text", "text": content}]
|
||||
|
||||
def _add_cache_flag(self, blocks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
added = False
|
||||
for blk in blocks:
|
||||
if not added and isinstance(blk, dict) and blk.get("type") == "text":
|
||||
out.append({**blk, "cache_control": {"type": "ephemeral"}})
|
||||
added = True
|
||||
else:
|
||||
out.append(blk)
|
||||
return out
|
||||
|
||||
def _is_claude_like(self, model: str) -> bool:
|
||||
try:
|
||||
model_lc = (model or "").lower()
|
||||
return ("claude" in model_lc) or ("anthropic" in model_lc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _load_role_turns(self, config: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
"""Load token_cache.role_turns,Priority:Invoker config > default value"""
|
||||
try:
|
||||
role_turns = None
|
||||
|
||||
# 1) Read from invoker's config
|
||||
if config and isinstance(config, dict):
|
||||
token_cache_cfg = config.get("token_cache")
|
||||
if isinstance(token_cache_cfg, dict):
|
||||
role_turns = token_cache_cfg.get("role_turns")
|
||||
|
||||
# 2) If still lack, use default value in current example config
|
||||
if not role_turns:
|
||||
role_turns = [
|
||||
{"turn": 0, "role": "user"},
|
||||
{"turn": 25, "role": "assistant"},
|
||||
{"turn": 50, "role": "assistant"},
|
||||
{"turn": 75, "role": "assistant"},
|
||||
]
|
||||
|
||||
return role_turns
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"Read token_cache.role_turns fail,use default value: {e}")
|
||||
return [
|
||||
{"turn": 0, "role": "user"},
|
||||
{"turn": 25, "role": "assistant"},
|
||||
{"turn": 50, "role": "assistant"},
|
||||
{"turn": 75, "role": "assistant"},
|
||||
]
|
||||
|
||||
def create_request(
|
||||
self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None, **kwargs
|
||||
) -> ChatCompletionRequest:
|
||||
|
||||
messages = self._make_token_cache_request(messages, model, config)
|
||||
|
||||
return ChatCompletionRequest(messages=messages, model=model, **kwargs)
|
||||
|
||||
def _handle_error(self, error: Exception) -> None:
|
||||
if self.logger:
|
||||
self.logger.error(f"API request failed: {error}")
|
||||
else:
|
||||
print(f"API request failed: {error}")
|
||||
raise error
|
||||
|
||||
def _retry_with_backoff(self, func, *args, **kwargs):
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < self.max_retries:
|
||||
delay = self.retry_delay * (2**attempt) # 指数退避
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"{attempt + 1}th try failed,retry after {delay}s: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(f"All retries failed: {e}, traceback: {format_exc()}.")
|
||||
|
||||
raise last_exception
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"provider": self.__class__.__name__,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
|
||||
def validate_request(self, request: ChatCompletionRequest) -> bool:
|
||||
if not request.messages:
|
||||
raise ValueError("Message list is empty")
|
||||
|
||||
if not request.model:
|
||||
raise ValueError("Model name is empty")
|
||||
|
||||
for idx, message in enumerate(request.messages):
|
||||
if not message.content and not message.tool_calls:
|
||||
try:
|
||||
msg_info = {
|
||||
"index": idx,
|
||||
"role": getattr(message.role, "value", str(message.role)),
|
||||
"content_len": (
|
||||
len(message.content)
|
||||
if isinstance(message.content, str)
|
||||
else (len(message.content) if isinstance(message.content, list) else 0)
|
||||
),
|
||||
"has_tool_calls": bool(message.tool_calls),
|
||||
"tool_calls": message.tool_calls,
|
||||
}
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"Request validation failed: Invalid message exists (lacking both content and tool_calls): {json.dumps(msg_info, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"Request validation failed: Invalid message exists (lacking both content and tool_calls), index={idx}, error: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
raise ValueError("Cannot lacking both content and tool_calls")
|
||||
|
||||
return True
|
||||
|
||||
def format_messages_for_api(
|
||||
self, messages: List[ChatMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
msg_dict = {"role": message.role.value, "content": message.content}
|
||||
|
||||
if message.name:
|
||||
msg_dict["name"] = message.name
|
||||
|
||||
if message.tool_calls:
|
||||
msg_dict["tool_calls"] = message.tool_calls
|
||||
|
||||
if message.tool_call_id:
|
||||
msg_dict["tool_call_id"] = message.tool_call_id
|
||||
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _stream_chat_completion(
|
||||
self, payload: Dict[str, Any], endpoint: str
|
||||
) -> Generator[ChatCompletionChunk, None, None]:
|
||||
full_url = self.base_url.rstrip("/") + endpoint
|
||||
|
||||
headers = {}
|
||||
if self.session and self.session.headers:
|
||||
headers.update(self.session.headers)
|
||||
else:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key and self.api_key != "EMPTY":
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
response = requests.post(
|
||||
full_url, json=payload, headers=headers, timeout=self.timeout, stream=True
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
line_count = 0
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_count += 1
|
||||
line_str = line.decode("utf-8")
|
||||
chunk = self._parse_stream_line(line_str)
|
||||
if chunk:
|
||||
yield chunk
|
||||
else:
|
||||
print(f"Jump invalid line")
|
||||
print(f"Streaming request process done, {line_count} processed")
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
async def _astream_chat_completion(
|
||||
self, payload: Dict[str, Any], endpoint: str
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
import asyncio
|
||||
|
||||
def _sync_stream():
|
||||
return list(self._stream_chat_completion(payload, endpoint))
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
chunks = await loop.run_in_executor(None, _sync_stream)
|
||||
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
|
||||
# Process standard SSE format
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
|
||||
if data.strip() == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(data)
|
||||
return self._parse_stream_chunk(chunk_data)
|
||||
except json.JSONDecodeError:
|
||||
self.logger.warning(f"Unable to parse streaming data: {data}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
if self.session:
|
||||
self.session.close()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.session:
|
||||
self.session.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.aclose()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(base_url={self.base_url})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"base_url={self.base_url}, "
|
||||
f"timeout={self.timeout}, "
|
||||
f"max_retries={self.max_retries})"
|
||||
)
|
||||
23
src/managers/llm_api/clients/__init__.py
Normal file
23
src/managers/llm_api/clients/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
LLM Client implement module
|
||||
|
||||
Diverse LLM provider:
|
||||
- OpenAI: OpenAI API and compatible service
|
||||
- Anthropic: Claude models
|
||||
- DeepSeek: DeepSeek models
|
||||
- Private: Private models(vLLM、TGI、Ollama etc.)
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
|
||||
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
|
||||
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
|
||||
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
|
||||
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
|
||||
|
||||
__all__ = [
|
||||
"OpenAIClient",
|
||||
"AnthropicClient",
|
||||
"DeepSeekClient",
|
||||
"OpenRouterClient",
|
||||
"PrivateModelClient",
|
||||
]
|
||||
7
src/managers/llm_api/clients/anthropic/__init__.py
Normal file
7
src/managers/llm_api/clients/anthropic/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Anthropic Claude Client Module
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
|
||||
|
||||
__all__ = ["AnthropicClient"]
|
||||
288
src/managers/llm_api/clients/anthropic/anthropic_client.py
Normal file
288
src/managers/llm_api/clients/anthropic/anthropic_client.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Anthropic Claude Client Module
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicClient(BaseLLMAPI):
|
||||
"""
|
||||
Anthropic Claude API Client
|
||||
Anthropic Claude models supported
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
anthropic_version: str = "2023-06-01",
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize Anthropic client
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
base_url: API basic URL,default Anthropic API
|
||||
timeout: timeout in seconds
|
||||
max_retries: max retries
|
||||
retry_delay: retry delay in seconds
|
||||
anthropic_version: API version
|
||||
**kwargs: other args
|
||||
"""
|
||||
self.anthropic_version = anthropic_version
|
||||
|
||||
if base_url is None:
|
||||
base_url = "https://api.anthropic.com"
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
"""Initialize HTTP client"""
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": self.anthropic_version,
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
}
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
return "/v1/messages"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
"""Build Anthropic API request payload"""
|
||||
messages, system_prompt = self._convert_messages_to_anthropic_format(
|
||||
request.messages
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": messages,
|
||||
"max_tokens": request.max_tokens or 1024,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop_sequences"] = (
|
||||
request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
)
|
||||
|
||||
if request.tools is not None:
|
||||
payload["tools"] = self._convert_tools_to_anthropic_format(request.tools)
|
||||
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = self._convert_tool_choice_to_anthropic_format(
|
||||
request.tool_choice
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def _convert_messages_to_anthropic_format(
|
||||
self, messages: List[ChatMessage]
|
||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""Convert messages to anthropic format"""
|
||||
anthropic_messages = []
|
||||
system_prompt = None
|
||||
|
||||
for message in messages:
|
||||
if message.role == MessageRole.SYSTEM:
|
||||
system_prompt = message.content
|
||||
elif message.role in [MessageRole.USER, MessageRole.ASSISTANT]:
|
||||
anthropic_messages.append(
|
||||
{"role": message.role.value, "content": message.content}
|
||||
)
|
||||
|
||||
return anthropic_messages, system_prompt
|
||||
|
||||
def _convert_tools_to_anthropic_format(
|
||||
self, tools: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert OpenAI format tools to anthropic format"""
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
function_def = tool.get("function", {})
|
||||
anthropic_tool = {
|
||||
"name": function_def.get("name", ""),
|
||||
"description": function_def.get("description", ""),
|
||||
"input_schema": function_def.get("parameters", {}),
|
||||
}
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _convert_tool_choice_to_anthropic_format(
|
||||
self, tool_choice: Union[str, Dict[str, Any]]
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Convert OpenAI format tool choice to anthropic format"""
|
||||
if isinstance(tool_choice, str):
|
||||
if tool_choice == "auto":
|
||||
return "auto"
|
||||
elif tool_choice == "none":
|
||||
return "none"
|
||||
else:
|
||||
return {"type": "tool", "name": tool_choice}
|
||||
elif isinstance(tool_choice, dict):
|
||||
if tool_choice.get("type") == "function":
|
||||
return {
|
||||
"type": "tool",
|
||||
"name": tool_choice.get("function", {}).get("name", ""),
|
||||
}
|
||||
elif tool_choice.get("type") == "tool":
|
||||
return tool_choice
|
||||
|
||||
return "auto"
|
||||
|
||||
def _convert_anthropic_tool_calls(
|
||||
self, content_list: List[Dict[str, Any]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Convert anthropic tool calls to OpenAI format"""
|
||||
tool_calls = []
|
||||
|
||||
for item in content_list:
|
||||
if item.get("type") == "tool_use":
|
||||
tool_call = {
|
||||
"id": item.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item.get("name", ""),
|
||||
"arguments": item.get("input", {}),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
"""Convert anthropic API response to OpenAI format"""
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
if response_data.get("content"):
|
||||
content_data = (
|
||||
response_data["content"][0] if response_data["content"] else {}
|
||||
)
|
||||
content = content_data.get("text", "")
|
||||
|
||||
if content_data.get("type") == "tool_use":
|
||||
tool_calls = self._convert_anthropic_tool_calls(
|
||||
response_data.get("content", [])
|
||||
)
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole.ASSISTANT, content=content, tool_calls=tool_calls
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=self._convert_stop_reason(response_data.get("stop_reason")),
|
||||
)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("input_tokens", 0),
|
||||
completion_tokens=usage_data.get("output_tokens", 0),
|
||||
total_tokens=usage_data.get("input_tokens", 0)
|
||||
+ usage_data.get("output_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=response_data.get("model", ""),
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _convert_stop_reason(self, stop_reason: Optional[str]) -> Optional[str]:
|
||||
"""Convert anthropic stop reason to OpenAI format"""
|
||||
if stop_reason == "end_turn":
|
||||
return "stop"
|
||||
elif stop_reason == "max_tokens":
|
||||
return "length"
|
||||
elif stop_reason == "stop_sequence":
|
||||
return "stop"
|
||||
else:
|
||||
return stop_reason
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
"""Convert anthropic steam response to OpenAI format"""
|
||||
event_type = chunk_data.get("type")
|
||||
|
||||
if event_type == "content_block_delta":
|
||||
delta = chunk_data.get("delta", {})
|
||||
text = delta.get("text", "")
|
||||
|
||||
if text:
|
||||
choices = [
|
||||
{"index": 0, "delta": {"content": text}, "finish_reason": None}
|
||||
]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("message", {}).get("id", ""),
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=chunk_data.get("message", {}).get("model", ""),
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
elif event_type == "message_stop":
|
||||
choices = [{"index": 0, "delta": {}, "finish_reason": "stop"}]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("message", {}).get("id", ""),
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=chunk_data.get("message", {}).get("model", ""),
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
return self._parse_stream_chunk(chunk_data)
|
||||
except json.JSONDecodeError:
|
||||
# if not json, try SSE
|
||||
return super()._parse_stream_line(line)
|
||||
7
src/managers/llm_api/clients/deepseek/__init__.py
Normal file
7
src/managers/llm_api/clients/deepseek/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
DeepSeek Client Module
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
|
||||
|
||||
__all__ = ["DeepSeekClient"]
|
||||
164
src/managers/llm_api/clients/deepseek/deepseek_client.py
Normal file
164
src/managers/llm_api/clients/deepseek/deepseek_client.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
DeepSeek Client Module
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekClient(BaseLLMAPI):
|
||||
"""
|
||||
DeepSeek API Client
|
||||
|
||||
DeepSeek models supported, including DeepSeek-Coder、DeepSeek-Chat etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if base_url is None:
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
}
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
return "/chat/completions"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
# optional
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
if request.frequency_penalty != 0.0:
|
||||
payload["frequency_penalty"] = request.frequency_penalty
|
||||
|
||||
if request.presence_penalty != 0.0:
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
|
||||
if request.tools is not None:
|
||||
payload["tools"] = request.tools
|
||||
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
if request.response_format is not None:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
if request.seed is not None:
|
||||
payload["seed"] = request.seed
|
||||
|
||||
if request.user is not None:
|
||||
payload["user"] = request.user
|
||||
|
||||
return payload
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
choices = []
|
||||
for choice_data in response_data.get("choices", []):
|
||||
message_data = choice_data.get("message", {})
|
||||
|
||||
tool_calls = message_data.get("tool_calls")
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole(message_data.get("role", "assistant")),
|
||||
content=message_data.get("content", ""),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=choice_data.get("index", 0),
|
||||
message=message,
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
logprobs=choice_data.get("logprobs"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object=response_data.get("object", "chat.completion"),
|
||||
created=response_data.get("created", int(time.time())),
|
||||
model=response_data.get("model", ""),
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
system_fingerprint=response_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
"""Parse stream chunk"""
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("id", ""),
|
||||
object=chunk_data.get("object", "chat.completion.chunk"),
|
||||
created=chunk_data.get("created", int(time.time())),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=chunk_data.get("choices", []),
|
||||
system_fingerprint=chunk_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def list_models(self) -> Dict[str, Any]:
|
||||
"""Get available models"""
|
||||
response = self.client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def alist_models(self) -> Dict[str, Any]:
|
||||
response = await self.async_client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
7
src/managers/llm_api/clients/openai/__init__.py
Normal file
7
src/managers/llm_api/clients/openai/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
OpenAI Client Module
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
|
||||
|
||||
__all__ = ["OpenAIClient"]
|
||||
279
src/managers/llm_api/clients/openai/openai_client.py
Normal file
279
src/managers/llm_api/clients/openai/openai_client.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
OpenAI Client
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingData,
|
||||
EmbeddingUsage,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMAPI):
|
||||
"""
|
||||
OpenAI API Client
|
||||
|
||||
OpenAI API supported, and other API services compatible with the OpenAI format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.organization = organization
|
||||
|
||||
if base_url is None:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
}
|
||||
|
||||
if self.organization:
|
||||
headers["OpenAI-Organization"] = self.organization
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
return "/chat/completions"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
# Optional
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
if request.tools is not None:
|
||||
payload["tools"] = request.tools
|
||||
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
if request.response_format is not None:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
if request.seed is not None:
|
||||
payload["seed"] = request.seed
|
||||
|
||||
if request.user is not None:
|
||||
payload["user"] = request.user
|
||||
|
||||
return payload
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
choices = []
|
||||
for choice_data in response_data.get("choices", []):
|
||||
message_data = choice_data.get("message", {})
|
||||
|
||||
tool_calls = message_data.get("tool_calls")
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole(message_data.get("role", "assistant")),
|
||||
content=message_data.get("content", ""),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=choice_data.get("index", 0),
|
||||
message=message,
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
logprobs=choice_data.get("logprobs"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object=response_data.get("object", "chat.completion"),
|
||||
created=response_data.get("created", int(time.time())),
|
||||
model=response_data.get("model", ""),
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
system_fingerprint=response_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("id", ""),
|
||||
object=chunk_data.get("object", "chat.completion.chunk"),
|
||||
created=chunk_data.get("created", int(time.time())),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=chunk_data.get("choices", []),
|
||||
system_fingerprint=chunk_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def list_models(self) -> Dict[str, Any]:
|
||||
response = self.client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def alist_models(self) -> Dict[str, Any]:
|
||||
response = await self.async_client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def create_embeddings(
|
||||
self,
|
||||
input_text: Union[str, List[str]],
|
||||
model: str,
|
||||
encoding_format: str = "float",
|
||||
dimensions: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
retry_delay: Optional[float] = None,
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
actual_timeout = timeout if timeout is not None else self.timeout
|
||||
actual_max_retries = (
|
||||
max_retries if max_retries is not None else self.max_retries
|
||||
)
|
||||
actual_retry_delay = (
|
||||
retry_delay if retry_delay is not None else self.retry_delay
|
||||
)
|
||||
|
||||
request = EmbeddingRequest(
|
||||
input=input_text,
|
||||
model=model,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
)
|
||||
|
||||
payload = self._build_embedding_request_payload(request)
|
||||
|
||||
for attempt in range(actual_max_retries + 1):
|
||||
try:
|
||||
print(f"Debug: Payload: {payload}")
|
||||
|
||||
response = self.session.post(
|
||||
f"{self.base_url}/embeddings", json=payload, timeout=actual_timeout
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
return self._parse_embedding_response(response_data)
|
||||
else:
|
||||
error_msg = f"embedding failed (try {attempt + 1}): HTTP {response.status_code}"
|
||||
if hasattr(response, "text"):
|
||||
error_msg += f" - {response.text}"
|
||||
print(f"Debug: {error_msg}")
|
||||
|
||||
if attempt < actual_max_retries:
|
||||
print(f"Debug: wait {actual_retry_delay} and retry...")
|
||||
time.sleep(actual_retry_delay)
|
||||
continue
|
||||
else:
|
||||
raise Exception(f"All retries failed: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Embedding request failed (try {attempt + 1}): {str(e)}, traceback: {format_exc()}."
|
||||
print(f"Debug: {error_msg}")
|
||||
|
||||
if attempt < actual_max_retries:
|
||||
print(f"Debug: wait {actual_retry_delay} and retry...")
|
||||
time.sleep(actual_retry_delay)
|
||||
continue
|
||||
else:
|
||||
raise Exception(f"All retries failed: {str(e)}, traceback: {format_exc()}.")
|
||||
|
||||
raise Exception("Unknown error")
|
||||
|
||||
def _build_embedding_request_payload(
|
||||
self, request: EmbeddingRequest
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
payload = {
|
||||
"input": request.input,
|
||||
"model": request.model,
|
||||
"encoding_format": request.encoding_format,
|
||||
}
|
||||
|
||||
if request.dimensions is not None:
|
||||
payload["dimensions"] = request.dimensions
|
||||
|
||||
if request.user is not None:
|
||||
payload["user"] = request.user
|
||||
|
||||
return payload
|
||||
|
||||
def _parse_embedding_response(
|
||||
self, response_data: Dict[str, Any]
|
||||
) -> EmbeddingResponse:
|
||||
embedding_data_list = []
|
||||
for data_item in response_data.get("data", []):
|
||||
embedding_data = EmbeddingData(
|
||||
object=data_item.get("object", "embedding"),
|
||||
embedding=data_item.get("embedding", []),
|
||||
index=data_item.get("index", 0),
|
||||
)
|
||||
embedding_data_list.append(embedding_data)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = EmbeddingUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
object=response_data.get("object", "list"),
|
||||
data=embedding_data_list,
|
||||
model=response_data.get("model", ""),
|
||||
usage=usage,
|
||||
)
|
||||
8
src/managers/llm_api/clients/openrouter/__init__.py
Normal file
8
src/managers/llm_api/clients/openrouter/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
OpenRouter Client
|
||||
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
|
||||
|
||||
__all__ = ["OpenRouterClient"]
|
||||
329
src/managers/llm_api/clients/openrouter/openrouter_client.py
Normal file
329
src/managers/llm_api/clients/openrouter/openrouter_client.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
OpenRouter Client
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from src.managers.log.logger import Logger
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class OpenRouterClient(BaseLLMAPI):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
app_name: Optional[str] = None,
|
||||
site_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize OpenRouter Client
|
||||
|
||||
Args:
|
||||
api_key: OpenRouter API key
|
||||
base_url: API base URL, default to OpenRouter official API
|
||||
app_name: Application name (optional, for statistics)
|
||||
site_url: Website URL (optional, for statistics)
|
||||
timeout: Request timeout
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Retry delay
|
||||
**kwargs: Other configuration parameters
|
||||
"""
|
||||
self.app_name = app_name or "tokfinity-llm-client"
|
||||
self.site_url = site_url
|
||||
|
||||
if base_url is None:
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
"X-Title": self.app_name,
|
||||
}
|
||||
|
||||
if self.site_url:
|
||||
headers["HTTP-Referer"] = self.site_url
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
|
||||
return "/chat/completions"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
if request.frequency_penalty != 0.0:
|
||||
payload["frequency_penalty"] = request.frequency_penalty
|
||||
|
||||
if request.presence_penalty != 0.0:
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
|
||||
if request.tools is not None:
|
||||
payload["tools"] = request.tools
|
||||
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
if request.response_format is not None:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
if request.seed is not None:
|
||||
payload["seed"] = request.seed
|
||||
|
||||
if request.user is not None:
|
||||
payload["user"] = request.user
|
||||
|
||||
return payload
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
"""
|
||||
Parse OpenRouter API response
|
||||
|
||||
Args:
|
||||
response_data: API response data
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: Parsed response object
|
||||
"""
|
||||
choices = []
|
||||
for choice_data in response_data.get("choices", []):
|
||||
message_data = choice_data.get("message", {})
|
||||
|
||||
tool_calls = message_data.get("tool_calls")
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole(message_data.get("role", "assistant")),
|
||||
content=message_data.get("content", ""),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=choice_data.get("index", 0),
|
||||
message=message,
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
logprobs=choice_data.get("logprobs"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object=response_data.get("object", "chat.completion"),
|
||||
created=response_data.get("created", int(time.time())),
|
||||
model=response_data.get("model", ""),
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
system_fingerprint=response_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
"""
|
||||
Parse stream chunk
|
||||
Args:
|
||||
chunk_data: Raw chunk data
|
||||
|
||||
Returns:
|
||||
Optional[ChatCompletionChunk]: Parsed chunk
|
||||
"""
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("id", ""),
|
||||
object=chunk_data.get("object", "chat.completion.chunk"),
|
||||
created=chunk_data.get("created", int(time.time())),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=chunk_data.get("choices", []),
|
||||
system_fingerprint=chunk_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def list_models(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get available models
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model list response
|
||||
"""
|
||||
response = self.client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def alist_models(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Async get available models
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model list response
|
||||
"""
|
||||
response = await self.async_client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_generation_info(self, generation_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get specific generation request details
|
||||
|
||||
Args:
|
||||
generation_id: Generation request ID
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Generation information
|
||||
"""
|
||||
response = self.client.get(f"/generation?id={generation_id}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def aget_generation_info(self, generation_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Async get specific generation request details
|
||||
|
||||
Args:
|
||||
generation_id: Generation request ID
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Generation information
|
||||
"""
|
||||
response = await self.async_client.get(f"/generation?id={generation_id}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_account_credits(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get account credits
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Account credits information
|
||||
"""
|
||||
response = self.client.get("/auth/key")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def aget_account_credits(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Async get account credits
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Account credits information
|
||||
"""
|
||||
response = await self.async_client.get("/auth/key")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_popular_models() -> List[str]:
|
||||
"""
|
||||
Get popular models
|
||||
|
||||
Returns:
|
||||
List[str]: Popular model names
|
||||
"""
|
||||
return [
|
||||
# OpenAI
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"openai/gpt-4",
|
||||
"openai/gpt-3.5-turbo",
|
||||
# Anthropic
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-3-sonnet",
|
||||
"anthropic/claude-3-haiku",
|
||||
# Google
|
||||
"google/gemini-pro",
|
||||
"google/gemini-pro-vision",
|
||||
# Meta
|
||||
"meta-llama/llama-2-70b-chat",
|
||||
"meta-llama/llama-2-13b-chat",
|
||||
# Mistral
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
"mistralai/mistral-7b-instruct",
|
||||
# Open Source
|
||||
"microsoft/wizardlm-2-8x22b",
|
||||
"databricks/dbrx-instruct",
|
||||
"cohere/command-r-plus",
|
||||
]
|
||||
|
||||
def get_model_info(self, model_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get specific model details
|
||||
|
||||
Args:
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model information
|
||||
"""
|
||||
models_response = self.list_models()
|
||||
models = models_response.get("data", [])
|
||||
|
||||
for model in models:
|
||||
if model.get("id") == model_name:
|
||||
return model
|
||||
|
||||
raise ValueError(f"Model {model_name} not found")
|
||||
|
||||
async def aget_model_info(self, model_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Async get specific model details
|
||||
|
||||
Args:
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model information
|
||||
"""
|
||||
models_response = await self.alist_models()
|
||||
models = models_response.get("data", [])
|
||||
|
||||
for model in models:
|
||||
if model.get("id") == model_name:
|
||||
return model
|
||||
|
||||
raise ValueError(f"Model {model_name} not found")
|
||||
7
src/managers/llm_api/clients/private/__init__.py
Normal file
7
src/managers/llm_api/clients/private/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Private model client
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
|
||||
|
||||
__all__ = ["PrivateModelClient"]
|
||||
321
src/managers/llm_api/clients/private/private_client.py
Normal file
321
src/managers/llm_api/clients/private/private_client.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Private model client
|
||||
Implement of private model API based on the BaseLLMAPI
|
||||
vLLM, Text Generation Inference, Ollama supported
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class PrivateModelClient(BaseLLMAPI):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
deployment_type: str = "vllm",
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
supports_tools: bool = True,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.deployment_type = deployment_type.lower()
|
||||
self.custom_headers = custom_headers or {}
|
||||
self.supports_tools = supports_tools
|
||||
super().__init__(
|
||||
api_key=api_key or "EMPTY",
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
}
|
||||
|
||||
if self.api_key and self.api_key != "EMPTY":
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
headers.update(self.custom_headers)
|
||||
|
||||
|
||||
if self.deployment_type == "ollama":
|
||||
pass
|
||||
elif self.deployment_type == "tgi":
|
||||
pass
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
if self.deployment_type == "ollama":
|
||||
return "/api/chat"
|
||||
elif self.deployment_type == "tgi":
|
||||
return "/generate"
|
||||
else:
|
||||
return "/chat/completions"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
if self.deployment_type == "ollama":
|
||||
return self._build_ollama_payload(request)
|
||||
elif self.deployment_type == "tgi":
|
||||
return self._build_tgi_payload(request)
|
||||
else:
|
||||
return self._build_openai_payload(request)
|
||||
|
||||
def _build_openai_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
# Optional
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
if request.frequency_penalty != 0.0:
|
||||
payload["frequency_penalty"] = request.frequency_penalty
|
||||
|
||||
if request.presence_penalty != 0.0:
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
|
||||
if self.supports_tools and request.tools is not None:
|
||||
payload["tools"] = request.tools
|
||||
|
||||
if self.supports_tools and request.tool_choice is not None:
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
if request.response_format is not None:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
return payload
|
||||
|
||||
def _build_ollama_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"stream": request.stream,
|
||||
"options": {
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if request.max_tokens is not None:
|
||||
payload["options"]["num_predict"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["options"]["stop"] = (
|
||||
request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def _build_tgi_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
prompt = self._messages_to_prompt(request.messages)
|
||||
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"do_sample": True,
|
||||
"stream": request.stream,
|
||||
},
|
||||
}
|
||||
|
||||
if request.max_tokens is not None:
|
||||
payload["parameters"]["max_new_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["parameters"]["stop_sequences"] = (
|
||||
request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def _messages_to_prompt(self, messages: List[ChatMessage]) -> str:
|
||||
prompt_parts = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == MessageRole.SYSTEM:
|
||||
prompt_parts.append(f"System: {message.content}")
|
||||
elif message.role == MessageRole.USER:
|
||||
prompt_parts.append(f"User: {message.content}")
|
||||
elif message.role == MessageRole.ASSISTANT:
|
||||
prompt_parts.append(f"Assistant: {message.content}")
|
||||
|
||||
prompt_parts.append("Assistant:")
|
||||
return "\n\n".join(prompt_parts)
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
if self.deployment_type == "ollama":
|
||||
return self._parse_ollama_response(response_data)
|
||||
elif self.deployment_type == "tgi":
|
||||
return self._parse_tgi_response(response_data)
|
||||
else:
|
||||
return self._parse_openai_response(response_data)
|
||||
|
||||
def _parse_openai_response(
|
||||
self, response_data: Dict[str, Any]
|
||||
) -> ChatCompletionResponse:
|
||||
choices = []
|
||||
for choice_data in response_data.get("choices", []):
|
||||
message_data = choice_data.get("message", {})
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole(message_data.get("role", "assistant")),
|
||||
content=message_data.get("content", ""),
|
||||
tool_calls=message_data.get("tool_calls"),
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=choice_data.get("index", 0),
|
||||
message=message,
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object=response_data.get("object", "chat.completion"),
|
||||
created=response_data.get("created", int(time.time())),
|
||||
model=response_data.get("model", ""),
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _parse_ollama_response(
|
||||
self, response_data: Dict[str, Any]
|
||||
) -> ChatCompletionResponse:
|
||||
content = response_data.get("message", {}).get("content", "")
|
||||
|
||||
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
|
||||
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason="stop" if response_data.get("done", False) else None,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=f"ollama-{int(time.time())}",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=response_data.get("model", ""),
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
def _parse_tgi_response(
|
||||
self, response_data: Dict[str, Any]
|
||||
) -> ChatCompletionResponse:
|
||||
if isinstance(response_data, list) and len(response_data) > 0:
|
||||
response_data = response_data[0]
|
||||
|
||||
content = response_data.get("generated_text", "")
|
||||
|
||||
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
|
||||
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=response_data.get("finish_reason", "stop"),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=f"tgi-{int(time.time())}",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model="tgi-model",
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
|
||||
if self.deployment_type == "ollama":
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
return self._parse_ollama_chunk(chunk_data)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
else:
|
||||
# Standard SSE format
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(data)
|
||||
return self._parse_stream_chunk(chunk_data)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _parse_ollama_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
content = chunk_data.get("message", {}).get("content", "")
|
||||
done = chunk_data.get("done", False)
|
||||
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": content} if content else {},
|
||||
"finish_reason": "stop" if done else None,
|
||||
}
|
||||
]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id=f"ollama-{int(time.time())}",
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("id", ""),
|
||||
object=chunk_data.get("object", "chat.completion.chunk"),
|
||||
created=chunk_data.get("created", int(time.time())),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=chunk_data.get("choices", []),
|
||||
)
|
||||
43
src/managers/log/__init__.py
Normal file
43
src/managers/log/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Log manage Module
|
||||
|
||||
Provide log management functionality categorized by timestamp and level
|
||||
|
||||
Usage example:
|
||||
from src.managers.log import Logger, create_logger
|
||||
|
||||
# Basic Usage
|
||||
logger = Logger("logs", "my_app")
|
||||
logger.info("This is an info")
|
||||
logger.error("This is an error")
|
||||
logger.close()
|
||||
|
||||
# Use the context manager
|
||||
with Logger("logs", "my_app") as logger:
|
||||
logger.info("Automatically manage resources")
|
||||
|
||||
# Convenient function
|
||||
logger = create_logger("logs", "my_app")
|
||||
logger.info("Convenient create")
|
||||
logger.close()
|
||||
"""
|
||||
|
||||
from .logger import (
|
||||
Logger,
|
||||
create_logger,
|
||||
init_global_logger,
|
||||
get_global_logger,
|
||||
set_global_logger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Logger",
|
||||
"create_logger",
|
||||
"init_global_logger",
|
||||
"get_global_logger",
|
||||
"set_global_logger",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Tokfinity Team"
|
||||
__description__ = "Timestamp:Directory and Level-Based Log Manager "
|
||||
357
src/managers/log/logger.py
Normal file
357
src/managers/log/logger.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
|
||||
Log Manager
|
||||
Create dir according to timestamp, store level-based log files
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
"""
|
||||
Module-level Self-defined NOTICE level log (between INFO and WARNING)
|
||||
"""
|
||||
NOTICE_LEVEL = 25
|
||||
if not hasattr(logging, "NOTICE"):
|
||||
logging.addLevelName(NOTICE_LEVEL, "NOTICE")
|
||||
|
||||
def notice(self, message, *args, **kwargs):
|
||||
if self.isEnabledFor(NOTICE_LEVEL):
|
||||
self._log(NOTICE_LEVEL, message, args, **kwargs)
|
||||
|
||||
logging.Logger.notice = notice # type: ignore[attr-defined]
|
||||
|
||||
class ImageBuilderLogger:
|
||||
"""
|
||||
ImageBuilder log manager
|
||||
|
||||
Function:
|
||||
- Create sub dir based on timestamp
|
||||
- Create debug.log, info.log and error.log
|
||||
- Provide standard log format, able to print logs two levels up and the code location
|
||||
- Console and file outputs
|
||||
"""
|
||||
|
||||
def __init__(self, log_base_path: str, console_output: bool = True):
|
||||
self.log_base_path = Path(log_base_path)
|
||||
self.console_output = console_output
|
||||
|
||||
self.log_dir = self._create_log_dir()
|
||||
|
||||
self.file_handlers = {}
|
||||
|
||||
self.logger = self._setup_logger()
|
||||
|
||||
def _create_log_dir(self) -> Path:
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M")
|
||||
log_dir = self.log_base_path / f"build_images/{timestamp}"
|
||||
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return log_dir
|
||||
|
||||
def _setup_logger(self) -> logging.Logger:
|
||||
logger = logging.getLogger("image_builder_logger")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
logger.handlers.clear()
|
||||
|
||||
log_format = (
|
||||
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
|
||||
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
|
||||
)
|
||||
formatter = logging.Formatter(log_format)
|
||||
|
||||
if self.console_output:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
self._add_file_handlers(logger, formatter)
|
||||
|
||||
return logger
|
||||
|
||||
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
|
||||
log_levels = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
}
|
||||
|
||||
for level_name, level_value in log_levels.items():
|
||||
log_file = self.log_dir / f"{level_name}.log"
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
if level_name == "debug":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
|
||||
elif level_name == "info":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
|
||||
elif level_name == "warning":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
|
||||
elif level_name == "error":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
self.file_handlers[level_name] = file_handler
|
||||
|
||||
all_log_file = self.log_dir / "all.log"
|
||||
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
|
||||
all_file_handler.setFormatter(formatter)
|
||||
all_file_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(all_file_handler)
|
||||
self.file_handlers["all"] = all_file_handler
|
||||
|
||||
def debug(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.debug(message, *args, **kwargs)
|
||||
|
||||
def info(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.info(message, *args, **kwargs)
|
||||
|
||||
def warning(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.warning(message, *args, **kwargs)
|
||||
|
||||
def error(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.error(message, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def log_file(self) -> str:
|
||||
return str(self.log_dir / "all.log")
|
||||
|
||||
def get_log_dir(self) -> str:
|
||||
return str(self.log_dir.absolute())
|
||||
|
||||
def get_log_files(self) -> dict:
|
||||
return {
|
||||
"debug": str(self.log_dir / "debug.log"),
|
||||
"info": str(self.log_dir / "info.log"),
|
||||
"warning": str(self.log_dir / "warning.log"),
|
||||
"error": str(self.log_dir / "error.log"),
|
||||
"all": str(self.log_dir / "all.log"),
|
||||
}
|
||||
|
||||
def close(self):
|
||||
for handler in self.file_handlers.values():
|
||||
handler.close()
|
||||
|
||||
for handler in self.logger.handlers[:]:
|
||||
self.logger.removeHandler(handler)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ImageBuilderLogger(dir={self.log_dir})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ImageBuilderLogger("
|
||||
f"base_path='{self.log_base_path}', "
|
||||
f"log_dir='{self.log_dir}', "
|
||||
f"console_output={self.console_output})"
|
||||
)
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
Log manager
|
||||
|
||||
Function:
|
||||
- Create sub dir based on timestamp
|
||||
- Create debug.log, info.log, notice.log, warning.log and error.log
|
||||
- Provide standard log format
|
||||
- Console and file outputs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_base_path: str,
|
||||
logger_name: str = "tokfinity_logger",
|
||||
console_output: bool = True,
|
||||
log_format: Optional[str] = None,
|
||||
instance_id: Optional[str] = None,
|
||||
):
|
||||
self.log_base_path = Path(log_base_path)
|
||||
self.logger_name = logger_name
|
||||
self.console_output = console_output
|
||||
self.instance_id = instance_id
|
||||
|
||||
self.log_format = log_format or (
|
||||
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
|
||||
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
|
||||
)
|
||||
|
||||
self.log_dir = self._create_log_dir()
|
||||
|
||||
self.file_handlers = {}
|
||||
|
||||
self.logger = self._setup_logger()
|
||||
|
||||
def _create_log_dir(self) -> Path:
|
||||
if self.instance_id:
|
||||
log_dir = self.log_base_path / self.instance_id
|
||||
else:
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M")
|
||||
log_dir = self.log_base_path / timestamp
|
||||
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return log_dir
|
||||
|
||||
def _setup_logger(self) -> logging.Logger:
|
||||
logger = logging.getLogger(self.logger_name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
logger.handlers.clear()
|
||||
|
||||
formatter = logging.Formatter(self.log_format)
|
||||
|
||||
if self.console_output:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
self._add_file_handlers(logger, formatter)
|
||||
|
||||
return logger
|
||||
|
||||
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
|
||||
log_levels = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"notice": NOTICE_LEVEL,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
}
|
||||
|
||||
for level_name, level_value in log_levels.items():
|
||||
log_file = self.log_dir / f"{level_name}.log"
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
if level_name == "debug":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
|
||||
elif level_name == "info":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
|
||||
elif level_name == "notice":
|
||||
file_handler.addFilter(lambda record: record.levelno == NOTICE_LEVEL)
|
||||
elif level_name == "warning":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
|
||||
elif level_name == "error":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
self.file_handlers[level_name] = file_handler
|
||||
|
||||
all_log_file = self.log_dir / "all.log"
|
||||
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
|
||||
all_file_handler.setFormatter(formatter)
|
||||
all_file_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(all_file_handler)
|
||||
self.file_handlers["all"] = all_file_handler
|
||||
|
||||
def debug(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.debug(message, *args, **kwargs)
|
||||
|
||||
def info(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.info(message, *args, **kwargs)
|
||||
|
||||
def notice(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.log(NOTICE_LEVEL, message, *args, **kwargs)
|
||||
|
||||
def warning(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.warning(message, *args, **kwargs)
|
||||
|
||||
def error(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.error(message, *args, **kwargs)
|
||||
|
||||
def get_log_dir(self) -> str:
|
||||
return str(self.log_dir.absolute())
|
||||
|
||||
def get_log_files(self) -> dict:
|
||||
return {
|
||||
"debug": str(self.log_dir / "debug.log"),
|
||||
"info": str(self.log_dir / "info.log"),
|
||||
"notice": str(self.log_dir / "notice.log"),
|
||||
"warning": str(self.log_dir / "warning.log"),
|
||||
"error": str(self.log_dir / "error.log"),
|
||||
"all": str(self.log_dir / "all.log"),
|
||||
}
|
||||
|
||||
@property
|
||||
def log_file(self) -> str:
|
||||
return str(self.log_dir / "all.log")
|
||||
|
||||
def close(self):
|
||||
for handler in self.file_handlers.values():
|
||||
handler.close()
|
||||
for handler in self.logger.handlers[:]:
|
||||
self.logger.removeHandler(handler)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Logger(name={self.logger_name}, dir={self.log_dir})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Logger("
|
||||
f"name='{self.logger_name}', "
|
||||
f"base_path='{self.log_base_path}', "
|
||||
f"log_dir='{self.log_dir}', "
|
||||
f"console_output={self.console_output})"
|
||||
)
|
||||
|
||||
|
||||
def create_logger(
|
||||
log_base_path: str,
|
||||
logger_name: str = "tokfinity_logger",
|
||||
console_output: bool = True,
|
||||
) -> Logger:
|
||||
return Logger(
|
||||
log_base_path=log_base_path,
|
||||
logger_name=logger_name,
|
||||
console_output=console_output,
|
||||
)
|
||||
|
||||
|
||||
# Global log manager instance (optional)
|
||||
_global_logger: Optional[Logger] = None
|
||||
|
||||
|
||||
def get_global_logger() -> Optional[Logger]:
|
||||
return _global_logger
|
||||
|
||||
|
||||
def set_global_logger(logger: Logger):
|
||||
global _global_logger
|
||||
_global_logger = logger
|
||||
|
||||
|
||||
def init_global_logger(
|
||||
log_base_path: str, logger_name: str = "global_logger"
|
||||
) -> Logger:
|
||||
global _global_logger
|
||||
_global_logger = Logger(log_base_path, logger_name)
|
||||
return _global_logger
|
||||
275
src/managers/loop/base.py
Normal file
275
src/managers/loop/base.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
from src.managers.llm_api.api_manager import LLMAPIManager
|
||||
from src.managers.prompts.prompts_manager import PromptsManager
|
||||
from src.tools.base import (
|
||||
ToolExecutor,
|
||||
BASH_TOOL_NAME,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
from src.managers.loop.types import ToolStats, LLMUsage
|
||||
|
||||
|
||||
|
||||
class BaseLoop:
|
||||
def __init__(self, instance_id: str, instance_data: Dict[str, Any], logger: Logger, prompts_manager: PromptsManager | None, llm_manager: LLMAPIManager | None, tool_executor: ToolExecutor, config: Dict[str, Any] | None = None):
|
||||
self.instance_id = instance_id
|
||||
self.instance_data = instance_data
|
||||
self.logger = logger
|
||||
self.prompts_manager = prompts_manager
|
||||
self.llm_manager = llm_manager
|
||||
self.tool_executor = tool_executor
|
||||
self.config = config or {}
|
||||
self.component_name = self.__class__.__name__
|
||||
|
||||
|
||||
def _make_assistant(
|
||||
self, content: str | None, tool_calls: Any, messages: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Construct an assistant message based on the current content and tool calls, and append it to the messages.
|
||||
"""
|
||||
safe_content = content or ""
|
||||
if not safe_content and not tool_calls:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] Assistant returned an empty message with no tool calls; skipping this message and prompting to continue"
|
||||
)
|
||||
messages.append(
|
||||
{"role": "user", "content": "请继续分析问题并使用工具来解决问题。"}
|
||||
)
|
||||
return False
|
||||
assistant_message: Dict[str, Any] = {"role": "assistant"}
|
||||
if tool_calls and not safe_content:
|
||||
assistant_message["content"] = ""
|
||||
elif safe_content:
|
||||
assistant_message["content"] = safe_content
|
||||
if tool_calls:
|
||||
assistant_message["tool_calls"] = tool_calls
|
||||
messages.append(assistant_message)
|
||||
return True
|
||||
|
||||
def _make_tool_response(
|
||||
self, tool_results: List[Any], messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Convert tool execution results into standard tool messages (role=tool) and append them to the messages.
|
||||
|
||||
- Generate content per result: use prompts_manager.tool_response_prompts([{...}]) to produce the content
|
||||
- Set tool_call_id: prefer ToolResult.id; fallback to ToolResult.call_id
|
||||
"""
|
||||
if not tool_results:
|
||||
return
|
||||
for result in tool_results:
|
||||
single_dict = [
|
||||
{
|
||||
"name": getattr(result, "name", "unknown"),
|
||||
"success": getattr(result, "success", False),
|
||||
"result": getattr(result, "result", None) or "",
|
||||
"error": getattr(result, "error", None) or "",
|
||||
}
|
||||
]
|
||||
content_text = (
|
||||
self.prompts_manager.tool_response_prompts(single_dict)
|
||||
if self.prompts_manager
|
||||
else ""
|
||||
)
|
||||
tool_call_id = getattr(result, "id", None) or getattr(
|
||||
result, "call_id", None
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": content_text,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
def _response_log(
|
||||
self, response: Any, first_content: str, first_tool_calls: Any, total_turns: int
|
||||
) -> None:
|
||||
"""notice log for the current turn's LLM output"""
|
||||
try:
|
||||
response_log: Dict[str, Any] = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
response_log["usage"] = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", None),
|
||||
"completion_tokens": getattr(response.usage, "completion_tokens", None),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", None),
|
||||
}
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
response_log["choice"] = {
|
||||
"message": {
|
||||
"content": first_content,
|
||||
"tool_calls": first_tool_calls,
|
||||
}
|
||||
}
|
||||
if response_log:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] The {total_turns}th turn output: {json.dumps(response_log, ensure_ascii=False)}"
|
||||
)
|
||||
else:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] The {total_turns}th turn output: {str(response)}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] 第 {total_turns} 轮: LLM 输出序列化失败,使用字符串表示: {str(response)}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
def _debug_messages(
|
||||
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
|
||||
) -> None:
|
||||
"""debug log for the messages to be sent to the model"""
|
||||
try:
|
||||
self.logger.debug(f"[{self.component_name}] msg:")
|
||||
recent_messages = messages[-2:] if len(messages) > 2 else messages
|
||||
base_index = len(messages) - len(recent_messages)
|
||||
for offset, msg in enumerate[Dict[str, Any]](recent_messages):
|
||||
idx = base_index + offset
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
content_str = content if isinstance(content, str) else ""
|
||||
preview = content_str[:prefix_len]
|
||||
content_len = len(content_str)
|
||||
extra = ""
|
||||
if role == "assistant":
|
||||
tool_calls = msg.get("tool_calls")
|
||||
has_tool = tool_calls is not None and tool_calls != []
|
||||
try:
|
||||
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] In debug_messages function, fail: {format_exc()}, tool calls: {tool_calls}."
|
||||
)
|
||||
tool_calls_json = str(tool_calls)
|
||||
extra = f", has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
|
||||
elif role == "tool":
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
extra = f", tool_call_id={tool_call_id}"
|
||||
self.logger.debug(
|
||||
f"[{self.component_name}] {turn+1}th, msg#{idx}: role={role}, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}{extra}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] In debug_messages function, fail msg: {format_exc()}."
|
||||
)
|
||||
|
||||
def _debug_last_message(
|
||||
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
|
||||
) -> None:
|
||||
"""debug last turn msg"""
|
||||
try:
|
||||
if not messages:
|
||||
return
|
||||
last_assistant_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].get("role") == "assistant":
|
||||
last_assistant_idx = i
|
||||
break
|
||||
if last_assistant_idx is None:
|
||||
return
|
||||
msg = messages[last_assistant_idx]
|
||||
content = msg.get("content")
|
||||
content_str = content if isinstance(content, str) else ""
|
||||
preview = content_str[:prefix_len]
|
||||
content_len = len(content_str)
|
||||
tool_calls = msg.get("tool_calls")
|
||||
has_tool = tool_calls is not None and tool_calls != []
|
||||
try:
|
||||
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] In debug_last_message function, fail: {format_exc()}, tool calls: {tool_calls}."
|
||||
)
|
||||
tool_calls_json = str(tool_calls)
|
||||
self.logger.debug(
|
||||
f"[{self.component_name}] {turn+1}th turn, output_preview: role=assistant, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}, has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] In debug_last_message function, last turn fail: {format_exc()}."
|
||||
)
|
||||
|
||||
def _debug_tools(self, tools: List[Dict[str, Any]]) -> None:
|
||||
"""debug tools msg"""
|
||||
try:
|
||||
self.logger.debug(f"[{self.component_name}] tools num: {len(tools)}")
|
||||
for i, tool in enumerate(tools):
|
||||
try:
|
||||
tool_json = json.dumps(tool, ensure_ascii=False)
|
||||
self.logger.debug(f"[{self.component_name}] tool #{i+1}: {tool_json}")
|
||||
except Exception:
|
||||
self.logger.debug(
|
||||
f"[{self.component_name}] tool #{i+1} fail: {format_exc()}, string: {str(tool)}."
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] fail; traceback: {format_exc()}."
|
||||
)
|
||||
self.logger.warning(f"[{self.component_name}] tools string: {str(tools)}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_tools(self) -> List[Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def _is_bash_tool(self, tool_name: str) -> bool:
|
||||
return BASH_TOOL_NAME in tool_name
|
||||
|
||||
def _is_edit_tool(self, tool_name: str) -> bool:
|
||||
return "edit" in tool_name or "str_replace" in tool_name or STR_REPLACE_BASED_EDIT_TOOL_NAME in tool_name
|
||||
|
||||
def _is_search_tool(self, tool_name: str) -> bool:
|
||||
return SEARCH_TOOL_NAME in tool_name or "search" in tool_name
|
||||
|
||||
def _is_submit_result_tool(self, tool_name: str) -> bool:
|
||||
return SUBMIT_RESULT_TOOL_NAME in tool_name
|
||||
|
||||
def _update_usage(self, response: Any, usage_stats: LLMUsage) -> None:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage_stats.prompt_tokens += int(getattr(response.usage, "prompt_tokens", 0) or 0)
|
||||
usage_stats.completion_tokens += int(
|
||||
getattr(response.usage, "completion_tokens", 0) or 0
|
||||
)
|
||||
usage_stats.total_tokens += int(getattr(response.usage, "total_tokens", 0) or 0)
|
||||
|
||||
def _init_usage_stats(self) -> LLMUsage:
|
||||
return LLMUsage()
|
||||
|
||||
def _init_tools_stats(self) -> ToolStats:
|
||||
return ToolStats()
|
||||
|
||||
def _update_tool_call_statistic(
|
||||
self, tool_results: List[Any], tool_stats: ToolStats
|
||||
) -> None:
|
||||
for result in tool_results:
|
||||
try:
|
||||
tool_name = getattr(result, "name", "")
|
||||
tool_name = tool_name.lower() if isinstance(tool_name, str) else ""
|
||||
success = bool(getattr(result, "success", False))
|
||||
|
||||
if self._is_bash_tool(tool_name):
|
||||
tool_stats.bash["count"] += 1
|
||||
if not success:
|
||||
tool_stats.bash["failed"] += 1
|
||||
|
||||
elif self._is_edit_tool(tool_name):
|
||||
tool_stats.edit["count"] += 1
|
||||
if not success:
|
||||
tool_stats.edit["failed"] += 1
|
||||
|
||||
elif self._is_search_tool(tool_name):
|
||||
tool_stats.search["count"] += 1
|
||||
if not success:
|
||||
tool_stats.search["failed"] += 1
|
||||
|
||||
elif self._is_submit_result_tool(tool_name):
|
||||
tool_stats.submit_result["count"] += 1
|
||||
if not success:
|
||||
tool_stats.submit_result["failed"] += 1
|
||||
except Exception:
|
||||
continue
|
||||
339
src/managers/loop/patch_generator.py
Normal file
339
src/managers/loop/patch_generator.py
Normal file
@@ -0,0 +1,339 @@
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
from src.managers.llm_api.api_manager import LLMAPIManager
|
||||
from src.managers.prompts.prompts_manager import PromptsManager
|
||||
from src.managers.loop.base import BaseLoop
|
||||
from src.tools.base import (
|
||||
ToolExecutor,
|
||||
ToolResult,
|
||||
SubmitToolResult,
|
||||
BASH_TOOL_NAME,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
|
||||
|
||||
class PatchGenerator(BaseLoop):
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
instance_data: Dict[str, Any],
|
||||
logger: Logger,
|
||||
prompts_manager: PromptsManager | None,
|
||||
llm_manager: LLMAPIManager | None,
|
||||
tool_executor: ToolExecutor,
|
||||
config: Dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
|
||||
|
||||
|
||||
async def _submit_all_tool_calls(
|
||||
self, other_tool_calls: List[Dict[str, Any]]
|
||||
) -> List[Any]:
|
||||
"""execute tool calls, return tool execution results list"""
|
||||
if not other_tool_calls:
|
||||
return []
|
||||
from src.tools.base import ToolCall
|
||||
|
||||
tool_call_objects = []
|
||||
for tool_call_dict in other_tool_calls:
|
||||
raw_args = tool_call_dict.get("function", {}).get("arguments", {})
|
||||
parsed_args = raw_args
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed_args = json.loads(raw_args)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"[{self.component_name}] In _submit_all_tool_calls function, fail: {e}, traceback: {format_exc()}, args: {raw_args}.")
|
||||
parsed_args = {}
|
||||
tool_call_obj = ToolCall(
|
||||
name=tool_call_dict.get("function", {}).get("name", ""),
|
||||
call_id=tool_call_dict.get("id", ""),
|
||||
arguments=parsed_args,
|
||||
id=tool_call_dict.get("id", ""),
|
||||
)
|
||||
tool_call_objects.append(tool_call_obj)
|
||||
return await self.tool_executor.container_sequential_tool_call(
|
||||
tool_call_objects
|
||||
)
|
||||
|
||||
def _process_submit_result_tool_result(
|
||||
self,
|
||||
submit_result: ToolResult,
|
||||
golden_patch: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""process submit_result tool call, fill golden_patch and log"""
|
||||
if not submit_result.success or not submit_result.result:
|
||||
self.logger.warning(f"[{self.component_name}] submit_result failed and no result.")
|
||||
return
|
||||
|
||||
try:
|
||||
submit_tool_result = SubmitToolResult.from_string(submit_result.result)
|
||||
|
||||
if submit_tool_result.output:
|
||||
patch_info = {
|
||||
"patch_content": submit_tool_result.output,
|
||||
"test_status": submit_tool_result.test_status,
|
||||
"reasoning": submit_tool_result.reasoning,
|
||||
}
|
||||
golden_patch.clear()
|
||||
golden_patch.append(patch_info)
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] patch len: {len(submit_tool_result.output)}."
|
||||
)
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] test status: {submit_tool_result.test_status}."
|
||||
)
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] reasoning: {submit_tool_result.reasoning[:100]}..."
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] submit_result success but no patch content."
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[{self.component_name}] parse submit_result result fail: {e}, traceback: {format_exc()}.")
|
||||
|
||||
def _get_tools(self) -> List[Dict[str, Any]]:
|
||||
tools = []
|
||||
#use_openai_format = self._should_use_openai_format()
|
||||
use_openai_format = True
|
||||
|
||||
for tool in self.tool_executor.tools.values():
|
||||
if use_openai_format:
|
||||
tool_def = tool._definition_for_openai_fmt()
|
||||
else:
|
||||
tool_def = tool._definition_for_claude_fmt()
|
||||
tools.append(tool_def)
|
||||
|
||||
return tools
|
||||
|
||||
def _should_use_openai_format(self) -> bool:
|
||||
|
||||
if not self.llm_manager or not hasattr(self.llm_manager, "get_model_name"):
|
||||
return True # openAI format by default
|
||||
|
||||
model_name = self.llm_manager.get_model_name().lower()
|
||||
|
||||
return "claude" not in model_name
|
||||
|
||||
def _get_issue_prompt(self) -> str:
|
||||
"""generate issue prompt based on instance data"""
|
||||
if not self.prompts_manager:
|
||||
self.logger.warning("PromptsManager not initialized, cannot generate issue prompt.")
|
||||
return ""
|
||||
|
||||
#instance_id = self.instance_data.get("instance_id", "")
|
||||
#repo = self.instance_data.get("repo", "")
|
||||
created_at = self.instance_data.get("created_at", "")
|
||||
base_commit = self.instance_data.get("base_commit", "")
|
||||
environment_setup_commit = self.instance_data.get(
|
||||
"environment_setup_commit", ""
|
||||
)
|
||||
version = self.instance_data.get("version", "")
|
||||
problem_statement = self.instance_data.get("problem_statement", "")
|
||||
difficulty = self.instance_data.get("difficulty", "")
|
||||
|
||||
return self.prompts_manager.format_issue_prompt(
|
||||
created_at=created_at,
|
||||
base_commit=base_commit,
|
||||
environment_setup_commit=environment_setup_commit,
|
||||
version=version,
|
||||
problem_statement=problem_statement,
|
||||
difficulty=difficulty,
|
||||
)
|
||||
|
||||
async def _generate_patch(self) -> Dict[str, Any] | None:
|
||||
"""main loop logic for generating candidate patch"""
|
||||
|
||||
usage_stats = self._init_usage_stats()
|
||||
tool_stats = self._init_tools_stats()
|
||||
|
||||
if not self.llm_manager or not self.prompts_manager:
|
||||
self.logger.error(f"[{self.component_name}] LLM manager or prompts manager not initialized.")
|
||||
return {
|
||||
"success": False,
|
||||
"golden_patch": [],
|
||||
"llm_usage": usage_stats.to_dict(),
|
||||
"tool_stats": tool_stats.to_dict(),
|
||||
"total_turns": 0,
|
||||
}
|
||||
|
||||
tools = self._get_tools()
|
||||
|
||||
self._debug_tools(tools)
|
||||
|
||||
root_path = self.config.get("builder", {}).get("repo_root_path", "")
|
||||
max_turn = (
|
||||
self.config.get("runner", {}).get("generator_loop", {}).get("max_turn", 10)
|
||||
)
|
||||
temperature = (
|
||||
self.config.get("runner", {})
|
||||
.get("generator_loop", {})
|
||||
.get("temperature", 0.2)
|
||||
)
|
||||
|
||||
issue_prompt = self._get_issue_prompt()
|
||||
user_prompt = self.prompts_manager.get_generator_user(root_path, issue_prompt)
|
||||
system_prompt = self.prompts_manager.get_generator_system(root_path)
|
||||
|
||||
|
||||
total_turns = 0
|
||||
golden_patch = []
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] {self.instance_id}: start generating candidate patch, max turn: {max_turn}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
|
||||
)
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
for turn in range(max_turn):
|
||||
total_turns = turn + 1
|
||||
self.logger.info(f"[{self.component_name}] The {total_turns}th turn started.")
|
||||
|
||||
try:
|
||||
current_input_msg = messages[-1] if messages else None
|
||||
if current_input_msg is not None:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] The {total_turns}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] {total_turns}th turn: LLM input fail: {messages[-1] if messages else None}, error: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
self._debug_messages(turn, messages)
|
||||
|
||||
response = self.llm_manager.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
first_content: str = ""
|
||||
first_tool_calls: Any = None
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
ch0 = response.choices[0]
|
||||
first_content = (
|
||||
getattr(getattr(ch0, "message", None), "content", None) or ""
|
||||
)
|
||||
first_tool_calls = getattr(
|
||||
getattr(ch0, "message", None), "tool_calls", None
|
||||
)
|
||||
|
||||
self._response_log(
|
||||
response, first_content, first_tool_calls, total_turns
|
||||
)
|
||||
|
||||
self._update_usage(response, usage_stats)
|
||||
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
content = first_content
|
||||
tool_calls = first_tool_calls
|
||||
|
||||
if not self._make_assistant(content, tool_calls, messages):
|
||||
continue
|
||||
|
||||
if tool_calls:
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] {total_turns}th turn: call {len(tool_calls)} tools."
|
||||
)
|
||||
|
||||
tool_results = await self._submit_all_tool_calls(tool_calls)
|
||||
|
||||
self._update_tool_call_statistic(tool_results, tool_stats)
|
||||
|
||||
if tool_results:
|
||||
submit_result = None
|
||||
other_tool_results = []
|
||||
|
||||
for tool_result in tool_results:
|
||||
tool_name = getattr(tool_result, "name", "")
|
||||
if tool_name == SUBMIT_RESULT_TOOL_NAME:
|
||||
submit_result = tool_result
|
||||
else:
|
||||
other_tool_results.append(tool_result)
|
||||
|
||||
if submit_result:
|
||||
self.logger.debug(
|
||||
f"[{self.component_name}] {total_turns}th turn: got submit_result tool call."
|
||||
)
|
||||
self.logger.debug(f"[{self.component_name}] {total_turns}th turn: submit_result result: {submit_result}")
|
||||
|
||||
self._process_submit_result_tool_result(
|
||||
submit_result, golden_patch
|
||||
)
|
||||
|
||||
self._debug_last_message(turn, messages)
|
||||
break
|
||||
|
||||
if other_tool_results:
|
||||
self._make_tool_response(other_tool_results, messages)
|
||||
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请继续分析问题并使用工具来解决问题。",
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.debug(f"[{self.component_name}] final golden_patch: {golden_patch}")
|
||||
|
||||
success = (
|
||||
len(golden_patch) > 0 and golden_patch[0].get("patch_content", "") != ""
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] status={success}, total_turns={total_turns}, tools_stats={tool_stats}"
|
||||
)
|
||||
|
||||
result_payload = {
|
||||
"success": success,
|
||||
"golden_patch": golden_patch,
|
||||
"llm_usage": usage_stats.to_dict(),
|
||||
"tool_stats": tool_stats.to_dict(),
|
||||
"total_turns": total_turns,
|
||||
}
|
||||
try:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] final output: {json.dumps(result_payload, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] output: {str(result_payload)}, error: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
return result_payload
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"[{self.component_name}] fail: {e}, traceback: {format_exc()}.")
|
||||
result_payload = {
|
||||
"success": False,
|
||||
"golden_patch": [],
|
||||
"llm_usage": usage_stats.to_dict(),
|
||||
"tool_stats": tool_stats.to_dict(),
|
||||
"total_turns": total_turns,
|
||||
}
|
||||
try:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] 最终返回数据(失败): {json.dumps(result_payload, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] 最终返回数据(失败, 字符串回退): {str(result_payload)}, error: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
return result_payload
|
||||
338
src/managers/loop/patch_selector.py
Normal file
338
src/managers/loop/patch_selector.py
Normal file
@@ -0,0 +1,338 @@
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
from src.managers.llm_api.api_manager import LLMAPIManager
|
||||
from src.managers.prompts.prompts_manager import PromptsManager
|
||||
from src.managers.loop.types import GeneratorResult, SelectorResult, LLMUsage, ToolStats, PatchInfo
|
||||
from src.tools.base import ToolExecutor, ToolCall, ToolResult
|
||||
from src.managers.loop.base import BaseLoop
|
||||
|
||||
SELECTOR_SUBMIT_TOOL_NAME = "submit_result"
|
||||
|
||||
class PatchSelector(BaseLoop):
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
instance_data: Dict[str, Any],
|
||||
logger: Logger,
|
||||
prompts_manager: PromptsManager | None,
|
||||
llm_manager: LLMAPIManager | None,
|
||||
tool_executor: ToolExecutor,
|
||||
config: Dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
|
||||
|
||||
def _get_submit_result_tool_name(self):
|
||||
return SELECTOR_SUBMIT_TOOL_NAME
|
||||
|
||||
def _definition_for_submit_tool(self, use_openai_format: bool) -> Dict[str, Any]:
|
||||
"""submit_result tool"""
|
||||
if use_openai_format:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self._get_submit_result_tool_name(),
|
||||
"description": "Submit the final selected patch index and reasoning.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "The chosen patch index (0-based).",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Detailed reasoning for the selection.",
|
||||
},
|
||||
},
|
||||
"required": ["index", "reason"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self._get_submit_result_tool_name(),
|
||||
"description": "Submit the final selected patch index and reasoning.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "The chosen patch index (0-based).",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Detailed reasoning for the selection.",
|
||||
},
|
||||
},
|
||||
"required": ["index", "reason"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _build_user_prompt(self, candidates: List[GeneratorResult], root_path: str) -> str:
|
||||
|
||||
if not self.prompts_manager:
|
||||
return ""
|
||||
return self.prompts_manager.get_selector_user(self.instance_data, candidates, root_path)
|
||||
|
||||
def _get_system_prompt(self, patches_count: int, root_path: str) -> str:
|
||||
if not self.prompts_manager:
|
||||
return ""
|
||||
return self.prompts_manager.get_selector_system(patches_count, root_path)
|
||||
|
||||
def _get_tools(self) -> List[Dict[str, Any]]:
|
||||
|
||||
tool_defs: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for tool in self.tool_executor.tools.values():
|
||||
try:
|
||||
tool_defs.append(tool._definition_for_openai_fmt())
|
||||
except Exception:
|
||||
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tool_defs.append(self._definition_for_submit_tool(True))
|
||||
return tool_defs
|
||||
|
||||
def _extract_submit_choice(self, tool_call: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
|
||||
if not tool_call:
|
||||
return None
|
||||
fn = tool_call.get("function", {})
|
||||
if fn.get("name") != self._get_submit_result_tool_name():
|
||||
return None
|
||||
raw_args = fn.get("arguments", {})
|
||||
try:
|
||||
args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args
|
||||
except Exception:
|
||||
args = {}
|
||||
index = args.get("index")
|
||||
reason = args.get("reason")
|
||||
if isinstance(index, int) and index >= 0:
|
||||
return {"index": index, "reason": reason or ""}
|
||||
return None
|
||||
|
||||
async def _submit_other_tool_calls(
|
||||
self, tool_calls: List[Dict[str, Any]]
|
||||
) -> List[ToolResult]:
|
||||
|
||||
if not tool_calls:
|
||||
return []
|
||||
to_run: List[ToolCall] = []
|
||||
for tool_call_dict in tool_calls:
|
||||
fn = tool_call_dict.get("function", {})
|
||||
name = fn.get("name", "")
|
||||
if name == SELECTOR_SUBMIT_TOOL_NAME:
|
||||
continue
|
||||
raw_args = fn.get("arguments", {})
|
||||
parsed_args = raw_args
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed_args = json.loads(raw_args)
|
||||
except Exception:
|
||||
parsed_args = {}
|
||||
to_run.append(
|
||||
ToolCall(
|
||||
name=name,
|
||||
call_id=tool_call_dict.get("id", ""),
|
||||
arguments=parsed_args,
|
||||
id=tool_call_dict.get("id", ""),
|
||||
)
|
||||
)
|
||||
if not to_run:
|
||||
return []
|
||||
results: List[ToolResult] = await self.tool_executor.container_sequential_tool_call(to_run)
|
||||
return results
|
||||
|
||||
async def _select_patch(self, candidates: List[GeneratorResult]) -> SelectorResult:
|
||||
|
||||
if not candidates:
|
||||
raise ValueError("No candidates provided")
|
||||
if not self.llm_manager:
|
||||
raise ValueError("LLM manager is not initialized")
|
||||
|
||||
|
||||
tools = self._get_tools()
|
||||
|
||||
self._debug_tools(tools)
|
||||
|
||||
root_path = self.config.get("builder", {}).get("repo_root_path", "")
|
||||
system_prompt = self._get_system_prompt(len(candidates), root_path)
|
||||
user_prompt = self._build_user_prompt(candidates, root_path)
|
||||
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
|
||||
)
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] Initial fail in selector loop: SP={str(messages[0])}, UP={str(messages[1])}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
max_turn = int(
|
||||
self.config.get("runner", {})
|
||||
.get("selector_loop", {})
|
||||
.get("max_turn", 200)
|
||||
)
|
||||
temperature = (
|
||||
self.config.get("runner", {})
|
||||
.get("selector_loop", {})
|
||||
.get("temperature", 0.2)
|
||||
)
|
||||
|
||||
usage_stats = self._init_usage_stats()
|
||||
tool_stats = self._init_tools_stats()
|
||||
|
||||
total_turns = 0
|
||||
|
||||
chosen_index: int | None = None
|
||||
select_reason: str = ""
|
||||
for turn in range(max_turn):
|
||||
try:
|
||||
try:
|
||||
current_input_msg = messages[-1] if messages else None
|
||||
if current_input_msg is not None:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] The {turn+1}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] {turn+1}th turn fail: {messages[-1] if messages else None}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
self._debug_messages(turn, messages)
|
||||
|
||||
response = self.llm_manager.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
first_tool_calls = None
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
ch0 = response.choices[0]
|
||||
first_tool_calls = getattr(getattr(ch0, "message", None), "tool_calls", None)
|
||||
first_content = getattr(getattr(ch0, "message", None), "content", None) or ""
|
||||
else:
|
||||
first_content = ""
|
||||
|
||||
total_turns = turn + 1
|
||||
|
||||
self._response_log(response, first_content, first_tool_calls, turn + 1)
|
||||
|
||||
self._update_usage(response, usage_stats)
|
||||
|
||||
|
||||
if first_tool_calls:
|
||||
|
||||
if not self._make_assistant(first_content, first_tool_calls, messages):
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
|
||||
}
|
||||
)
|
||||
continue
|
||||
submit_found = False
|
||||
for tc in first_tool_calls:
|
||||
choice = self._extract_submit_choice(tc)
|
||||
if choice is not None:
|
||||
chosen_index = choice["index"]
|
||||
reason = choice.get("reason", "")
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] choose: index={chosen_index}, reason={reason}"
|
||||
)
|
||||
select_reason = reason or ""
|
||||
submit_found = True
|
||||
|
||||
self._debug_last_message(turn, messages)
|
||||
break
|
||||
|
||||
|
||||
if not submit_found:
|
||||
|
||||
results = await self._submit_other_tool_calls(first_tool_calls)
|
||||
|
||||
self._make_tool_response(results, messages)
|
||||
|
||||
self._update_tool_call_statistic(results, tool_stats)
|
||||
|
||||
else:
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
|
||||
}
|
||||
)
|
||||
|
||||
if chosen_index is not None:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] fail: {e}, traceback: {format_exc()}"
|
||||
)
|
||||
break
|
||||
|
||||
if chosen_index is None:
|
||||
# If the model provides no choice, fallback: pick the first successful one; otherwise the first
|
||||
for i, r in enumerate(candidates):
|
||||
try:
|
||||
if r.success:
|
||||
chosen_index = i
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if chosen_index is None:
|
||||
chosen_index = 0
|
||||
|
||||
if not (0 <= chosen_index < len(candidates)):
|
||||
chosen_index = 0
|
||||
|
||||
selected = candidates[chosen_index]
|
||||
|
||||
try:
|
||||
gp = selected.golden_patch[0] if selected.golden_patch else None
|
||||
if gp is None:
|
||||
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
|
||||
else:
|
||||
patch_info = PatchInfo(
|
||||
patch_content=gp.patch_content,
|
||||
test_status=gp.test_status,
|
||||
reasoning=gp.reasoning,
|
||||
)
|
||||
except Exception:
|
||||
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
|
||||
|
||||
selector_result = SelectorResult(
|
||||
instance_id=selected.instance_id,
|
||||
generator_id=selected.generator_id,
|
||||
image=selected.image,
|
||||
success=True,
|
||||
golden_patch=patch_info,
|
||||
llm_usage=usage_stats,
|
||||
tool_stats=tool_stats,
|
||||
total_turns=total_turns,
|
||||
select_reason=select_reason,
|
||||
error=None,
|
||||
)
|
||||
return selector_result
|
||||
|
||||
|
||||
254
src/managers/loop/types.py
Normal file
254
src/managers/loop/types.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
This module defines the GeneratorResult data structure for patch generation results.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Any, Optional
|
||||
from src.tools.base import (
|
||||
BASH_TOOL_NAME,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMUsage:
|
||||
"""LLM usage statistics."""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, int]:
|
||||
"""Serialize LLMUsage to a plain dictionary."""
|
||||
return {
|
||||
"prompt_tokens": int(self.prompt_tokens),
|
||||
"completion_tokens": int(self.completion_tokens),
|
||||
"total_tokens": int(self.total_tokens),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolStats:
|
||||
"""Tool usage statistics per tool.
|
||||
|
||||
Each tool is represented by a small map with two fields:
|
||||
- count: total invocation count
|
||||
- failed: failed invocation count
|
||||
"""
|
||||
|
||||
bash: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
|
||||
edit: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
|
||||
search: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
|
||||
submit_result: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, int]]:
|
||||
"""Serialize ToolStats to a plain dictionary."""
|
||||
return {
|
||||
BASH_TOOL_NAME: {"count": int(self.bash.get("count", 0)), "failed": int(self.bash.get("failed", 0))},
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: {"count": int(self.edit.get("count", 0)), "failed": int(self.edit.get("failed", 0))},
|
||||
SEARCH_TOOL_NAME: {"count": int(self.search.get("count", 0)), "failed": int(self.search.get("failed", 0))},
|
||||
SUBMIT_RESULT_TOOL_NAME: {"count": int(self.submit_result.get("count", 0)), "failed": int(self.submit_result.get("failed", 0))},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchInfo:
|
||||
"""Information about a generated patch."""
|
||||
|
||||
patch_content: str
|
||||
test_status: str
|
||||
reasoning: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorResult:
|
||||
"""Result from a patch generator."""
|
||||
|
||||
instance_id: str
|
||||
generator_id: int
|
||||
image: str
|
||||
success: bool
|
||||
golden_patch: List[
|
||||
PatchInfo
|
||||
]
|
||||
llm_usage: LLMUsage
|
||||
tool_stats: ToolStats
|
||||
total_turns: int
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GeneratorResult":
|
||||
"""Create GeneratorResult from dictionary."""
|
||||
# Handle golden_patch conversion
|
||||
golden_patch = []
|
||||
if data.get("golden_patch"):
|
||||
for patch_data in data["golden_patch"]:
|
||||
if isinstance(patch_data, dict):
|
||||
golden_patch.append(
|
||||
PatchInfo(
|
||||
patch_content=patch_data.get("patch_content", ""),
|
||||
test_status=patch_data.get("test_status", ""),
|
||||
reasoning=patch_data.get("reasoning", ""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Legacy format: just patch content string
|
||||
golden_patch.append(
|
||||
PatchInfo(
|
||||
patch_content=str(patch_data), test_status="", reasoning=""
|
||||
)
|
||||
)
|
||||
|
||||
# Handle LLM usage
|
||||
llm_usage_data = data.get("llm_usage", {})
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=llm_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=llm_usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
# Handle tool stats
|
||||
tool_stats_data = data.get("tool_stats", {})
|
||||
tool_stats = ToolStats(
|
||||
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
|
||||
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
|
||||
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
|
||||
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
|
||||
)
|
||||
|
||||
return cls(
|
||||
instance_id=data.get("instance_id", ""),
|
||||
generator_id=data.get("generator_id", 0),
|
||||
image=data.get("image", ""),
|
||||
success=data.get("success", False),
|
||||
golden_patch=golden_patch,
|
||||
llm_usage=llm_usage,
|
||||
tool_stats=tool_stats,
|
||||
total_turns=data.get("total_turns", 0),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert GeneratorResult to dictionary."""
|
||||
return {
|
||||
"instance_id": self.instance_id,
|
||||
"generator_id": self.generator_id,
|
||||
"image": self.image,
|
||||
"success": self.success,
|
||||
"golden_patch": [
|
||||
{
|
||||
"patch_content": patch.patch_content,
|
||||
"test_status": patch.test_status,
|
||||
"reasoning": patch.reasoning,
|
||||
}
|
||||
for patch in self.golden_patch
|
||||
],
|
||||
"llm_usage": {
|
||||
"prompt_tokens": self.llm_usage.prompt_tokens,
|
||||
"completion_tokens": self.llm_usage.completion_tokens,
|
||||
"total_tokens": self.llm_usage.total_tokens,
|
||||
},
|
||||
"tool_stats": {
|
||||
BASH_TOOL_NAME: self.tool_stats.bash,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
|
||||
SEARCH_TOOL_NAME: self.tool_stats.search,
|
||||
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
|
||||
},
|
||||
"total_turns": self.total_turns,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SelectorResult:
|
||||
"""Result from a patch selector.
|
||||
"""
|
||||
|
||||
instance_id: str
|
||||
generator_id: int
|
||||
image: str
|
||||
success: bool
|
||||
golden_patch: PatchInfo
|
||||
llm_usage: LLMUsage
|
||||
tool_stats: ToolStats
|
||||
total_turns: int
|
||||
select_reason: str
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SelectorResult":
|
||||
"""Create SelectorResult from dictionary."""
|
||||
|
||||
gp_data = data.get("golden_patch", {})
|
||||
if isinstance(gp_data, dict):
|
||||
golden_patch = PatchInfo(
|
||||
patch_content=gp_data.get("patch_content", ""),
|
||||
test_status=gp_data.get("test_status", ""),
|
||||
reasoning=gp_data.get("reasoning", ""),
|
||||
)
|
||||
else:
|
||||
golden_patch = PatchInfo(
|
||||
patch_content=str(gp_data) if gp_data is not None else "",
|
||||
test_status="",
|
||||
reasoning="",
|
||||
)
|
||||
|
||||
# LLM usage
|
||||
llm_usage_data = data.get("llm_usage", {})
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=llm_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=llm_usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
# Tool stats
|
||||
tool_stats_data = data.get("tool_stats", {})
|
||||
tool_stats = ToolStats(
|
||||
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
|
||||
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
|
||||
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
|
||||
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
|
||||
)
|
||||
|
||||
return cls(
|
||||
instance_id=data.get("instance_id", ""),
|
||||
generator_id=data.get("generator_id", 0),
|
||||
image=data.get("image", ""),
|
||||
success=data.get("success", False),
|
||||
golden_patch=golden_patch,
|
||||
llm_usage=llm_usage,
|
||||
tool_stats=tool_stats,
|
||||
total_turns=data.get("total_turns", 0),
|
||||
select_reason=data.get("select_reason", ""),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert SelectorResult to dictionary."""
|
||||
return {
|
||||
"instance_id": self.instance_id,
|
||||
"generator_id": self.generator_id,
|
||||
"image": self.image,
|
||||
"success": self.success,
|
||||
"golden_patch": {
|
||||
"patch_content": self.golden_patch.patch_content,
|
||||
"test_status": self.golden_patch.test_status,
|
||||
"reasoning": self.golden_patch.reasoning,
|
||||
},
|
||||
"llm_usage": {
|
||||
"prompt_tokens": self.llm_usage.prompt_tokens,
|
||||
"completion_tokens": self.llm_usage.completion_tokens,
|
||||
"total_tokens": self.llm_usage.total_tokens,
|
||||
},
|
||||
"tool_stats": {
|
||||
BASH_TOOL_NAME: self.tool_stats.bash,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
|
||||
SEARCH_TOOL_NAME: self.tool_stats.search,
|
||||
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
|
||||
},
|
||||
"total_turns": self.total_turns,
|
||||
"select_reason": self.select_reason,
|
||||
"error": self.error,
|
||||
}
|
||||
268
src/managers/prompts/prompts_manager.py
Normal file
268
src/managers/prompts/prompts_manager.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from typing import Any, List, Dict
|
||||
|
||||
|
||||
class PromptsManager:
|
||||
def __init__(self, config):
|
||||
self.candidate_length = config.get("runner", {}).get("generator_concurrency", 5)
|
||||
|
||||
def get_generator_system(self, root_path: str | None = None):
|
||||
return f"""
|
||||
# You are a highly skilled expert in software engineering focused on resolving complex GitHub issues by effectively analyzing codebases, implementing fixes, and ensuring code reliability through rigorous testing.
|
||||
|
||||
## Skills
|
||||
1. Code Analysis and Debugging
|
||||
- Issue Exploration: Ability to explore and comprehend codebases within repositories.
|
||||
- Workflow Tracing: Skilled in using debugging techniques to trace issues through code.
|
||||
- Root Cause Identification: Proficient in pinpointing the underlying causes of software issues.
|
||||
- Test Creation: Expertise in establishing tests that replicate and validate issues.
|
||||
|
||||
2. Solution Implementation and Testing
|
||||
- Fix Implementation: Experience in crafting precise and minimal code patches.
|
||||
- Comprehensive Testing: Skilled in running and analyzing both existing and newly created tests.
|
||||
- Regression Prevention: Ensures all changes maintain overall code stability.
|
||||
- Continuous Improvement: Iterates on solutions based on test results to achieve optimal functionality.
|
||||
|
||||
## Task
|
||||
Your task is resolve the given GitHub issue by understanding the repository and the issue, implementing a fix, and checking your changes against existing tests and your own test(s).
|
||||
|
||||
Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project.
|
||||
|
||||
For example, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
|
||||
|
||||
Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and I activated the virtual environment for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
|
||||
|
||||
Follow these steps:
|
||||
|
||||
1. Problem Analysis:
|
||||
- Read the issue description carefully to fully grasp the issue and explore the repository (source code, tests, examples) to understand expected behavior of relevant components.
|
||||
- Identify the full scope. Does the issue mention multiple components, backends, or functions? Your solution must address all of them.
|
||||
|
||||
2. Reproduce the issue (IMPORTANT):
|
||||
- Create a test that reproduces the issue as a baseline for verification.
|
||||
- Check that the output of your test matches your understanding of the issue in step 1.
|
||||
|
||||
3. Identify the root cause:
|
||||
- Go through relavant files, create debugging scripts with print statements or use other methods if necessary,to trace the workflow and exact cause of the issue.
|
||||
- Trace the problem to its root cause.** Do not just patch the symptom where the error appears. Trace the data and execution flow upstream to find where the problem originates.
|
||||
|
||||
4. Implement a Fix:
|
||||
- Once you have identified the root cause, develop a precise and targeted fix and then apply it as a minimal patch using the `str_replace_based_edit_tool` tools.
|
||||
|
||||
5. Test comprehensively:
|
||||
- Verify the Fix: Run your initial reproduction script to confirm that the bug is resolved.
|
||||
- Prevent Regressions:
|
||||
--Identify the right tests: Once you have verified your fix, identify the most relevant tests within the project's existing test suite that correspond to your code changes.
|
||||
--Run the tests: Then you **must** run these tests to ensure that your fix does not introduce any new bugs.
|
||||
--Analyze failures carefully:
|
||||
---If tests fail, do not immediately assume your fix is wrong. Critically analyze the failure.
|
||||
---Is it a **regression**? Did your change break existing, valid functionality? If so, you must refine your fix.
|
||||
---Is it an **unrelated failure**? It could be an environmental issue (e.g., missing dependency, network error) or a pre-existing flaky test. If you suspect this, try to run a more focused test and note the issue in your final reasoning.
|
||||
---Is the **test now obsolete**? If your fix improves behavior in a way that makes an old test's assertions incorrect, you should **update the test** to match the new, correct behavior and explain why in your reasoning.
|
||||
- Write New Tests: Create new, specific test cases (e.g., using `pytest`) that cover the original bug scenario.
|
||||
- Consider Edge Cases: Think about and test potential edge cases related to your changes.
|
||||
|
||||
6. Revisit step 1 through 5 if unexpected behavior occurs, then call `submit_result` to submit the reliable and verified solution patch after successful testing and validation.
|
||||
|
||||
**Mandatory Workflow** As a senior engineer, ensure solution correctness and safety. Upon successful verification, immediately conclude the task by calling `submit_result`.
|
||||
"""
|
||||
|
||||
def format_issue_prompt(
|
||||
self,
|
||||
created_at: str,
|
||||
base_commit: str,
|
||||
environment_setup_commit: str,
|
||||
version: str,
|
||||
problem_statement: str,
|
||||
difficulty: str,
|
||||
) -> str:
|
||||
|
||||
template = f"""
|
||||
[📝 Issue Description]
|
||||
|
||||
**Created at**: {created_at}
|
||||
**Base commit**: {base_commit}
|
||||
---
|
||||
|
||||
### 📌 Problem Statement
|
||||
{problem_statement}
|
||||
---
|
||||
|
||||
### ⚙️ Difficulty Level
|
||||
{difficulty}
|
||||
---
|
||||
"""
|
||||
return template.strip()
|
||||
|
||||
def get_generator_user(self, root_path: str, issue_text: str):
|
||||
return (
|
||||
f"""
|
||||
[Project root path]:
|
||||
{root_path}
|
||||
[Issue Information]:
|
||||
{issue_text}
|
||||
"""
|
||||
+ self.get_generator_notice()
|
||||
)
|
||||
|
||||
def get_generator_notice(self):
|
||||
return """
|
||||
[notice]
|
||||
1. Use the available tools to locate the root cause.
|
||||
2. Prioritize using the `search_tool` to retrieve and locate the precise location of key information in the project.
|
||||
3. Collect supporting evidence: stack traces, logs, configs, recent changes, related modules.
|
||||
"""
|
||||
|
||||
def get_selector_system(self, patches_count: int, root_path: str):
|
||||
return f"""
|
||||
# ROLE:
|
||||
|
||||
*You are a highly proficient software engineer tasked with evaluating and selecting optimal code patches to resolve specific issues within a given project.
|
||||
|
||||
*You colleagus worked on {patches_count} potential patches for an github issue. Select ONE correct patch to solve the issue.
|
||||
|
||||
*Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and the virtual environment has been activated for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
|
||||
|
||||
*Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project. For instance, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
|
||||
|
||||
# WORKFLOWS:
|
||||
*Follow these steps without any skipping:
|
||||
1.Problem Analysis:
|
||||
- Read the issue description and the current code that needs to be fixed. Explore the repository (source code, tests, examples) to understand expected behavior of relevant components, and gather comprehensive information about the problem area
|
||||
|
||||
2.Conduct a thorough review of each patch:
|
||||
- Scrutinize all code modifications.
|
||||
- Decipher the core logic and problem-solving methodology.
|
||||
- Evaluate potential edge cases and unintended consequences.
|
||||
- Validate that each patch fully addresses the initial issue specifications.
|
||||
|
||||
3.Verify Your Analysis
|
||||
- Use available tools to verify your analysis works of this issue.
|
||||
- Test your conclusions against relevant code sections.
|
||||
- Ensure full contextual understanding.
|
||||
|
||||
4.Proceed with Your Decision
|
||||
- Upon completion of the preceding three steps, utilize the `submit_result` tool with your detailed reasoning.
|
||||
|
||||
#RULES:
|
||||
1.It is MANDATORY to utilize both available tools prior to finalizing any selectio:
|
||||
-- Start with `bash` to explore the codebase structure;
|
||||
-- Employ the str_replace_based_edit_tool to inspect the current code;
|
||||
-- Use `search_tool` to search related code and file;
|
||||
2.You MUST first explore the codebase before using the `submit_result` tool.
|
||||
3.Substantiate your reasoning with evidence from your analysis.
|
||||
4.Only selections made after employing the tools will be accepted.
|
||||
|
||||
#FINAL DECISION:
|
||||
Upon completion of your tool-based analysis, finalize the process by submitting your choice via the `submit_result` tool.
|
||||
|
||||
#NOTICE:
|
||||
1. Tool usage is MANDATORY - do not skip this step.
|
||||
2. Without making a decision after completing analysis is not permitted.
|
||||
3. Never generate new patches by your own, just make the selection.
|
||||
4. Always provide detailed reasoning for the selection based on your tool-based investigation
|
||||
"""
|
||||
|
||||
def get_selector_user(
|
||||
self, instance_data: Dict[str, Any] | None = None, candidates: List[Any] | None = None, root_path: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate user prompt of selector, including issue information and the first golden patch of each candidate.
|
||||
|
||||
- instance_data: Current instance metadata (issue description etc.)
|
||||
- candidates: Candidates list (.to_dict() supported), only get golden_patch[0].patch_content
|
||||
"""
|
||||
if not instance_data or not candidates:
|
||||
return ""
|
||||
|
||||
created_at = instance_data.get("created_at", "")
|
||||
base_commit = instance_data.get("base_commit", "")
|
||||
environment_setup_commit = instance_data.get("environment_setup_commit", "")
|
||||
version = instance_data.get("version", "")
|
||||
problem_statement = instance_data.get("problem_statement", "")
|
||||
difficulty = instance_data.get("difficulty", "")
|
||||
|
||||
issue_block = self.format_issue_prompt(
|
||||
created_at=created_at,
|
||||
base_commit=base_commit,
|
||||
environment_setup_commit=environment_setup_commit,
|
||||
version=version,
|
||||
problem_statement=problem_statement,
|
||||
difficulty=difficulty,
|
||||
)
|
||||
|
||||
root_path_block = f"""
|
||||
[Project root path]:
|
||||
{root_path or "/testbed"}
|
||||
"""
|
||||
|
||||
parts: List[str] = [root_path_block, issue_block, "\n[🔎 Candidates]\n"]
|
||||
for idx, r in enumerate(candidates):
|
||||
try:
|
||||
data = r.to_dict() if hasattr(r, "to_dict") else {}
|
||||
except Exception:
|
||||
data = {}
|
||||
golden_patch = data.get("golden_patch", [])
|
||||
patch_content = golden_patch[0].get("patch_content", "") if golden_patch else ""
|
||||
test_status = golden_patch[0].get("test_status", "") if golden_patch else ""
|
||||
reasoning = golden_patch[0].get("reasoning", "") if golden_patch else ""
|
||||
parts.append(self.format_selector_candidate(idx, patch_content, test_status, reasoning))
|
||||
|
||||
parts.append(
|
||||
"\nPlease analyze the candidates, then call the submit_result tool with the final index and reasoning."
|
||||
)
|
||||
return "\n".join(parts)
|
||||
|
||||
def get_terminal_response(self, exit_code: int, output: str, timeout_status: bool):
|
||||
if timeout_status == True:
|
||||
return f"""[Terminal response]
|
||||
Exit code: {exit_code}
|
||||
Output: {output}"""
|
||||
else:
|
||||
return f"""[Terminal response]
|
||||
Terminal time out."""
|
||||
|
||||
def tool_response_prompts(self, tool_results: list) -> str:
|
||||
if not tool_results:
|
||||
return ""
|
||||
|
||||
response_parts = ["[tool_response]"]
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
output = result.get("result", "")
|
||||
error = result.get("error", "")
|
||||
|
||||
response_parts.append(f"Tool {i}: {tool_name}")
|
||||
response_parts.append(f"Success: {success}")
|
||||
|
||||
if success and output:
|
||||
response_parts.append(f"Output:\n{output}")
|
||||
elif error:
|
||||
response_parts.append(f"Error: {error}")
|
||||
else:
|
||||
response_parts.append("No output")
|
||||
|
||||
response_parts.append("") # 空行分隔
|
||||
|
||||
return "\n".join(response_parts)
|
||||
|
||||
def format_selector_candidate(self, index: int, patch_content: str, test_status: str, reasoning: str) -> str:
|
||||
"""
|
||||
Generate description of selector candidate items, including key information of the first golden_patch
|
||||
|
||||
- index: Candidate index(0-based)
|
||||
- patch_content: golden_patch[0].patch_content
|
||||
- test_status: golden_patch[0].test_status test status in generating stage
|
||||
- reasoning: golden_patch[0].reasoning model reasoning in generating stage
|
||||
"""
|
||||
header = f"- Candidate #{index}:"
|
||||
patch_block = patch_content or ""
|
||||
status_block = test_status or ""
|
||||
reasoning_block = reasoning or ""
|
||||
return (
|
||||
f"--{header}\n"
|
||||
f"--Patch content (the proposed fix):\n{patch_block}\n\n"
|
||||
f"--Test status during generation: {status_block}\n\n"
|
||||
f"--Reasoning during generation (model's logic):\n{reasoning_block}"
|
||||
)
|
||||
103
src/managers/result_builder/result_builder.py
Normal file
103
src/managers/result_builder/result_builder.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
import json
|
||||
|
||||
|
||||
class ResultBuilder:
|
||||
"""
|
||||
Build preds.json
|
||||
|
||||
Iterate through JSON files named by instance_id in the directory specified by `runner.selector_result_dump_path` in the configuration.
|
||||
- Parse the `golden_patch` field from each JSON file and extract the patch text as `model_patch`.
|
||||
- Read the first top-level field from providers and the first model under it, and concatenate them to form `model_name_or_path`.
|
||||
- The output location is `{workspace.path}/{result.preds.path}/preds.json`.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config or {}
|
||||
|
||||
def _get_selector_dump_dir(self) -> Path:
|
||||
runner_cfg = self.config.get("runner", {}) if isinstance(self.config, dict) else {}
|
||||
dump_dir_str = runner_cfg.get(
|
||||
"selector_result_dump_path", "workspace/selector_result_dump"
|
||||
)
|
||||
return Path(dump_dir_str)
|
||||
|
||||
def _get_preds_output_dir(self) -> Path:
|
||||
workspace_cfg = self.config.get("workspace", {}) if isinstance(self.config, dict) else {}
|
||||
result_cfg = self.config.get("result", {}) if isinstance(self.config, dict) else {}
|
||||
preds_cfg = result_cfg.get("preds", {}) if isinstance(result_cfg, dict) else {}
|
||||
|
||||
workspace_path = workspace_cfg.get("path", "workspace")
|
||||
preds_path = preds_cfg.get("path", "result")
|
||||
return Path(workspace_path) / preds_path
|
||||
|
||||
def _get_model_name_or_path(self) -> str:
|
||||
providers = self.config.get("providers", {}) if isinstance(self.config, dict) else {}
|
||||
if not isinstance(providers, dict) or not providers:
|
||||
return ""
|
||||
first_provider_name = next(iter(providers.keys()))
|
||||
first_models = providers.get(first_provider_name, [])
|
||||
if isinstance(first_models, list) and first_models:
|
||||
first_model = first_models[0]
|
||||
else:
|
||||
first_model = ""
|
||||
return f"{first_provider_name}/{first_model}" if first_provider_name and first_model else ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_patch(golden_patch: Any) -> str:
|
||||
"""
|
||||
Extract patch content from golden_patch
|
||||
|
||||
Forms supported:
|
||||
- dict: prioritize extract 'patch_content', then attempt `model_patch`
|
||||
- string: Directly return
|
||||
- other: return empty string
|
||||
"""
|
||||
if isinstance(golden_patch, dict):
|
||||
if "patch_content" in golden_patch and isinstance(golden_patch["patch_content"], str):
|
||||
return golden_patch["patch_content"]
|
||||
if "model_patch" in golden_patch and isinstance(golden_patch["model_patch"], str):
|
||||
return golden_patch["model_patch"]
|
||||
return ""
|
||||
if isinstance(golden_patch, str):
|
||||
return golden_patch
|
||||
return ""
|
||||
|
||||
def build_preds(self) -> Path:
|
||||
dump_dir = self._get_selector_dump_dir()
|
||||
output_dir = self._get_preds_output_dir()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_file = output_dir / "preds.json"
|
||||
|
||||
model_name_or_path = self._get_model_name_or_path()
|
||||
|
||||
# SWE-bench evaluation expects: list[dict], each element includes instance_id / model_patch / model
|
||||
predictions: list[dict[str, str]] = []
|
||||
|
||||
if dump_dir.exists() and dump_dir.is_dir():
|
||||
for path in sorted(dump_dir.glob("*.json")):
|
||||
try:
|
||||
instance_id = path.stem
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
golden_patch = data.get("golden_patch", {}) if isinstance(data, dict) else {}
|
||||
model_patch = self._extract_model_patch(golden_patch)
|
||||
predictions.append(
|
||||
{
|
||||
"instance_id": instance_id,
|
||||
"model_patch": model_patch,
|
||||
"model_name_or_path": model_name_or_path,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(predictions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return output_file
|
||||
|
||||
|
||||
Reference in New Issue
Block a user