mirror of
https://github.com/Tokfinity/InfCode.git
synced 2026-02-12 13:12:45 +00:00
225
.gitignore
vendored
Normal file
225
.gitignore
vendored
Normal file
@@ -0,0 +1,225 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
#poetry.toml
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||
#pdm.lock
|
||||
#pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# pixi
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||
#pixi.lock
|
||||
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||
.pixi
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.envrc
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Abstra
|
||||
# Abstra is an AI-powered process automation framework.
|
||||
# Ignore directories containing user credentials, local state, and settings.
|
||||
# Learn more at https://abstra.io/docs
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
# .vscode/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Cursor
|
||||
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
||||
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
||||
# refer to https://docs.cursor.com/context/ignore-files
|
||||
.cursorignore
|
||||
.cursorindexingignore
|
||||
|
||||
# Marimo
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
|
||||
# workspace
|
||||
workspace/
|
||||
|
||||
# logs
|
||||
logs/
|
||||
|
||||
nohup.out
|
||||
|
||||
gold.validate-gold.json
|
||||
|
||||
.DS_Store
|
||||
.idea/
|
||||
|
||||
# instance_to_image.json
|
||||
src/managers/image_builder/instance_to_image.json
|
||||
|
||||
/batch_out
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12.9
|
||||
7
LICENSE
Normal file
7
LICENSE
Normal file
@@ -0,0 +1,7 @@
|
||||
Copyright 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
1
__init__.py
Normal file
1
__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# project root directory marker file
|
||||
275
_batch_run.py
Normal file
275
_batch_run.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, List, Optional
|
||||
import multiprocessing
|
||||
from functools import partial
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
issue_list: Optional[List[str]] = None
|
||||
pass
|
||||
|
||||
|
||||
class Args:
|
||||
config: str
|
||||
output: str
|
||||
parallel: int
|
||||
name: str
|
||||
issue_list: str
|
||||
custom: str
|
||||
clean: bool
|
||||
dry_run: bool
|
||||
pass
|
||||
|
||||
|
||||
class Context:
|
||||
config: Config
|
||||
args: Args
|
||||
|
||||
def __init__(self, config: Config, args: Args):
|
||||
self.args = args
|
||||
self.config = config
|
||||
|
||||
|
||||
def main():
|
||||
(args, cfg) = parse_args()
|
||||
ctx = Context(config=cfg, args=args)
|
||||
print(f"Input args: {args} {args.config}")
|
||||
|
||||
out_dir = Path(ctx.args.output).joinpath(format(f"batch-{ctx.args.name}"))
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not out_dir.is_dir():
|
||||
raise ValueError(f"{out_dir}is not a directory")
|
||||
if ctx.args.clean:
|
||||
shutil.rmtree(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
now = time.time()
|
||||
p = BatchProcessExecutor(ctx, out_dir=out_dir)
|
||||
p.execute_all_tasks()
|
||||
duration = int(time.time() - now)
|
||||
print(
|
||||
f"DONE Total time cost {int(duration/3600)}h {int(duration/60%60)}m {int(duration%60)}s"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcResult:
|
||||
issue: str
|
||||
idx: int
|
||||
duration: int
|
||||
summary: dict[str, Any]
|
||||
|
||||
|
||||
def worder_func_for_gen_result(ctx: Context, out_dir: Path):
|
||||
print(f"worder_func_for_gen_result...")
|
||||
import _run
|
||||
config_path = Path(__file__).parent / "config" / "config.yaml"
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
cfg["workspace"]["path"] = str(out_dir)
|
||||
cfg["result"]["preds"]["result"] = "result"
|
||||
cfg["runner"]["selector_result_dump_path"] = str(
|
||||
out_dir.joinpath("selector_result_dump")
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if not ctx.args.dry_run:
|
||||
_run.gen_result(cfg)
|
||||
time.sleep(1)
|
||||
else:
|
||||
time.sleep(2)
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(
|
||||
f"worder_func_for_gen_result DONE time cost:{int(duration/60)}m {int(duration%60)}s"
|
||||
)
|
||||
|
||||
|
||||
def worker_function(ctx: Context, issue: str, idx: int, out_dir: Path) -> ProcResult:
|
||||
print(f"worker_function idx:{idx} issue:{issue}")
|
||||
|
||||
issue_out_dir = out_dir.joinpath("issues").joinpath(f"{idx:03d}-{issue}")
|
||||
issue_out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
issue_out_log_dir = issue_out_dir.joinpath("logs")
|
||||
issue_out_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generator_result_dump_path = out_dir.joinpath("generator_result_dump")
|
||||
generator_result_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
selector_result_dump_path = out_dir.joinpath("selector_result_dump")
|
||||
selector_result_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
original_stdout = sys.stdout
|
||||
original_stderr = sys.stderr
|
||||
sys.stdout = open(issue_out_dir.joinpath("stdout.log"), "a")
|
||||
sys.stderr = open(issue_out_dir.joinpath("stdout.log"), "a")
|
||||
# signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
config_path = Path(__file__).parent / "config" / "config.yaml"
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
import _run
|
||||
|
||||
config_path = Path(__file__).parent / "config" / "config.yaml"
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
cfg["log"]["base_path"] = str(issue_out_log_dir)
|
||||
cfg["runner"]["generator_result_dump_path"] = str(generator_result_dump_path)
|
||||
cfg["runner"]["selector_result_dump_path"] = str(selector_result_dump_path)
|
||||
if not ctx.args.dry_run:
|
||||
summary = _run.run_one_issue(cfg, issue=issue)
|
||||
time.sleep(1)
|
||||
else:
|
||||
time.sleep(2)
|
||||
summary = {}
|
||||
summary["success"] = 1
|
||||
summary["failed"] = 0
|
||||
summary["total"] = 1
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
sys.stdout = original_stdout
|
||||
sys.stderr = original_stderr
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(
|
||||
f"worker_function DONE idx:{idx} issue:{issue} 耗时:{int(duration/60)}m {int(duration%60)}s"
|
||||
)
|
||||
if summary["success"] > 0:
|
||||
add_done_set(out_dir, issue)
|
||||
|
||||
return ProcResult(duration=duration, issue=issue, idx=idx, summary=summary)
|
||||
|
||||
|
||||
class BatchProcessExecutor:
|
||||
ctx: Context
|
||||
out_dir: Path
|
||||
|
||||
def __init__(self, ctx: Context, out_dir: Path):
|
||||
self.ctx = ctx
|
||||
self.out_dir = out_dir
|
||||
|
||||
def execute_all_tasks(self):
|
||||
done_set = load_done_set(self.out_dir)
|
||||
parallel = self.ctx.args.parallel
|
||||
formatted_issues = []
|
||||
for idx, issue in enumerate(self.ctx.config.issue_list, 1):
|
||||
if issue in done_set:
|
||||
print(f"done set: skip {idx}, {issue}")
|
||||
else:
|
||||
formatted_issues.append((issue, idx, self.out_dir))
|
||||
with multiprocessing.Pool(processes=parallel, maxtasksperchild=1) as pool:
|
||||
try:
|
||||
worker = partial(worker_function, self.ctx)
|
||||
results = pool.starmap(worker, formatted_issues)
|
||||
except KeyboardInterrupt as e:
|
||||
print(f"ctrl-c received, exit")
|
||||
pool.terminate()
|
||||
cum_total = 0
|
||||
cum_success = 0
|
||||
cum_failed = 0
|
||||
for result in results:
|
||||
total = result.summary["total"]
|
||||
success = result.summary["success"]
|
||||
failed = result.summary["failed"]
|
||||
cum_total += total
|
||||
cum_success += success
|
||||
cum_failed += failed
|
||||
|
||||
print(f"Total instances: {cum_total}")
|
||||
print(f"Success: {cum_success}")
|
||||
print(f"Fail: {cum_failed}")
|
||||
print(f"Success rate: {(cum_success/cum_total*100):.1f}%" if cum_total > 0 else "0%")
|
||||
|
||||
print(f"Start generating final results")
|
||||
process = multiprocessing.Process(
|
||||
target=worder_func_for_gen_result, args=(self.ctx, self.out_dir)
|
||||
)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
|
||||
def parse_args() -> tuple[Args, Config]:
|
||||
parser = argparse.ArgumentParser(description="This is a concurrent execution tool.")
|
||||
parser.add_argument("-c", "--config", help="config file", required=True)
|
||||
parser.add_argument(
|
||||
"--name", help="task name,which will be concatenated as a directory name under the output path.", required=True
|
||||
)
|
||||
parser.add_argument("-i", "--issue-list", help="question list", type=str)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
help="output directory, ./batch_out/as default",
|
||||
type=str,
|
||||
default="./batch_out/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--parallel",
|
||||
help="Parallelism,20 as default",
|
||||
type=int,
|
||||
default=20,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean", help="Clean up data with the same name in the output directory before starting.", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", help="Skip the actual inference task execution, but proceed with all other logic.", action="store_true"
|
||||
)
|
||||
|
||||
args: Args = parser.parse_args()
|
||||
setattr(args, "custom", "str")
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
cfg = Config(**config_data)
|
||||
|
||||
if args.issue_list:
|
||||
issues: list[str] = []
|
||||
with open(args.issue_list, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line_clean = line.strip()
|
||||
if line_clean:
|
||||
issues.append(str(line_clean))
|
||||
cfg.issue_list = issues
|
||||
|
||||
print(f"list len = {len(cfg.issue_list)}")
|
||||
return (args, cfg)
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
print(f"Received signal {signum}, terminating child processes...")
|
||||
# sys.exit(0)
|
||||
|
||||
|
||||
def load_done_set(out_dir: Path) -> set[str]:
|
||||
file_name = out_dir.joinpath("done.txt")
|
||||
if not file_name.exists():
|
||||
return set()
|
||||
with open(file_name, "r") as f:
|
||||
return set(line.strip() for line in f if line.strip())
|
||||
|
||||
|
||||
def add_done_set(out_dir: Path, issue: str):
|
||||
file_name = out_dir.joinpath("done.txt")
|
||||
with open(file_name, "a") as f:
|
||||
f.write(f"{issue}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
50
_build_image.py
Normal file
50
_build_image.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from src.managers.image_builder.build_image import SWEBenchImageBuilder
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from src.managers.log.logger import ImageBuilderLogger
|
||||
from typing import List
|
||||
|
||||
|
||||
def load_instance_ids(file_path: str) -> List[str]:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return [line.strip() for line in f if line.strip()]
|
||||
|
||||
|
||||
|
||||
class BuildImage:
|
||||
def __init__(self, instance_ids: List[str] = None):
|
||||
self.instance_ids = instance_ids or []
|
||||
self.config = self._load_config()
|
||||
self.logger = ImageBuilderLogger(
|
||||
log_base_path=self.config.get("log", {}).get("image_builder", "workspace/image_logs"),
|
||||
console_output=True,
|
||||
)
|
||||
self.builder = self._init_builder()
|
||||
|
||||
|
||||
def _init_builder(self):
|
||||
dataset_name = self.config.get("dataset", {}).get("name", "princeton-nlp/SWE-bench_Lite")
|
||||
dataset_split = self.config.get("dataset", {}).get("split", "dev")
|
||||
return SWEBenchImageBuilder(
|
||||
dataset_name=dataset_name,
|
||||
split=dataset_split,
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
def _load_config(self):
|
||||
config_path = Path(__file__).parent / "config" / "config.yaml"
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
|
||||
def _build(self):
|
||||
max_workers = self.config.get("builder", {}).get("max_workers", 3)
|
||||
max_retries = self.config.get("builder", {}).get("max_retries", 1)
|
||||
self.builder.build_target_images(instance_ids=self.instance_ids, max_workers=max_workers, max_retries=max_retries)
|
||||
|
||||
if __name__ == "__main__":
|
||||
#instance_ids = ["django__django-10097", "django__django-10880"]
|
||||
file_path = "" # fill in the list path
|
||||
instance_ids = load_instance_ids(file_path=file_path)
|
||||
print(instance_ids)
|
||||
build_image = BuildImage(instance_ids)
|
||||
build_image._build()
|
||||
783
_run.py
Normal file
783
_run.py
Normal file
@@ -0,0 +1,783 @@
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
import json
|
||||
import yaml
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from traceback import format_exc
|
||||
from src.tools.executor import Executor
|
||||
from src.tools import BashTool, TextEditorTool, SearchTool, SubmitResultTool
|
||||
from src.managers.result_builder.result_builder import ResultBuilder
|
||||
from src.tools.base import (
|
||||
ToolExecutor,
|
||||
BASH_TOOL_NAME,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
from src.managers.log.logger import create_logger, Logger
|
||||
from src.managers.llm_api.api_manager import LLMAPIManager
|
||||
from src.managers.image_builder.build_image import SWEBenchImageBuilder
|
||||
from src.managers.prompts.prompts_manager import PromptsManager
|
||||
from src.managers.loop.patch_generator import PatchGenerator
|
||||
from src.managers.loop.patch_selector import PatchSelector
|
||||
from src.managers.loop.types import GeneratorResult, SelectorResult
|
||||
|
||||
|
||||
class SelectorLoop:
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
image_name: str,
|
||||
runner_log_base: Path,
|
||||
llm_manager: LLMAPIManager | None,
|
||||
prompts_manager: PromptsManager | None,
|
||||
instance_data: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
self.instance_id = instance_id
|
||||
self.image_name = image_name
|
||||
self.llm_manager = llm_manager
|
||||
self.prompts_manager = prompts_manager
|
||||
self.instance_data = instance_data
|
||||
self.config = config
|
||||
self.log_dir = runner_log_base / "run" / instance_id / "selector"
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.logger = Logger(
|
||||
log_base_path=str(self.log_dir.parent),
|
||||
logger_name=f"selector_{instance_id}",
|
||||
console_output=True,
|
||||
instance_id=self.log_dir.name,
|
||||
)
|
||||
|
||||
def _dump_select_result(self, result: SelectorResult) -> None:
|
||||
try:
|
||||
runner_cfg = (
|
||||
self.config.get("runner", {}) if isinstance(self.config, dict) else {}
|
||||
)
|
||||
dump_dir_str = runner_cfg.get(
|
||||
"selector_result_dump_path", "workspace/selector_result_dump"
|
||||
)
|
||||
dump_dir = Path(dump_dir_str)
|
||||
dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
out_path = dump_dir / f"{self.instance_id}.json"
|
||||
|
||||
payload = (
|
||||
result.to_dict()
|
||||
if hasattr(result, "to_dict") and callable(getattr(result, "to_dict"))
|
||||
else {}
|
||||
)
|
||||
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"dump selector result fail: {e}, traceback: {format_exc()}"
|
||||
)
|
||||
|
||||
def _load_dumped_generator_results(self) -> List[GeneratorResult]:
|
||||
try:
|
||||
runner_cfg = (
|
||||
self.config.get("runner", {}) if isinstance(self.config, dict) else {}
|
||||
)
|
||||
dump_dir_str = runner_cfg.get(
|
||||
"generator_result_dump_path", "workspace/generator_result_dump"
|
||||
)
|
||||
dump_path = Path(dump_dir_str) / f"{self.instance_id}.json"
|
||||
if not dump_path.exists():
|
||||
self.logger.warning(f"fail to find dump 文件: {dump_path}")
|
||||
return []
|
||||
with open(dump_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f) or []
|
||||
results: List[GeneratorResult] = []
|
||||
for item in data:
|
||||
try:
|
||||
if isinstance(item, dict):
|
||||
results.append(GeneratorResult.from_dict(item))
|
||||
except Exception as e:
|
||||
self.logger.warning(f"parse dump fail: {e}")
|
||||
continue
|
||||
return results
|
||||
except Exception as e:
|
||||
self.logger.warning(f"load dump fail: {e}, traceback: {format_exc()}")
|
||||
return []
|
||||
|
||||
async def select(self, generator_results: List[GeneratorResult]) -> SelectorResult:
|
||||
|
||||
if bool(self.config.get("runner", {}).get("skip_generator", False)):
|
||||
self.logger.info("jump generator,selector will load generator results from dump files")
|
||||
generator_results = self._load_dumped_generator_results()
|
||||
self.logger.info(f"load from dump: {len(generator_results)} candidates")
|
||||
|
||||
if not generator_results:
|
||||
from src.managers.loop.types import (
|
||||
SelectorResult,
|
||||
PatchInfo,
|
||||
LLMUsage,
|
||||
ToolStats,
|
||||
)
|
||||
|
||||
self.logger.error("No choosable candidates found")
|
||||
return SelectorResult(
|
||||
instance_id=self.instance_id,
|
||||
generator_id=-1,
|
||||
image="",
|
||||
success=False,
|
||||
golden_patch=PatchInfo(patch_content="", test_status="", reasoning=""),
|
||||
llm_usage=LLMUsage(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
tool_stats=ToolStats(bash=0, edit=0, search=0, submit_result=0),
|
||||
total_turns=0,
|
||||
select_reason="",
|
||||
error="No candidates available",
|
||||
)
|
||||
|
||||
self.logger.info(f"Start choosing best result,{len(generator_results)} candidates in total")
|
||||
executor = Executor(self.image_name, self.logger)
|
||||
bash_tool = BashTool(
|
||||
model_provider=None,
|
||||
executor=executor,
|
||||
logger=self.logger,
|
||||
config=self.config,
|
||||
)
|
||||
edit_tool = TextEditorTool(
|
||||
model_provider=None,
|
||||
executor=executor,
|
||||
logger=self.logger,
|
||||
config=self.config,
|
||||
)
|
||||
search_tool = SearchTool(
|
||||
model_provider=None,
|
||||
executor=executor,
|
||||
logger=self.logger,
|
||||
config=self.config,
|
||||
)
|
||||
tool_executor = ToolExecutor([bash_tool, edit_tool, search_tool], self.logger)
|
||||
|
||||
code, out = executor.execute("0", "echo READY && rg --version || true")
|
||||
self.logger.info(f"Container Health check: exit={code}, out=\n{out}")
|
||||
|
||||
successful_candidates = [r for r in generator_results if r.success]
|
||||
|
||||
if not successful_candidates:
|
||||
self.logger.warning("No successful candidates found, randomly choose one from all candidates")
|
||||
candidates = generator_results
|
||||
else:
|
||||
self.logger.info(f"Find {len(successful_candidates)} successful candidates")
|
||||
candidates = successful_candidates
|
||||
|
||||
patch_selector = PatchSelector(
|
||||
instance_id=self.instance_id,
|
||||
instance_data=self.instance_data,
|
||||
logger=self.logger,
|
||||
prompts_manager=self.prompts_manager,
|
||||
llm_manager=self.llm_manager,
|
||||
tool_executor=tool_executor,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
selected = await patch_selector._select_patch(candidates=candidates)
|
||||
|
||||
try:
|
||||
runner_cfg = (
|
||||
self.config.get("runner", {}) if isinstance(self.config, dict) else {}
|
||||
)
|
||||
if bool(runner_cfg.get("selector_result_dump", False)):
|
||||
self._dump_select_result(selected)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Error occurred when choosing dump : {e}, traceback: {format_exc()}"
|
||||
)
|
||||
|
||||
# import random
|
||||
# selected = random.choice(candidates)
|
||||
|
||||
self.logger.info(f"Choosing complete: choose#{selected.generator_id}.")
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class GeneratorLoop:
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
image_name: str,
|
||||
runner_log_base: Path,
|
||||
llm_manager: LLMAPIManager | None,
|
||||
prompts_manager: PromptsManager | None,
|
||||
instance_data: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
generator_id: int = 0,
|
||||
):
|
||||
self.instance_id = instance_id
|
||||
self.image_name = image_name
|
||||
self.generator_id = generator_id
|
||||
self.llm_manager = llm_manager
|
||||
self.prompts_manager = prompts_manager
|
||||
self.instance_data = instance_data
|
||||
self.config = config
|
||||
self.log_dir = (
|
||||
runner_log_base / "run" / instance_id / "generator" / f"{generator_id:03d}"
|
||||
)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.logger = Logger(
|
||||
log_base_path=str(self.log_dir.parent),
|
||||
logger_name=f"generator_{instance_id}_{generator_id:03d}",
|
||||
console_output=True,
|
||||
instance_id=self.log_dir.name,
|
||||
)
|
||||
|
||||
async def generate(self) -> GeneratorResult:
|
||||
executor: Executor | None = None
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Activate instance GeneratorLoop #{self.generator_id:03d}: {self.instance_id} -> {self.image_name}"
|
||||
)
|
||||
self.logger.info(f"Use image: {self.image_name}")
|
||||
executor = Executor(self.image_name, self.logger)
|
||||
|
||||
bash_tool = BashTool(
|
||||
model_provider=None,
|
||||
executor=executor,
|
||||
logger=self.logger,
|
||||
config=self.config,
|
||||
)
|
||||
edit_tool = TextEditorTool(
|
||||
model_provider=None,
|
||||
executor=executor,
|
||||
logger=self.logger,
|
||||
config=self.config,
|
||||
)
|
||||
search_tool = SearchTool(
|
||||
model_provider=None,
|
||||
executor=executor,
|
||||
logger=self.logger,
|
||||
config=self.config,
|
||||
)
|
||||
submit_result_tool = SubmitResultTool(
|
||||
model_provider=None,
|
||||
executor=executor,
|
||||
logger=self.logger,
|
||||
config=self.config,
|
||||
)
|
||||
tool_executor = ToolExecutor(
|
||||
[bash_tool, edit_tool, search_tool, submit_result_tool], self.logger
|
||||
)
|
||||
# tool_executor = ToolExecutor([bash_tool, edit_tool])
|
||||
|
||||
# optional: do a container health check
|
||||
code, out = executor.execute("0", "echo READY && rg --version || true")
|
||||
self.logger.info(f"Container Health Check: exit={code}, out=\n{out}")
|
||||
|
||||
|
||||
patch_generator = PatchGenerator(
|
||||
instance_id=self.instance_id,
|
||||
instance_data=self.instance_data,
|
||||
logger=self.logger,
|
||||
prompts_manager=self.prompts_manager,
|
||||
llm_manager=self.llm_manager,
|
||||
tool_executor=tool_executor,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
patch_result = await patch_generator._generate_patch()
|
||||
|
||||
if patch_result is None:
|
||||
result_data = {
|
||||
"instance_id": self.instance_id,
|
||||
"generator_id": self.generator_id,
|
||||
"image": self.image_name,
|
||||
"success": False,
|
||||
"golden_patch": [],
|
||||
"llm_usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
"tool_stats": {
|
||||
BASH_TOOL_NAME: 0,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: 0,
|
||||
SEARCH_TOOL_NAME: 0,
|
||||
SUBMIT_RESULT_TOOL_NAME: 0,
|
||||
},
|
||||
"total_turns": 0,
|
||||
}
|
||||
else:
|
||||
result_data = {
|
||||
"instance_id": self.instance_id,
|
||||
"generator_id": self.generator_id,
|
||||
"image": self.image_name,
|
||||
"success": patch_result.get("success", False),
|
||||
"golden_patch": patch_result.get("golden_patch", []),
|
||||
"llm_usage": patch_result.get(
|
||||
"llm_usage",
|
||||
{
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
),
|
||||
"tool_stats": patch_result.get(
|
||||
"tool_stats",
|
||||
{
|
||||
BASH_TOOL_NAME: 0,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: 0,
|
||||
SEARCH_TOOL_NAME: 0,
|
||||
SUBMIT_RESULT_TOOL_NAME: 0,
|
||||
},
|
||||
),
|
||||
"total_turns": patch_result.get("total_turns", 0),
|
||||
}
|
||||
self.logger.debug(f"[Generator Loop] result_data: {result_data}")
|
||||
return GeneratorResult.from_dict(result_data)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Instance {self.instance_id} Generator #{self.generator_id:03d} fail: {e}, traceback: {format_exc()}"
|
||||
)
|
||||
error_data = {
|
||||
"instance_id": self.instance_id,
|
||||
"generator_id": self.generator_id,
|
||||
"image": self.image_name,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"golden_patch": [],
|
||||
"llm_usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
"tool_stats": {
|
||||
BASH_TOOL_NAME: 0,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: 0,
|
||||
SEARCH_TOOL_NAME: 0,
|
||||
SUBMIT_RESULT_TOOL_NAME: 0,
|
||||
},
|
||||
"total_turns": 0,
|
||||
}
|
||||
return GeneratorResult.from_dict(error_data)
|
||||
finally:
|
||||
if executor:
|
||||
try:
|
||||
executor.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Runner:
|
||||
def __init__(self, cfg: Dict[str, Any], instance_ids: List[str] = None):
|
||||
self.cfg = cfg
|
||||
dataset_cfg = cfg.get("dataset", {})
|
||||
workspace_cfg = cfg.get("workspace", {})
|
||||
builder_cfg = cfg.get("builder", {})
|
||||
log_cfg = cfg.get("log", {})
|
||||
runner_cfg = cfg.get("runner", {})
|
||||
providers_cfg = cfg.get("providers", {})
|
||||
self.instance_ids = instance_ids
|
||||
self.dataset_name = dataset_cfg.get("name", "princeton-nlp/SWE-bench_Lite")
|
||||
self.dataset_split = dataset_cfg.get("split", "dev")
|
||||
self.max_workers = int(builder_cfg.get("max_workers", 2))
|
||||
self.generator_loop_concurrency = int(
|
||||
runner_cfg.get("generator_concurrency", 2)
|
||||
)
|
||||
|
||||
log_base_path = log_cfg.get("base_path", "workspace/logs")
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M")
|
||||
self.logs_base = Path(log_base_path) / timestamp
|
||||
self.logs_base.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.logger = Logger(
|
||||
log_base_path=str(self.logs_base.parent),
|
||||
logger_name="main",
|
||||
console_output=True,
|
||||
instance_id=self.logs_base.name,
|
||||
)
|
||||
|
||||
self.builder = self.load_images()
|
||||
self.logger.debug(f"builder instance_to_image: {self.builder.instance_to_image}.")
|
||||
self.llm_manager = LLMAPIManager(logger=self.logger, config=cfg)
|
||||
|
||||
self.prompts_manager: PromptsManager | None = None
|
||||
try:
|
||||
self.prompts_manager = PromptsManager(cfg)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Failed to initialize PromptsManager: {e}, traceback: {format_exc()}"
|
||||
)
|
||||
self.prompts_manager = None
|
||||
|
||||
def dump_generator_results(
|
||||
self, instance_id: str, generator_results: List[GeneratorResult]
|
||||
) -> None:
|
||||
"""Dump generator results to disk if enabled in config.
|
||||
|
||||
- Reads runner.generator_result_dump (bool) and runner.generator_result_dump_path (str)
|
||||
- When enabled, writes JSON file named by instance_id under the dump path
|
||||
"""
|
||||
try:
|
||||
runner_cfg: Dict[str, Any] = (
|
||||
self.cfg.get("runner", {}) if isinstance(self.cfg, dict) else {}
|
||||
)
|
||||
enabled: bool = bool(runner_cfg.get("generator_result_dump", False))
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
dump_dir_str: str = runner_cfg.get(
|
||||
"generator_result_dump_path", "workspace/generator_result_dump"
|
||||
)
|
||||
dump_dir = Path(dump_dir_str)
|
||||
dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
out_path = dump_dir / f"{instance_id}.json"
|
||||
|
||||
# Convert results to serializable form
|
||||
serialized: List[Dict[str, Any]] = []
|
||||
for r in generator_results:
|
||||
try:
|
||||
if hasattr(r, "to_dict") and callable(getattr(r, "to_dict")):
|
||||
serialized.append(r.to_dict())
|
||||
elif isinstance(r, dict):
|
||||
serialized.append(r)
|
||||
else:
|
||||
# Fallback: best-effort representation
|
||||
serialized.append({"repr": repr(r)})
|
||||
except Exception as e: # noqa: BLE001
|
||||
serialized.append({"error": f"failed to serialize item: {e}"})
|
||||
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(serialized, f, ensure_ascii=False, indent=2)
|
||||
|
||||
self.logger.info(
|
||||
f"Generator results write to: {out_path} ({len(serialized)} items in total)"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"dump_generator_results failed: {e}, traceback: {format_exc()}"
|
||||
)
|
||||
|
||||
def load_images(self):
|
||||
self.logger.info("Initialize SWEBenchImageBuilder and ready to load image...")
|
||||
builder = SWEBenchImageBuilder(
|
||||
dataset_name=self.dataset_name,
|
||||
split=self.dataset_split,
|
||||
logger=self.logger,
|
||||
)
|
||||
builder.load_images(instance_ids=self.instance_ids)
|
||||
return builder
|
||||
|
||||
async def _run_one(
|
||||
self,
|
||||
instance_id: str,
|
||||
image_name: str,
|
||||
instance_data: Dict[str, Any],
|
||||
generator_id: int = 0,
|
||||
) -> GeneratorResult:
|
||||
loop = GeneratorLoop(
|
||||
instance_id,
|
||||
image_name,
|
||||
self.logs_base,
|
||||
self.llm_manager,
|
||||
self.prompts_manager,
|
||||
instance_data,
|
||||
self.cfg,
|
||||
generator_id,
|
||||
)
|
||||
return await loop.generate()
|
||||
|
||||
async def process_one_instance(self, instance: Dict[str, Any]) -> SelectorResult:
|
||||
instance_id = instance["instance_id"]
|
||||
try:
|
||||
image_name = self.builder.get_image_name(instance_id)
|
||||
except KeyError:
|
||||
self.logger.warning(f"Jump instance (image mapping unfind): {instance_id}")
|
||||
return None
|
||||
|
||||
self.logger.info(f"Start processing instance: {instance_id}")
|
||||
|
||||
# optional: jump generator, let selector load from dump directly and choose
|
||||
skip_generator = bool(self.cfg.get("runner", {}).get("skip_generator", False))
|
||||
if skip_generator:
|
||||
self.logger.info(
|
||||
"Jump GeneratorLoop,Selector load from dump directly"
|
||||
)
|
||||
selector = SelectorLoop(
|
||||
instance_id=instance_id,
|
||||
image_name=image_name,
|
||||
runner_log_base=self.logs_base,
|
||||
llm_manager=self.llm_manager,
|
||||
prompts_manager=self.prompts_manager,
|
||||
instance_data=instance,
|
||||
config=self.cfg,
|
||||
)
|
||||
selected_result = await selector.select(
|
||||
[]
|
||||
)
|
||||
self.logger.info(
|
||||
f"Instance {instance_id} done(generating jumped),Generator #{selected_result.generator_id:03d} chosen"
|
||||
)
|
||||
return selected_result
|
||||
|
||||
generator_load_dump = bool(
|
||||
self.cfg.get("runner", {}).get("generator_load_dump_result", False)
|
||||
)
|
||||
valid_results = []
|
||||
|
||||
if generator_load_dump:
|
||||
runner_cfg = self.cfg.get("runner", {})
|
||||
dump_path = runner_cfg.get(
|
||||
"generator_result_dump_path", "workspace/generator_result_dump"
|
||||
)
|
||||
|
||||
if self._check_dump_result(dump_path, instance_id):
|
||||
self.logger.info(f"load generator results from dump: {instance_id}")
|
||||
valid_results = self._load_generator_dump_result(dump_path, instance_id)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Fail to find generator dump file,Generate candidate patches concurrently: {instance_id}"
|
||||
)
|
||||
|
||||
if not valid_results:
|
||||
generator_tasks = []
|
||||
for generator_id in range(self.generator_loop_concurrency):
|
||||
task = asyncio.create_task(
|
||||
self._run_one(instance_id, image_name, instance, generator_id)
|
||||
)
|
||||
generator_tasks.append(task)
|
||||
|
||||
generator_results = await asyncio.gather(
|
||||
*generator_tasks, return_exceptions=True
|
||||
)
|
||||
self.logger.debug(
|
||||
f"In process_one_instance, generator_results len: {len(generator_results)}"
|
||||
)
|
||||
|
||||
for result in generator_results:
|
||||
if isinstance(result, Exception):
|
||||
self.logger.error(f"GeneratorLoop exception: {result}")
|
||||
else:
|
||||
valid_results.append(result)
|
||||
self.logger.debug(
|
||||
f"In process_one_instance, valid_results len: {len(valid_results)}"
|
||||
)
|
||||
# optional: Dump the generator results for subsequent selector debugging.
|
||||
try:
|
||||
self.dump_generator_results(instance_id, valid_results)
|
||||
except Exception:
|
||||
# Dump failure should not block the main process/flow.
|
||||
pass
|
||||
|
||||
if not valid_results:
|
||||
self.logger.warning(f"Instance {instance_id} has no valid GeneratorLoop results")
|
||||
return None
|
||||
|
||||
selector_load_dump = bool(
|
||||
self.cfg.get("runner", {}).get("selector_load_dump_result", False)
|
||||
)
|
||||
|
||||
if selector_load_dump:
|
||||
runner_cfg = self.cfg.get("runner", {})
|
||||
dump_path = runner_cfg.get(
|
||||
"selector_result_dump_path", "workspace/selector_result_dump"
|
||||
)
|
||||
|
||||
if self._check_dump_result(dump_path, instance_id):
|
||||
self.logger.info(f"load selector results from dump file: {instance_id}")
|
||||
try:
|
||||
dump_dir = Path(dump_path)
|
||||
file_path = dump_dir / f"{instance_id}.json"
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
selected_data = json.load(f)
|
||||
|
||||
from src.managers.loop.types import SelectorResult
|
||||
|
||||
selected_result = SelectorResult.from_dict(selected_data)
|
||||
|
||||
self.logger.info(
|
||||
f"Instance{instance_id} process done (load from dump),Generator #{selected_result.generator_id:03d} chosen"
|
||||
)
|
||||
return selected_result
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Fail to load selector dump result: {e}, execute normal choosing procedure"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Fail to load selector dump result, execute normal choosing procedure: {instance_id}"
|
||||
)
|
||||
|
||||
self.logger.info(f"Star choosing best result for instance {instance_id}")
|
||||
selector = SelectorLoop(
|
||||
instance_id=instance_id,
|
||||
image_name=image_name,
|
||||
runner_log_base=self.logs_base,
|
||||
llm_manager=self.llm_manager,
|
||||
prompts_manager=self.prompts_manager,
|
||||
instance_data=instance,
|
||||
config=self.cfg,
|
||||
)
|
||||
selected_result = await selector.select(valid_results)
|
||||
|
||||
self.logger.info(
|
||||
f"Instance {instance_id} processed,Generator #{selected_result.generator_id:03d} chosen"
|
||||
)
|
||||
return selected_result
|
||||
|
||||
async def run(self) -> Dict[str, Any]:
|
||||
|
||||
assert self.builder is not None
|
||||
|
||||
if self.instance_ids:
|
||||
target_ids = set(self.instance_ids)
|
||||
instances_to_run = [
|
||||
inst
|
||||
for inst in self.builder.full_dataset
|
||||
if inst.get("instance_id") in target_ids
|
||||
]
|
||||
else:
|
||||
instances_to_run = list(self.builder.full_dataset)
|
||||
|
||||
self.logger.info(f"Start to process {len(instances_to_run)} instances")
|
||||
|
||||
final_results = []
|
||||
for i, instance in enumerate(instances_to_run, 1):
|
||||
self.logger.info(
|
||||
f"process instance{i}/{len(instances_to_run)}: {instance['instance_id']}"
|
||||
)
|
||||
try:
|
||||
result = await self.process_one_instance(instance)
|
||||
if result is not None:
|
||||
final_results.append(result)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Instance {instance['instance_id']} process fail: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
final_results.append(
|
||||
SelectorResult(
|
||||
instance_id=instance["instance_id"],
|
||||
generator_id=0,
|
||||
success=False,
|
||||
golden_patch=None,
|
||||
)
|
||||
)
|
||||
|
||||
summary = self._calculate_summary(final_results)
|
||||
|
||||
self.logger.info(
|
||||
f"Done. Total={summary['total']} Success={summary['success']} Fail={summary['failed']}"
|
||||
)
|
||||
return summary
|
||||
|
||||
def _check_dump_result(self, dump_path: str, instance_id: str) -> bool:
|
||||
try:
|
||||
dump_dir = Path(dump_path)
|
||||
file_path = dump_dir / f"{instance_id}.json"
|
||||
return file_path.exists()
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Fail to check dump file: {e}")
|
||||
return False
|
||||
|
||||
def _load_generator_dump_result(
|
||||
self, dump_path: str, instance_id: str
|
||||
) -> List[GeneratorResult]:
|
||||
try:
|
||||
dump_dir = Path(dump_path)
|
||||
file_path = dump_dir / f"{instance_id}.json"
|
||||
|
||||
if not file_path.exists():
|
||||
self.logger.warning(f"Generator dump file does not exist: {file_path}")
|
||||
return []
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f) or []
|
||||
|
||||
results: List[GeneratorResult] = []
|
||||
for item in data:
|
||||
try:
|
||||
if isinstance(item, dict):
|
||||
results.append(GeneratorResult.from_dict(item))
|
||||
else:
|
||||
self.logger.warning(f"Jump non-dict dump: {type(item)}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"parse dump fail: {e}")
|
||||
continue
|
||||
|
||||
self.logger.info(f"load {len(results)} GeneratorResults from dump")
|
||||
return results
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Load generator dump results failed: {e}, traceback: {format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def _calculate_summary(self, results: List[SelectorResult]) -> Dict[str, Any]:
|
||||
summary: Dict[str, Any] = {
|
||||
"total": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
}
|
||||
|
||||
for r in results:
|
||||
summary["total"] += 1
|
||||
if r.success:
|
||||
summary["success"] += 1
|
||||
else:
|
||||
summary["failed"] += 1
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def run_one_issue(cfg, issue: str) -> dict[str, Any]:
|
||||
ids = [issue]
|
||||
summary = {}
|
||||
if not cfg.get("runner", {}).get("skip_selector", False):
|
||||
runner = Runner(cfg, ids)
|
||||
summary = asyncio.run(runner.run())
|
||||
print("\n" + "=" * 80)
|
||||
print("total results")
|
||||
print("=" * 80)
|
||||
print(f"Total instances: {summary['total']}")
|
||||
print(f"Success: {summary['success']}")
|
||||
print(f"Fail: {summary['failed']}")
|
||||
print(
|
||||
f"Success rate: {(summary['success']/summary['total']*100):.1f}%"
|
||||
if summary["total"] > 0
|
||||
else "0%"
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def gen_result(cfg):
|
||||
result_builder = ResultBuilder(cfg)
|
||||
result_path = result_builder.build_preds()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config_path = Path(__file__).parent / "config" / "config.yaml"
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
test_instance_ids = ["astropy__astropy-14309"]
|
||||
if not cfg.get("runner", {}).get("skip_selector", False):
|
||||
runner = Runner(cfg, test_instance_ids)
|
||||
summary = asyncio.run(runner.run())
|
||||
print("\n" + "=" * 80)
|
||||
print("Total results")
|
||||
print("=" * 80)
|
||||
print(f"Total instances: {summary['total']}")
|
||||
print(f"Success: {summary['success']}")
|
||||
print(f"Fail: {summary['failed']}")
|
||||
print(
|
||||
f"Success rate: {(summary['success']/summary['total']*100):.1f}%"
|
||||
if summary["total"] > 0
|
||||
else "0%"
|
||||
)
|
||||
|
||||
result_builder = ResultBuilder(cfg)
|
||||
result_builder.build_preds()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
config/batch.yaml
Normal file
1
config/batch.yaml
Normal file
@@ -0,0 +1 @@
|
||||
providers:
|
||||
61
config/config.yaml
Normal file
61
config/config.yaml
Normal file
@@ -0,0 +1,61 @@
|
||||
providers:
|
||||
openrouter:
|
||||
- "anthropic/claude-sonnet-4.5"
|
||||
|
||||
dataset:
|
||||
name: "princeton-nlp/SWE-bench_Verified"
|
||||
split: "test"
|
||||
|
||||
workspace:
|
||||
path: "workspace"
|
||||
|
||||
# Runner Settings
|
||||
runner:
|
||||
|
||||
# generator
|
||||
skip_generator: false # Whether to skip the generator; for selector testing you can skip and load generator_results directly from dump files
|
||||
generator_concurrency: 5 # GeneratorLoop patches num
|
||||
generator_loop: # GeneratorLoop settings
|
||||
max_turn: 140 # Max turns
|
||||
temperature: 0.2 # Temperature
|
||||
generator_result_dump: true # Whether to dump generator results for later selector testing
|
||||
generator_result_dump_path: "workspace/generator_result_dump" # generator result dump path
|
||||
generator_load_dump_result: true # Whether to skip examples that already have generated results
|
||||
|
||||
# selector
|
||||
skip_selector: false # Whether to skip the selector; useful to test validation directly
|
||||
selector_loop: # SelectorLoop settings
|
||||
max_turn: 200 # Max turns
|
||||
temperature: 0.2 # Temperature
|
||||
selector_result_dump: true # Whether to dump selector results
|
||||
selector_result_dump_path: "workspace/selector_result_dump"
|
||||
selector_load_dump_result: true # Whether to skip examples that already have selection results
|
||||
|
||||
# Image Builder
|
||||
builder:
|
||||
max_workers: 16 # Image build concurrency
|
||||
max_retries: 3 # Image build retry count
|
||||
repo_root_path: "/testbed" # Repository root path
|
||||
|
||||
# Log settings
|
||||
log:
|
||||
base_path: "workspace/logs"
|
||||
image_builder: "workspace/image_logs" # Use temporary directory for image builder logs
|
||||
disable_image_builder_logs: false # Set to true to prevent saving log files
|
||||
|
||||
# Final result paths
|
||||
result:
|
||||
preds:
|
||||
path: "result"
|
||||
|
||||
# token_cache settings for claude models
|
||||
token_cache:
|
||||
role_turns:
|
||||
- turn: 0
|
||||
role: "user"
|
||||
- turn: 25
|
||||
role: "assistant"
|
||||
- turn: 50
|
||||
role: "assistant"
|
||||
- turn: 75
|
||||
role: "assistant"
|
||||
500
config/list_all.txt
Normal file
500
config/list_all.txt
Normal file
@@ -0,0 +1,500 @@
|
||||
astropy__astropy-12907
|
||||
astropy__astropy-13033
|
||||
astropy__astropy-13236
|
||||
astropy__astropy-13398
|
||||
astropy__astropy-13453
|
||||
astropy__astropy-13579
|
||||
astropy__astropy-13977
|
||||
astropy__astropy-14096
|
||||
astropy__astropy-14182
|
||||
astropy__astropy-14309
|
||||
astropy__astropy-14365
|
||||
astropy__astropy-14369
|
||||
astropy__astropy-14508
|
||||
astropy__astropy-14539
|
||||
astropy__astropy-14598
|
||||
astropy__astropy-14995
|
||||
astropy__astropy-7166
|
||||
astropy__astropy-7336
|
||||
astropy__astropy-7606
|
||||
astropy__astropy-7671
|
||||
astropy__astropy-8707
|
||||
astropy__astropy-8872
|
||||
django__django-10097
|
||||
django__django-10554
|
||||
django__django-10880
|
||||
django__django-10914
|
||||
django__django-10973
|
||||
django__django-10999
|
||||
django__django-11066
|
||||
django__django-11087
|
||||
django__django-11095
|
||||
django__django-11099
|
||||
django__django-11119
|
||||
django__django-11133
|
||||
django__django-11138
|
||||
django__django-11141
|
||||
django__django-11149
|
||||
django__django-11163
|
||||
django__django-11179
|
||||
django__django-11206
|
||||
django__django-11211
|
||||
django__django-11239
|
||||
django__django-11265
|
||||
django__django-11276
|
||||
django__django-11292
|
||||
django__django-11299
|
||||
django__django-11333
|
||||
django__django-11400
|
||||
django__django-11433
|
||||
django__django-11451
|
||||
django__django-11477
|
||||
django__django-11490
|
||||
django__django-11532
|
||||
django__django-11551
|
||||
django__django-11555
|
||||
django__django-11603
|
||||
django__django-11728
|
||||
django__django-11734
|
||||
django__django-11740
|
||||
django__django-11749
|
||||
django__django-11790
|
||||
django__django-11815
|
||||
django__django-11820
|
||||
django__django-11848
|
||||
django__django-11880
|
||||
django__django-11885
|
||||
django__django-11951
|
||||
django__django-11964
|
||||
django__django-11999
|
||||
django__django-12039
|
||||
django__django-12050
|
||||
django__django-12125
|
||||
django__django-12143
|
||||
django__django-12155
|
||||
django__django-12193
|
||||
django__django-12209
|
||||
django__django-12262
|
||||
django__django-12273
|
||||
django__django-12276
|
||||
django__django-12304
|
||||
django__django-12308
|
||||
django__django-12325
|
||||
django__django-12406
|
||||
django__django-12419
|
||||
django__django-12663
|
||||
django__django-12708
|
||||
django__django-12713
|
||||
django__django-12741
|
||||
django__django-12754
|
||||
django__django-12774
|
||||
django__django-12858
|
||||
django__django-12965
|
||||
django__django-13012
|
||||
django__django-13023
|
||||
django__django-13028
|
||||
django__django-13033
|
||||
django__django-13089
|
||||
django__django-13109
|
||||
django__django-13112
|
||||
django__django-13121
|
||||
django__django-13128
|
||||
django__django-13158
|
||||
django__django-13195
|
||||
django__django-13212
|
||||
django__django-13279
|
||||
django__django-13297
|
||||
django__django-13315
|
||||
django__django-13343
|
||||
django__django-13344
|
||||
django__django-13346
|
||||
django__django-13363
|
||||
django__django-13401
|
||||
django__django-13406
|
||||
django__django-13410
|
||||
django__django-13417
|
||||
django__django-13449
|
||||
django__django-13512
|
||||
django__django-13513
|
||||
django__django-13516
|
||||
django__django-13551
|
||||
django__django-13568
|
||||
django__django-13569
|
||||
django__django-13590
|
||||
django__django-13658
|
||||
django__django-13670
|
||||
django__django-13741
|
||||
django__django-13786
|
||||
django__django-13794
|
||||
django__django-13807
|
||||
django__django-13809
|
||||
django__django-13810
|
||||
django__django-13820
|
||||
django__django-13821
|
||||
django__django-13837
|
||||
django__django-13925
|
||||
django__django-13933
|
||||
django__django-13964
|
||||
django__django-14007
|
||||
django__django-14011
|
||||
django__django-14017
|
||||
django__django-14034
|
||||
django__django-14053
|
||||
django__django-14089
|
||||
django__django-14122
|
||||
django__django-14140
|
||||
django__django-14155
|
||||
django__django-14170
|
||||
django__django-14238
|
||||
django__django-14311
|
||||
django__django-14315
|
||||
django__django-14349
|
||||
django__django-14351
|
||||
django__django-14373
|
||||
django__django-14376
|
||||
django__django-14404
|
||||
django__django-14434
|
||||
django__django-14493
|
||||
django__django-14500
|
||||
django__django-14534
|
||||
django__django-14539
|
||||
django__django-14559
|
||||
django__django-14580
|
||||
django__django-14608
|
||||
django__django-14631
|
||||
django__django-14672
|
||||
django__django-14725
|
||||
django__django-14752
|
||||
django__django-14765
|
||||
django__django-14771
|
||||
django__django-14787
|
||||
django__django-14792
|
||||
django__django-14855
|
||||
django__django-14915
|
||||
django__django-14999
|
||||
django__django-15022
|
||||
django__django-15037
|
||||
django__django-15098
|
||||
django__django-15103
|
||||
django__django-15104
|
||||
django__django-15127
|
||||
django__django-15128
|
||||
django__django-15161
|
||||
django__django-15252
|
||||
django__django-15268
|
||||
django__django-15277
|
||||
django__django-15278
|
||||
django__django-15280
|
||||
django__django-15315
|
||||
django__django-15368
|
||||
django__django-15375
|
||||
django__django-15380
|
||||
django__django-15382
|
||||
django__django-15467
|
||||
django__django-15499
|
||||
django__django-15503
|
||||
django__django-15525
|
||||
django__django-15554
|
||||
django__django-15561
|
||||
django__django-15563
|
||||
django__django-15569
|
||||
django__django-15572
|
||||
django__django-15629
|
||||
django__django-15695
|
||||
django__django-15731
|
||||
django__django-15732
|
||||
django__django-15741
|
||||
django__django-15814
|
||||
django__django-15851
|
||||
django__django-15863
|
||||
django__django-15916
|
||||
django__django-15930
|
||||
django__django-15957
|
||||
django__django-15973
|
||||
django__django-15987
|
||||
django__django-16032
|
||||
django__django-16082
|
||||
django__django-16100
|
||||
django__django-16116
|
||||
django__django-16136
|
||||
django__django-16139
|
||||
django__django-16145
|
||||
django__django-16255
|
||||
django__django-16256
|
||||
django__django-16263
|
||||
django__django-16315
|
||||
django__django-16333
|
||||
django__django-16429
|
||||
django__django-16454
|
||||
django__django-16485
|
||||
django__django-16493
|
||||
django__django-16502
|
||||
django__django-16527
|
||||
django__django-16560
|
||||
django__django-16569
|
||||
django__django-16595
|
||||
django__django-16612
|
||||
django__django-16631
|
||||
django__django-16642
|
||||
django__django-16661
|
||||
django__django-16662
|
||||
django__django-16667
|
||||
django__django-16801
|
||||
django__django-16819
|
||||
django__django-16877
|
||||
django__django-16899
|
||||
django__django-16901
|
||||
django__django-16938
|
||||
django__django-16950
|
||||
django__django-17029
|
||||
django__django-17084
|
||||
django__django-17087
|
||||
django__django-7530
|
||||
django__django-9296
|
||||
matplotlib__matplotlib-13989
|
||||
matplotlib__matplotlib-14623
|
||||
matplotlib__matplotlib-20488
|
||||
matplotlib__matplotlib-20676
|
||||
matplotlib__matplotlib-20826
|
||||
matplotlib__matplotlib-20859
|
||||
matplotlib__matplotlib-21568
|
||||
matplotlib__matplotlib-22719
|
||||
matplotlib__matplotlib-22865
|
||||
matplotlib__matplotlib-22871
|
||||
matplotlib__matplotlib-23299
|
||||
matplotlib__matplotlib-23314
|
||||
matplotlib__matplotlib-23412
|
||||
matplotlib__matplotlib-23476
|
||||
matplotlib__matplotlib-24026
|
||||
matplotlib__matplotlib-24149
|
||||
matplotlib__matplotlib-24177
|
||||
matplotlib__matplotlib-24570
|
||||
matplotlib__matplotlib-24627
|
||||
matplotlib__matplotlib-24637
|
||||
matplotlib__matplotlib-24870
|
||||
matplotlib__matplotlib-24970
|
||||
matplotlib__matplotlib-25122
|
||||
matplotlib__matplotlib-25287
|
||||
matplotlib__matplotlib-25311
|
||||
matplotlib__matplotlib-25332
|
||||
matplotlib__matplotlib-25479
|
||||
matplotlib__matplotlib-25775
|
||||
matplotlib__matplotlib-25960
|
||||
matplotlib__matplotlib-26113
|
||||
matplotlib__matplotlib-26208
|
||||
matplotlib__matplotlib-26291
|
||||
matplotlib__matplotlib-26342
|
||||
matplotlib__matplotlib-26466
|
||||
mwaskom__seaborn-3069
|
||||
mwaskom__seaborn-3187
|
||||
pallets__flask-5014
|
||||
psf__requests-1142
|
||||
psf__requests-1724
|
||||
psf__requests-1766
|
||||
psf__requests-1921
|
||||
psf__requests-2317
|
||||
psf__requests-2931
|
||||
psf__requests-5414
|
||||
psf__requests-6028
|
||||
pydata__xarray-2905
|
||||
pydata__xarray-3095
|
||||
pydata__xarray-3151
|
||||
pydata__xarray-3305
|
||||
pydata__xarray-3677
|
||||
pydata__xarray-3993
|
||||
pydata__xarray-4075
|
||||
pydata__xarray-4094
|
||||
pydata__xarray-4356
|
||||
pydata__xarray-4629
|
||||
pydata__xarray-4687
|
||||
pydata__xarray-4695
|
||||
pydata__xarray-4966
|
||||
pydata__xarray-6461
|
||||
pydata__xarray-6599
|
||||
pydata__xarray-6721
|
||||
pydata__xarray-6744
|
||||
pydata__xarray-6938
|
||||
pydata__xarray-6992
|
||||
pydata__xarray-7229
|
||||
pydata__xarray-7233
|
||||
pydata__xarray-7393
|
||||
pylint-dev__pylint-4551
|
||||
pylint-dev__pylint-4604
|
||||
pylint-dev__pylint-4661
|
||||
pylint-dev__pylint-4970
|
||||
pylint-dev__pylint-6386
|
||||
pylint-dev__pylint-6528
|
||||
pylint-dev__pylint-6903
|
||||
pylint-dev__pylint-7080
|
||||
pylint-dev__pylint-7277
|
||||
pylint-dev__pylint-8898
|
||||
pytest-dev__pytest-10051
|
||||
pytest-dev__pytest-10081
|
||||
pytest-dev__pytest-10356
|
||||
pytest-dev__pytest-5262
|
||||
pytest-dev__pytest-5631
|
||||
pytest-dev__pytest-5787
|
||||
pytest-dev__pytest-5809
|
||||
pytest-dev__pytest-5840
|
||||
pytest-dev__pytest-6197
|
||||
pytest-dev__pytest-6202
|
||||
pytest-dev__pytest-7205
|
||||
pytest-dev__pytest-7236
|
||||
pytest-dev__pytest-7324
|
||||
pytest-dev__pytest-7432
|
||||
pytest-dev__pytest-7490
|
||||
pytest-dev__pytest-7521
|
||||
pytest-dev__pytest-7571
|
||||
pytest-dev__pytest-7982
|
||||
pytest-dev__pytest-8399
|
||||
scikit-learn__scikit-learn-10297
|
||||
scikit-learn__scikit-learn-10844
|
||||
scikit-learn__scikit-learn-10908
|
||||
scikit-learn__scikit-learn-11310
|
||||
scikit-learn__scikit-learn-11578
|
||||
scikit-learn__scikit-learn-12585
|
||||
scikit-learn__scikit-learn-12682
|
||||
scikit-learn__scikit-learn-12973
|
||||
scikit-learn__scikit-learn-13124
|
||||
scikit-learn__scikit-learn-13135
|
||||
scikit-learn__scikit-learn-13142
|
||||
scikit-learn__scikit-learn-13328
|
||||
scikit-learn__scikit-learn-13439
|
||||
scikit-learn__scikit-learn-13496
|
||||
scikit-learn__scikit-learn-13779
|
||||
scikit-learn__scikit-learn-14053
|
||||
scikit-learn__scikit-learn-14087
|
||||
scikit-learn__scikit-learn-14141
|
||||
scikit-learn__scikit-learn-14496
|
||||
scikit-learn__scikit-learn-14629
|
||||
scikit-learn__scikit-learn-14710
|
||||
scikit-learn__scikit-learn-14894
|
||||
scikit-learn__scikit-learn-14983
|
||||
scikit-learn__scikit-learn-15100
|
||||
scikit-learn__scikit-learn-25102
|
||||
scikit-learn__scikit-learn-25232
|
||||
scikit-learn__scikit-learn-25747
|
||||
scikit-learn__scikit-learn-25931
|
||||
scikit-learn__scikit-learn-25973
|
||||
scikit-learn__scikit-learn-26194
|
||||
scikit-learn__scikit-learn-26323
|
||||
scikit-learn__scikit-learn-9288
|
||||
sphinx-doc__sphinx-10323
|
||||
sphinx-doc__sphinx-10435
|
||||
sphinx-doc__sphinx-10449
|
||||
sphinx-doc__sphinx-10466
|
||||
sphinx-doc__sphinx-10614
|
||||
sphinx-doc__sphinx-10673
|
||||
sphinx-doc__sphinx-11445
|
||||
sphinx-doc__sphinx-11510
|
||||
sphinx-doc__sphinx-7440
|
||||
sphinx-doc__sphinx-7454
|
||||
sphinx-doc__sphinx-7462
|
||||
sphinx-doc__sphinx-7590
|
||||
sphinx-doc__sphinx-7748
|
||||
sphinx-doc__sphinx-7757
|
||||
sphinx-doc__sphinx-7889
|
||||
sphinx-doc__sphinx-7910
|
||||
sphinx-doc__sphinx-7985
|
||||
sphinx-doc__sphinx-8035
|
||||
sphinx-doc__sphinx-8056
|
||||
sphinx-doc__sphinx-8120
|
||||
sphinx-doc__sphinx-8265
|
||||
sphinx-doc__sphinx-8269
|
||||
sphinx-doc__sphinx-8459
|
||||
sphinx-doc__sphinx-8475
|
||||
sphinx-doc__sphinx-8548
|
||||
sphinx-doc__sphinx-8551
|
||||
sphinx-doc__sphinx-8593
|
||||
sphinx-doc__sphinx-8595
|
||||
sphinx-doc__sphinx-8621
|
||||
sphinx-doc__sphinx-8638
|
||||
sphinx-doc__sphinx-8721
|
||||
sphinx-doc__sphinx-9229
|
||||
sphinx-doc__sphinx-9230
|
||||
sphinx-doc__sphinx-9258
|
||||
sphinx-doc__sphinx-9281
|
||||
sphinx-doc__sphinx-9320
|
||||
sphinx-doc__sphinx-9367
|
||||
sphinx-doc__sphinx-9461
|
||||
sphinx-doc__sphinx-9591
|
||||
sphinx-doc__sphinx-9602
|
||||
sphinx-doc__sphinx-9658
|
||||
sphinx-doc__sphinx-9673
|
||||
sphinx-doc__sphinx-9698
|
||||
sphinx-doc__sphinx-9711
|
||||
sympy__sympy-11618
|
||||
sympy__sympy-12096
|
||||
sympy__sympy-12419
|
||||
sympy__sympy-12481
|
||||
sympy__sympy-12489
|
||||
sympy__sympy-13031
|
||||
sympy__sympy-13091
|
||||
sympy__sympy-13372
|
||||
sympy__sympy-13480
|
||||
sympy__sympy-13551
|
||||
sympy__sympy-13615
|
||||
sympy__sympy-13647
|
||||
sympy__sympy-13757
|
||||
sympy__sympy-13798
|
||||
sympy__sympy-13852
|
||||
sympy__sympy-13877
|
||||
sympy__sympy-13878
|
||||
sympy__sympy-13974
|
||||
sympy__sympy-14248
|
||||
sympy__sympy-14531
|
||||
sympy__sympy-14711
|
||||
sympy__sympy-14976
|
||||
sympy__sympy-15017
|
||||
sympy__sympy-15345
|
||||
sympy__sympy-15349
|
||||
sympy__sympy-15599
|
||||
sympy__sympy-15809
|
||||
sympy__sympy-15875
|
||||
sympy__sympy-15976
|
||||
sympy__sympy-16450
|
||||
sympy__sympy-16597
|
||||
sympy__sympy-16766
|
||||
sympy__sympy-16792
|
||||
sympy__sympy-16886
|
||||
sympy__sympy-17139
|
||||
sympy__sympy-17318
|
||||
sympy__sympy-17630
|
||||
sympy__sympy-17655
|
||||
sympy__sympy-18189
|
||||
sympy__sympy-18199
|
||||
sympy__sympy-18211
|
||||
sympy__sympy-18698
|
||||
sympy__sympy-18763
|
||||
sympy__sympy-19040
|
||||
sympy__sympy-19346
|
||||
sympy__sympy-19495
|
||||
sympy__sympy-19637
|
||||
sympy__sympy-19783
|
||||
sympy__sympy-19954
|
||||
sympy__sympy-20154
|
||||
sympy__sympy-20428
|
||||
sympy__sympy-20438
|
||||
sympy__sympy-20590
|
||||
sympy__sympy-20801
|
||||
sympy__sympy-20916
|
||||
sympy__sympy-21379
|
||||
sympy__sympy-21596
|
||||
sympy__sympy-21612
|
||||
sympy__sympy-21847
|
||||
sympy__sympy-21930
|
||||
sympy__sympy-22080
|
||||
sympy__sympy-22456
|
||||
sympy__sympy-22714
|
||||
sympy__sympy-22914
|
||||
sympy__sympy-23262
|
||||
sympy__sympy-23413
|
||||
sympy__sympy-23534
|
||||
sympy__sympy-23824
|
||||
sympy__sympy-23950
|
||||
sympy__sympy-24066
|
||||
sympy__sympy-24213
|
||||
sympy__sympy-24443
|
||||
sympy__sympy-24539
|
||||
sympy__sympy-24562
|
||||
sympy__sympy-24661
|
||||
31
env.example
Normal file
31
env.example
Normal file
@@ -0,0 +1,31 @@
|
||||
# API
|
||||
|
||||
## OpenRouter API
|
||||
OPENROUTER_API_KEY=your-openrouter-api-key-here
|
||||
|
||||
## Anthropic Claude API
|
||||
ANTHROPIC_API_KEY=your-anthropic-api-key-here
|
||||
|
||||
## DeepSeek API
|
||||
DEEPSEEK_API_KEY=your-deepseek-api-key-here
|
||||
|
||||
## OpenAI API
|
||||
OPENAI_API_KEY=your-openai-api-key-here
|
||||
|
||||
## Private API
|
||||
PRIVATE_API_KEY=EMPTY
|
||||
PRIVATE_MODEL_NAME=your-private-model-name-here
|
||||
PRIVATE_DEPLOYMENT_TYPE=vllm
|
||||
PRIVATE_URL=https://127.0.0.1:0/v1
|
||||
|
||||
OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
ANTHROPIC_BASE_URL=https://api.anthropic.com
|
||||
DEEPSEEK_BASE_URL=https://api.deepseek.com/v1
|
||||
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
|
||||
|
||||
## OpenRouter
|
||||
OPENROUTER_APP_NAME=tokfinity-llm-client
|
||||
OPENROUTER_SITE_URL=https://github.com/your-repo
|
||||
|
||||
# config path
|
||||
CONFIG_PATH=config/config.yaml
|
||||
63
eval.sh
Normal file
63
eval.sh
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
|
||||
PROJECT_ROOT="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
|
||||
VENV_DIR="$PROJECT_ROOT/venv"
|
||||
if [[ ! -d "$VENV_DIR" ]]; then
|
||||
echo "[ERROR] wrong venv directory: $VENV_DIR" >&2
|
||||
exit 1
|
||||
fi
|
||||
source "$VENV_DIR/bin/activate"
|
||||
|
||||
|
||||
CONFIG_PATH="$PROJECT_ROOT/config/config.yaml"
|
||||
if [[ ! -f "$CONFIG_PATH" ]]; then
|
||||
echo "[ERROR] config file not found: $CONFIG_PATH" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
eval "$(python - <<'PY'
|
||||
from __future__ import annotations
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
project_root = Path(__file__).resolve().parent
|
||||
config_path = project_root / 'config' / 'config.yaml'
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
dataset_name = ((cfg.get('dataset') or {}).get('name')) or 'princeton-nlp/SWE-bench_Lite'
|
||||
workspace_path = ((cfg.get('workspace') or {}).get('path')) or 'workspace'
|
||||
preds_dir = (((cfg.get('result') or {}).get('preds') or {}).get('path')) or 'result'
|
||||
preds_path = (project_root / workspace_path / preds_dir / 'preds.json').resolve()
|
||||
|
||||
print(f'DATASET_NAME="{dataset_name}"')
|
||||
print(f'PREDICTIONS_PATH="{preds_path}"')
|
||||
PY
|
||||
)"
|
||||
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "[ERROR] run_id is required" >&2
|
||||
echo "Usage: $0 <run_id>" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RUN_ID="$1"
|
||||
|
||||
echo "Using dataset: $DATASET_NAME"
|
||||
echo "Using predictions: $PREDICTIONS_PATH"
|
||||
echo "Using run_id: $RUN_ID"
|
||||
|
||||
# run evaluation
|
||||
python -m swebench.harness.run_evaluation \
|
||||
--dataset_name "$DATASET_NAME" \
|
||||
--predictions_path "$PREDICTIONS_PATH" \
|
||||
--max_workers 20 \
|
||||
--cache_level instance \
|
||||
--run_id "$RUN_ID"
|
||||
|
||||
|
||||
1914
figures/framework.svg
Normal file
1914
figures/framework.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 236 KiB |
41
requirements.txt
Normal file
41
requirements.txt
Normal file
@@ -0,0 +1,41 @@
|
||||
# HTTP client
|
||||
httpx>=0.25.0
|
||||
requests>=2.28.0
|
||||
|
||||
# Data Validation and Serialization
|
||||
pydantic>=2.0.0
|
||||
|
||||
# log record
|
||||
loguru>=0.7.0
|
||||
|
||||
# type-check support
|
||||
typing-extensions>=4.5.0
|
||||
|
||||
# env management
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# YAML config file parse
|
||||
PyYAML>=6.0
|
||||
|
||||
# dataset process
|
||||
datasets>=2.14.0
|
||||
|
||||
# JSONPath parse
|
||||
jsonpath-ng>=1.6.0
|
||||
|
||||
# Docker support
|
||||
docker>=6.0.0
|
||||
|
||||
# cli interact support
|
||||
pexpect>=4.8.0
|
||||
|
||||
# SWE-bench relative requirements
|
||||
swebench>=1.0.0
|
||||
|
||||
# develop requirements (optional)
|
||||
# pytest>=7.0.0
|
||||
# pytest-asyncio>=0.21.0
|
||||
# black>=23.0.0
|
||||
# flake8>=6.0.0
|
||||
# mypy>=1.0.0
|
||||
pydantic>=2.11.9
|
||||
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# src directory marker file
|
||||
16
src/managers/__init__.py
Normal file
16
src/managers/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Export all modules
|
||||
from .image_builder import *
|
||||
from .log import *
|
||||
from .llm_api import *
|
||||
from .prompts import *
|
||||
from .loop import *
|
||||
from .decorators import *
|
||||
|
||||
__all__ = [
|
||||
"image_builder",
|
||||
"log",
|
||||
"llm_api",
|
||||
"prompts",
|
||||
"loop",
|
||||
"decorators"
|
||||
]
|
||||
13
src/managers/decorators/singleton.py
Normal file
13
src/managers/decorators/singleton.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import threading
|
||||
|
||||
def singleton(cls):
|
||||
instances = {}
|
||||
lock = threading.Lock()
|
||||
|
||||
def get_instance(*args, **kwargs):
|
||||
if cls not in instances:
|
||||
with lock:
|
||||
if cls not in instances:
|
||||
instances[cls] = cls(*args, **kwargs)
|
||||
return instances[cls]
|
||||
return get_instance
|
||||
8
src/managers/image_builder/__init__.py
Normal file
8
src/managers/image_builder/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Export all modules
|
||||
from .build_image import *
|
||||
from .dockerfiles import *
|
||||
|
||||
__all__ = [
|
||||
"build_image",
|
||||
"dockerfiles"
|
||||
]
|
||||
1009
src/managers/image_builder/build_image.py
Normal file
1009
src/managers/image_builder/build_image.py
Normal file
File diff suppressed because it is too large
Load Diff
8
src/managers/image_builder/dockerfiles.py
Normal file
8
src/managers/image_builder/dockerfiles.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# User image Dockerfile template
|
||||
# image_key: instance_image_key corresponding to instance_id
|
||||
# apt_packages: split by ' ', load multi packages needed to be installed in ubuntu
|
||||
_DOCKERFILE_USER_IMAGE_PY = r"""
|
||||
FROM {image_key}
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends {apt_packages}
|
||||
"""
|
||||
125
src/managers/image_builder/logger_patch.py
Normal file
125
src/managers/image_builder/logger_patch.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
User defined logger for swebench
|
||||
"""
|
||||
import logging
|
||||
from typing import Union, Optional
|
||||
|
||||
from swebench.harness.docker_utils import remove_image
|
||||
from src.managers.log.logger import Logger as CustomLogger
|
||||
|
||||
from swebench.harness.docker_build import build_instance_image, run_threadpool
|
||||
|
||||
|
||||
def patched_build_instance_images(
|
||||
client,
|
||||
dataset: list,
|
||||
force_rebuild: bool = False,
|
||||
max_workers: int = 4,
|
||||
namespace: str = None,
|
||||
tag: str = None,
|
||||
env_image_tag: str = None,
|
||||
custom_logger: Optional[Union[logging.Logger, CustomLogger]] = None,
|
||||
):
|
||||
"""
|
||||
Monkey patched version of build_instance_images that supports custom logger
|
||||
|
||||
Args:
|
||||
client: Docker client
|
||||
dataset: List of test specs or dataset to build images for
|
||||
force_rebuild: Whether to force rebuild the images even if they already exist
|
||||
max_workers: Maximum number of worker threads
|
||||
namespace: Namespace for images
|
||||
tag: Tag for images
|
||||
env_image_tag: Environment image tag
|
||||
custom_logger: Custom logger to use (instead of creating new ones)
|
||||
"""
|
||||
from swebench.harness.docker_build import make_test_spec, build_env_images
|
||||
|
||||
test_specs = []
|
||||
for instance in dataset:
|
||||
spec = make_test_spec(
|
||||
instance,
|
||||
namespace=namespace,
|
||||
instance_image_tag=tag,
|
||||
env_image_tag=env_image_tag,
|
||||
)
|
||||
test_specs.append(spec)
|
||||
|
||||
if force_rebuild:
|
||||
for spec in test_specs:
|
||||
remove_image(client, spec.instance_image_key, "quiet")
|
||||
|
||||
_, env_failed = build_env_images(client, test_specs, force_rebuild, max_workers)
|
||||
|
||||
if len(env_failed) > 0:
|
||||
# Don't build images for instances that depend on failed-to-build env images
|
||||
dont_run_specs = [
|
||||
spec for spec in test_specs if spec.env_image_key in env_failed
|
||||
]
|
||||
test_specs = [
|
||||
spec for spec in test_specs if spec.env_image_key not in env_failed
|
||||
]
|
||||
if custom_logger:
|
||||
custom_logger.info(
|
||||
f"Skipping {len(dont_run_specs)} instances - due to failed env image builds"
|
||||
)
|
||||
else:
|
||||
print(f"Skipping {len(dont_run_specs)} instances - due to failed env image builds")
|
||||
|
||||
if custom_logger:
|
||||
custom_logger.info(f"Building instance images for {len(test_specs)} instances")
|
||||
else:
|
||||
print(f"Building instance images for {len(test_specs)} instances")
|
||||
|
||||
successful, failed = [], []
|
||||
|
||||
if custom_logger:
|
||||
payloads = [(spec, client, custom_logger, False) for spec in test_specs]
|
||||
else:
|
||||
payloads = [(spec, client, None, False) for spec in test_specs]
|
||||
|
||||
successful, failed = run_threadpool(build_instance_image, payloads, max_workers)
|
||||
|
||||
if len(failed) == 0:
|
||||
if custom_logger:
|
||||
custom_logger.info("All instance images built successfully.")
|
||||
else:
|
||||
print("All instance images built successfully.")
|
||||
else:
|
||||
if custom_logger:
|
||||
custom_logger.warning(f"{len(failed)} instance images failed to build.")
|
||||
else:
|
||||
print(f"{len(failed)} instance images failed to build.")
|
||||
return successful, failed
|
||||
|
||||
|
||||
def apply_logger_patch():
|
||||
import swebench.harness.docker_build as docker_build_module
|
||||
|
||||
original_build_instance_images = docker_build_module.build_instance_images
|
||||
|
||||
docker_build_module.build_instance_images = patched_build_instance_images
|
||||
|
||||
return original_build_instance_images
|
||||
|
||||
|
||||
def restore_logger_patch(original_function):
|
||||
"""Recover original logger actions"""
|
||||
import swebench.harness.docker_build as docker_build_module
|
||||
docker_build_module.build_instance_images = original_function
|
||||
|
||||
|
||||
class LoggerPatch:
|
||||
"""Context manager"""
|
||||
|
||||
def __init__(self, logger: Optional[Union[logging.Logger, CustomLogger]] = None):
|
||||
self.logger = logger
|
||||
self.original_function = None
|
||||
|
||||
def __enter__(self):
|
||||
self.original_function = apply_logger_patch()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.original_function:
|
||||
restore_logger_patch(self.original_function)
|
||||
132
src/managers/image_builder/print_redirect.py
Normal file
132
src/managers/image_builder/print_redirect.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Redirect print output of a third party repo to a specific log
|
||||
"""
|
||||
import sys
|
||||
import logging
|
||||
import re
|
||||
from typing import Union, Set, Pattern
|
||||
from typing import Optional, TextIO
|
||||
from pathlib import Path
|
||||
from src.managers.log.logger import Logger as CustomLogger
|
||||
|
||||
|
||||
class StreamWrapper:
|
||||
"""Simple stream wrapper,for redirecting stdout/stderr"""
|
||||
|
||||
def __init__(self, original_stream, write_func):
|
||||
self.original_stream = original_stream
|
||||
self.write_func = write_func
|
||||
self.buffer = ""
|
||||
|
||||
def write(self, text):
|
||||
self.buffer += text
|
||||
|
||||
while '\n' in self.buffer:
|
||||
line, self.buffer = self.buffer.split('\n', 1)
|
||||
if line.strip():
|
||||
self.write_func(line + '\n')
|
||||
|
||||
return len(text)
|
||||
|
||||
def flush(self):
|
||||
if self.buffer:
|
||||
if self.buffer.strip():
|
||||
self.write_func(self.buffer)
|
||||
self.buffer = ""
|
||||
if hasattr(self.original_stream, 'flush'):
|
||||
self.original_stream.flush()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.original_stream, name)
|
||||
|
||||
|
||||
class PrintRedirector:
|
||||
"""Redirect print output of a third party repo to a specific log"""
|
||||
|
||||
TRACEBACK_START_PATTERN = re.compile(r'Traceback\s*\(most recent call last\):', re.IGNORECASE)
|
||||
EXCEPTION_END_PATTERN = re.compile(
|
||||
r'\b(Error|Exception|KeyError|ValueError|TypeError|AttributeError|'
|
||||
r'ImportError|BuildError|DockerError|IndentationError|SyntaxError|'
|
||||
r'RuntimeError|OSError|FileNotFoundError|PermissionError)\s*:',
|
||||
re.IGNORECASE
|
||||
)
|
||||
|
||||
ERROR_KEYWORDS = {'error', 'failed', 'exception', 'fatal', 'critical'}
|
||||
WARNING_KEYWORDS = {'warning', 'warn', 'deprecated'}
|
||||
SKIP_KEYWORDS = {'skipping', 'skip', 'ignoring', 'ignore'}
|
||||
|
||||
def __init__(self, logger: Union[logging.Logger, CustomLogger]):
|
||||
self.logger = logger
|
||||
self.original_print = None
|
||||
self.original_stdout = None
|
||||
self.original_stderr = None
|
||||
|
||||
self.traceback_buffer = []
|
||||
self.in_traceback = False
|
||||
|
||||
def __enter__(self):
|
||||
self.start_redirect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop_redirect()
|
||||
|
||||
def start_redirect(self):
|
||||
if isinstance(__builtins__, dict):
|
||||
self.original_print = __builtins__['print']
|
||||
else:
|
||||
self.original_print = getattr(__builtins__, 'print')
|
||||
if isinstance(__builtins__, dict):
|
||||
__builtins__['print'] = self._redirected_print
|
||||
else:
|
||||
setattr(__builtins__, 'print', self._redirected_print)
|
||||
|
||||
def stop_redirect(self):
|
||||
if self.original_print:
|
||||
if isinstance(__builtins__, dict):
|
||||
__builtins__['print'] = self.original_print
|
||||
else:
|
||||
setattr(__builtins__, 'print', self.original_print)
|
||||
|
||||
def _redirected_print(self, *args, **kwargs):
|
||||
if not args:
|
||||
return
|
||||
output = ' '.join(str(arg) for arg in args)
|
||||
|
||||
if self.TRACEBACK_START_PATTERN.search(output):
|
||||
self.in_traceback = True
|
||||
self.traceback_buffer = [output]
|
||||
return
|
||||
|
||||
if self.in_traceback:
|
||||
self.traceback_buffer.append(output)
|
||||
|
||||
if self.EXCEPTION_END_PATTERN.search(output):
|
||||
self._log_traceback_and_reset()
|
||||
return
|
||||
|
||||
self._log_by_level(output)
|
||||
|
||||
def _log_traceback_and_reset(self):
|
||||
full_traceback = '\n'.join(self.traceback_buffer)
|
||||
self.logger.error(f"[Third party repo exception stack]\n{full_traceback}")
|
||||
|
||||
self.in_traceback = False
|
||||
self.traceback_buffer.clear()
|
||||
|
||||
def _log_by_level(self, output: str):
|
||||
output_lower = output.lower()
|
||||
|
||||
if any(keyword in output_lower for keyword in self.ERROR_KEYWORDS):
|
||||
self.logger.error(f"[Third party] {output}")
|
||||
elif any(keyword in output_lower for keyword in self.WARNING_KEYWORDS):
|
||||
self.logger.warning(f"[Third party] {output}")
|
||||
elif any(keyword in output_lower for keyword in self.SKIP_KEYWORDS):
|
||||
self.logger.info(f"[Third party] {output}")
|
||||
else:
|
||||
self.logger.info(f"[Third party] {output}")
|
||||
|
||||
|
||||
def redirect_swebench_prints(logger: Union[logging.Logger, CustomLogger]):
|
||||
return PrintRedirector(logger)
|
||||
|
||||
98
src/managers/llm_api/__init__.py
Normal file
98
src/managers/llm_api/__init__.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
LLM API Management Module
|
||||
|
||||
Provide a standard OpenAI format LLM API interface that supports:
|
||||
- Unified Chat Completions endpoint
|
||||
- Synchronous and asynchronous operations
|
||||
- Streaming responses
|
||||
- Tool Calling
|
||||
- Error handling and retry mechanism(s)
|
||||
|
||||
Supported providers:
|
||||
- OpenAI: OpenAI API and compatible services
|
||||
- Anthropic: Claude models
|
||||
- DeepSeek: DeepSeek models
|
||||
- Private: Private models(vLLM、TGI、Ollama etc.)
|
||||
|
||||
Usage example::
|
||||
# 1: use general manager (suggested)
|
||||
from llm_api import LLMAPIManager
|
||||
|
||||
# create manager
|
||||
manager = LLMAPIManager(
|
||||
client_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
stream=False
|
||||
)
|
||||
|
||||
# sendf messages
|
||||
response = manager.chat("Hello world!")
|
||||
print(response)
|
||||
|
||||
# 2: Use client directly
|
||||
from llm_api import OpenAIClient, ChatMessage, MessageRole
|
||||
|
||||
# creat client
|
||||
client = OpenAIClient(api_key="your-api-key")
|
||||
|
||||
# create4 message
|
||||
messages = [
|
||||
ChatMessage(role=MessageRole.USER, content="你好,世界!")
|
||||
]
|
||||
|
||||
# send request
|
||||
request = client.create_request(messages=messages, model="gpt-3.5-turbo")
|
||||
response = client.chat_completions_create(request)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
|
||||
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
|
||||
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
|
||||
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
|
||||
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
|
||||
|
||||
|
||||
from src.managers.llm_api.api_manager import (
|
||||
LLMAPIManager,
|
||||
create_manager,
|
||||
create_common_manager,
|
||||
COMMON_CONFIGS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMAPI",
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatCompletionResponse",
|
||||
"ChatCompletionChunk",
|
||||
"MessageRole",
|
||||
"Choice",
|
||||
"Usage",
|
||||
"OpenAIClient",
|
||||
"AnthropicClient",
|
||||
"DeepSeekClient",
|
||||
"OpenRouterClient",
|
||||
"PrivateModelClient",
|
||||
"LLMAPIManager",
|
||||
"create_manager",
|
||||
"create_common_manager",
|
||||
"COMMON_CONFIGS",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Tokfinity Team"
|
||||
__description__ = "标准 OpenAI 格式的 LLM API 基类库"
|
||||
477
src/managers/llm_api/api_manager.py
Normal file
477
src/managers/llm_api/api_manager.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
LLM API manager
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Union, Generator
|
||||
from dotenv import load_dotenv
|
||||
import yaml
|
||||
from traceback import format_exc
|
||||
from .base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
|
||||
from .clients.openai.openai_client import OpenAIClient
|
||||
from .clients.anthropic.anthropic_client import AnthropicClient
|
||||
from .clients.deepseek.deepseek_client import DeepSeekClient
|
||||
from .clients.openrouter.openrouter_client import OpenRouterClient
|
||||
from .clients.private.private_client import PrivateModelClient
|
||||
|
||||
|
||||
class LLMAPIManager:
|
||||
SUPPORTED_CLIENTS = {
|
||||
"openai": OpenAIClient,
|
||||
"anthropic": AnthropicClient,
|
||||
"deepseek": DeepSeekClient,
|
||||
"openrouter": OpenRouterClient,
|
||||
"private": PrivateModelClient,
|
||||
}
|
||||
|
||||
DEFAULT_CONFIGS = {
|
||||
"openai": {"base_url": None, "api_key_env": "OPENAI_API_KEY"},
|
||||
"anthropic": {"base_url": None, "api_key_env": "ANTHROPIC_API_KEY"},
|
||||
"deepseek": {"base_url": None, "api_key_env": "DEEPSEEK_API_KEY"},
|
||||
"openrouter": {
|
||||
"base_url": None,
|
||||
"api_key_env": "OPENROUTER_API_KEY",
|
||||
"extra_config": {
|
||||
"app_name": "tokfinity-llm-client",
|
||||
"site_url": "https://github.com/your-repo",
|
||||
},
|
||||
},
|
||||
"private": {
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key_env": "PRIVATE_API_KEY",
|
||||
"extra_config": {"deployment_type": "vllm"},
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_name: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
auto_load_env: bool = True,
|
||||
logger: Optional[Any] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# if client not provided, get first provider from default providers in config
|
||||
if client_name is None:
|
||||
default_client, default_model = (
|
||||
self._load_default_client_and_model_from_config()
|
||||
)
|
||||
self.client_name = default_client
|
||||
self.default_model = default_model
|
||||
else:
|
||||
self.client_name = client_name.lower()
|
||||
self.default_model = None
|
||||
|
||||
self.stream = stream
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.logger = logger
|
||||
self.logger.info(f"[LLMAPIManager]: Using client: {self.client_name}, model: {self.default_model}.")
|
||||
self.config = config
|
||||
|
||||
if auto_load_env:
|
||||
self._load_environment()
|
||||
|
||||
if self.client_name not in self.SUPPORTED_CLIENTS:
|
||||
raise ValueError(
|
||||
f"Unsupported client: {client_name}。"
|
||||
f"Support client: {list(self.SUPPORTED_CLIENTS.keys())}"
|
||||
)
|
||||
|
||||
self.client = self._create_client(api_key, base_url, logger, **kwargs)
|
||||
|
||||
|
||||
def _load_environment(self) -> None:
|
||||
"""load environment variables"""
|
||||
env_paths = [".env", "../.env", "../../.env", "../../../.env"]
|
||||
|
||||
for env_path in env_paths:
|
||||
if os.path.exists(env_path):
|
||||
load_dotenv(env_path)
|
||||
break
|
||||
|
||||
def _create_client(
|
||||
self, api_key: Optional[str] = None, base_url: Optional[str] = None, logger: Optional[Any] = None, **kwargs
|
||||
) -> BaseLLMAPI:
|
||||
client_class = self.SUPPORTED_CLIENTS[self.client_name]
|
||||
config = self.DEFAULT_CONFIGS[self.client_name]
|
||||
|
||||
if api_key is None:
|
||||
api_key = os.getenv(config["api_key_env"])
|
||||
if not api_key:
|
||||
if self.client_name == "private":
|
||||
api_key = "EMPTY" # private mode may not need a key
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Fail to find env variable, please set: {config['api_key_env']} "
|
||||
f"or upload ai_key parameter when initialize"
|
||||
)
|
||||
|
||||
if base_url is None:
|
||||
env_key = f"{self.client_name.upper()}_BASE_URL"
|
||||
if self.client_name == "private":
|
||||
env_key = "PRIVATE_URL"
|
||||
|
||||
base_url = os.getenv(env_key)
|
||||
if base_url is None:
|
||||
base_url = config.get("base_url")
|
||||
|
||||
client_kwargs = {
|
||||
"api_key": api_key,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
|
||||
if logger is not None:
|
||||
client_kwargs["logger"] = logger
|
||||
|
||||
extra_config = config.get("extra_config", {})
|
||||
client_kwargs.update(extra_config)
|
||||
client_kwargs.update(kwargs)
|
||||
|
||||
if self.client_name == "openrouter":
|
||||
client_kwargs.setdefault("app_name", "tokfinity-llm-client")
|
||||
elif self.client_name == "private":
|
||||
client_kwargs.setdefault("deployment_type", "vllm")
|
||||
|
||||
return client_class(**client_kwargs)
|
||||
|
||||
def _load_default_client_and_model_from_config(self) -> (str, Optional[str]):
|
||||
"""
|
||||
Get first item from providers from config/config.yaml as the default client
|
||||
And take the first model as default model
|
||||
"""
|
||||
# Parse the root of config file (relative to the project root)
|
||||
# This file is in src/managers/llm_api/api_manager.py
|
||||
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
|
||||
config_path = os.path.join(base_dir, "config", "config.yaml")
|
||||
if not os.path.exists(config_path):
|
||||
raise FileNotFoundError(f"No config file: {config_path}")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
providers = cfg.get("providers")
|
||||
if not isinstance(providers, dict) or len(providers) == 0:
|
||||
raise ValueError("config.yaml lack of providers config or format error")
|
||||
|
||||
first_provider_name = next(iter(providers.keys()))
|
||||
models = providers.get(first_provider_name) or []
|
||||
first_model = (
|
||||
models[0] if isinstance(models, list) and len(models) > 0 else None
|
||||
)
|
||||
|
||||
client_key = first_provider_name.strip().lower()
|
||||
if client_key not in self.SUPPORTED_CLIENTS:
|
||||
raise ValueError(
|
||||
f"Default provider '{first_provider_name}' in config not registered in SUPPORTED_CLIENTS"
|
||||
)
|
||||
|
||||
return client_key, first_model
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Union[Dict[str, Any], ChatMessage]],
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
retry: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None
|
||||
]:
|
||||
"""
|
||||
Send chat message and get response
|
||||
|
||||
Args:
|
||||
model: model name (optional, default self.default_model)
|
||||
messages: Concatenated message list.
|
||||
- If list of dict: [{"role": "system|user|assistant|tool", "content": "..."}]
|
||||
- Or `ChatMessage` list
|
||||
temperature: Temperature
|
||||
max_tokens: Max tokens
|
||||
timeout: Request timeout, use the value from initialization if not specified
|
||||
retry: Max retries, use the value from initialization if not specified
|
||||
tools: Tools description list (OpenAI tools format)
|
||||
tool_choice: Tool choice strategy("auto" | "none" | {"type":..., ...})
|
||||
**kwargs: Other request args
|
||||
|
||||
Returns:
|
||||
Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None]:
|
||||
- Non-streaming: Return the complete ChatCompletionResponse object
|
||||
- Stream: Return a streaming response generator
|
||||
- Fail: Return None
|
||||
"""
|
||||
actual_timeout = timeout if timeout is not None else self.timeout
|
||||
actual_retry = retry if retry is not None else self.max_retries
|
||||
|
||||
actual_model = model if model is not None else self.default_model
|
||||
if not actual_model:
|
||||
raise ValueError("Model unprovided, and cannot get default model from config")
|
||||
|
||||
normalized_messages: List[ChatMessage] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
if msg.content is None:
|
||||
msg.content = ""
|
||||
normalized_messages.append(msg)
|
||||
else:
|
||||
role_value = msg.get("role")
|
||||
content_value = msg.get("content")
|
||||
|
||||
if content_value is None:
|
||||
content_value = ""
|
||||
|
||||
role_enum = MessageRole(role_value)
|
||||
normalized_messages.append(
|
||||
ChatMessage(
|
||||
role=role_enum,
|
||||
content=content_value,
|
||||
name=msg.get("name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
)
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
for attempt in range(actual_retry + 1):
|
||||
try:
|
||||
request = self.client.create_request(
|
||||
messages=normalized_messages,
|
||||
model=actual_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=self.stream,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
config=self.config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
response = self.client.chat_completions_create(request)
|
||||
|
||||
if self.stream:
|
||||
# Streaming response: has not processed
|
||||
return response
|
||||
else:
|
||||
# Non-streaming response: return complete ChatCompletionResponse
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < actual_retry:
|
||||
delay = min(2**attempt, 30)
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"{attempt + 1}th try failed,retry after {delay} seconds: {str(e)}, traceback: {format_exc()}."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"All {actual_retry + 1} tries failed: {str(e)}"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def chat_simple(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Union[Dict[str, Any], ChatMessage]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
retry: Optional[int] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> Optional[str]:
|
||||
# Ban stream
|
||||
original_stream = self.stream
|
||||
self.stream = False
|
||||
|
||||
try:
|
||||
response = self.chat(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
timeout=timeout,
|
||||
retry=retry,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if response and hasattr(response, "choices") and response.choices:
|
||||
return response.choices[0].message.content
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Recover original stream settings
|
||||
self.stream = original_stream
|
||||
|
||||
def create_embeddings(
|
||||
self,
|
||||
input_text: Union[str, List[str]],
|
||||
model: str,
|
||||
encoding_format: str = "float",
|
||||
dimensions: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
retry: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Optional[EmbeddingResponse]:
|
||||
actual_timeout = timeout if timeout is not None else self.timeout
|
||||
actual_retry = retry if retry is not None else self.max_retries
|
||||
|
||||
if isinstance(input_text, str):
|
||||
text_list = [input_text]
|
||||
single_input = True
|
||||
else:
|
||||
text_list = input_text
|
||||
single_input = False
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(
|
||||
f"Start embedding request - client: {self.client_name}, model: {model}, text num: {len(text_list)}"
|
||||
)
|
||||
|
||||
if not hasattr(self.client, "create_embeddings"):
|
||||
error_msg = f"Client {self.client_name} not support embedding"
|
||||
if self.logger:
|
||||
self.logger.error(error_msg)
|
||||
return None
|
||||
|
||||
last_exception = None
|
||||
for attempt in range(actual_retry + 1):
|
||||
try:
|
||||
if self.logger:
|
||||
self.logger.debug(f"{attempt + 1}th try to create embedding vector")
|
||||
|
||||
response = self.client.create_embeddings(
|
||||
input_text=text_list,
|
||||
model=model,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
timeout=actual_timeout,
|
||||
max_retries=1,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if response:
|
||||
return response
|
||||
else:
|
||||
raise Exception("Client return empty response")
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < actual_retry:
|
||||
delay = min(2**attempt, 30)
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"{attempt + 1}th embedding request failed,retry after {delay}s: {str(e)}, traceback: {format_exc()}."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"All {actual_retry + 1} embedding tries failed: {str(e)}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_client_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"client_name": self.client_name,
|
||||
"stream": self.stream,
|
||||
"client_info": (
|
||||
self.client.get_model_info()
|
||||
if hasattr(self.client, "get_model_info")
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
return self.default_model or "unknown"
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self.client, "close"):
|
||||
self.client.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"LLMAPIManager(client={self.client_name}, stream={self.stream})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"LLMAPIManager("
|
||||
f"client_name='{self.client_name}', "
|
||||
f"stream={self.stream})"
|
||||
)
|
||||
|
||||
|
||||
def create_manager(
|
||||
client_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
|
||||
) -> LLMAPIManager:
|
||||
return LLMAPIManager(
|
||||
client_name=client_name, stream=stream, logger=logger, **kwargs
|
||||
)
|
||||
|
||||
|
||||
COMMON_CONFIGS = {
|
||||
"openai_gpt4": {"client_name": "openai", "model_name": "gpt-4o"},
|
||||
"openai_gpt35": {"client_name": "openai", "model_name": "gpt-3.5-turbo"},
|
||||
"claude_sonnet": {
|
||||
"client_name": "anthropic",
|
||||
"model_name": "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
"claude_haiku": {
|
||||
"client_name": "anthropic",
|
||||
"model_name": "claude-3-haiku-20240307",
|
||||
},
|
||||
"deepseek_chat": {"client_name": "deepseek", "model_name": "deepseek-chat"},
|
||||
"deepseek_coder": {"client_name": "deepseek", "model_name": "deepseek-coder"},
|
||||
}
|
||||
|
||||
|
||||
def create_common_manager(
|
||||
config_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
|
||||
) -> LLMAPIManager:
|
||||
if config_name not in COMMON_CONFIGS:
|
||||
raise ValueError(
|
||||
f"Unknown config: {config_name}. Available configs: {list(COMMON_CONFIGS.keys())}"
|
||||
)
|
||||
|
||||
config = COMMON_CONFIGS[config_name]
|
||||
return LLMAPIManager(
|
||||
client_name=config["client_name"], stream=stream, logger=logger, **kwargs
|
||||
)
|
||||
594
src/managers/llm_api/base_client.py
Normal file
594
src/managers/llm_api/base_client.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""
|
||||
LLM API base class - standard OpenAI format
|
||||
Multi LLM providers' universal API supported
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
import asyncio
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
|
||||
class MessageRole(Enum):
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: MessageRole
|
||||
content: Union[str, List[Dict[str, Any]]]
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionRequest:
|
||||
messages: List[ChatMessage]
|
||||
model: str
|
||||
temperature: float = 0.7
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: float = 1.0
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||
response_format: Optional[Dict[str, Any]] = None
|
||||
seed: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Usage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Choice:
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionResponse:
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Choice]
|
||||
usage: Optional[Usage] = None
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionChunk:
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingRequest:
|
||||
input: Union[str, List[str]]
|
||||
model: str
|
||||
encoding_format: str = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingData:
|
||||
object: str
|
||||
embedding: List[float]
|
||||
index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingUsage:
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResponse:
|
||||
object: str
|
||||
data: List[EmbeddingData]
|
||||
model: str
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
class BaseLLMAPI(ABC):
|
||||
"""
|
||||
LLM API base class
|
||||
|
||||
Provide standard OpenAI format API, support:
|
||||
- Synchronous/Asynchronous Chat Completions
|
||||
- Streaming Responses
|
||||
- Tool Calling
|
||||
- Error handling and Retry
|
||||
- Usage Statics
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.extra_config = kwargs
|
||||
|
||||
self.logger = logger
|
||||
|
||||
self.session: Optional[requests.Session] = None
|
||||
|
||||
self._initialize_client()
|
||||
|
||||
def _create_http_clients(self, headers: Dict[str, str]) -> None:
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(headers)
|
||||
self.session.timeout = self.timeout
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_client(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
pass
|
||||
|
||||
def chat_completions_create(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None]]:
|
||||
self.validate_request(request)
|
||||
|
||||
def _make_request():
|
||||
payload = self._build_request_payload(request)
|
||||
endpoint = self._get_chat_endpoint()
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(payload, endpoint)
|
||||
else:
|
||||
|
||||
full_url = self.base_url.rstrip("/") + endpoint
|
||||
|
||||
headers = {}
|
||||
if self.session and self.session.headers:
|
||||
headers.update(self.session.headers)
|
||||
else:
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key and self.api_key != "EMPTY":
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
response = requests.post(
|
||||
full_url, json=payload, headers=headers, timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_msg = f"API Error {response.status_code}: {response.text}"
|
||||
print(error_msg)
|
||||
raise requests.exceptions.HTTPError(error_msg, response=response)
|
||||
response.raise_for_status()
|
||||
return self._parse_response(response.json())
|
||||
|
||||
return self._retry_with_backoff(_make_request)
|
||||
|
||||
async def achat_completions_create(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> Union[ChatCompletionResponse, AsyncGenerator[ChatCompletionChunk, None]]:
|
||||
self.validate_request(request)
|
||||
|
||||
def _make_async_request():
|
||||
payload = self._build_request_payload(request)
|
||||
endpoint = self._get_chat_endpoint()
|
||||
|
||||
if request.stream:
|
||||
return self._stream_chat_completion(
|
||||
payload, endpoint
|
||||
)
|
||||
else:
|
||||
full_url = self.base_url.rstrip("/") + endpoint
|
||||
|
||||
headers = {}
|
||||
if self.session and self.session.headers:
|
||||
headers.update(self.session.headers)
|
||||
else:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key and self.api_key != "EMPTY":
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
response = requests.post(
|
||||
full_url, json=payload, headers=headers, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._parse_response(response.json())
|
||||
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _make_async_request)
|
||||
|
||||
def create_message(self, role: MessageRole, content: str, config: Optional[Dict[str, Any]] = None, **kwargs) -> ChatMessage:
|
||||
return ChatMessage(role=role, content=content, **kwargs)
|
||||
|
||||
def _make_token_cache_request(self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None) -> list[ChatMessage]:
|
||||
"""
|
||||
Insert cache_control block to Claude/Anthropic model according to config, forward compatibility
|
||||
|
||||
- Only activate when model is "Claude/Anthropic"
|
||||
- Preserve the content as is for other providers to avoid breaking compatibility.
|
||||
- If token_cache.role_turns is configured, inject it matching the turn and role; otherwise, only inject it into the first system message.
|
||||
"""
|
||||
role_turns = self._load_role_turns(config)
|
||||
#if self.logger:
|
||||
# self.logger.debug(f"In _make_token_cache_request, got role_turns: {role_turns}.")
|
||||
if not self._is_claude_like(model):
|
||||
return messages
|
||||
|
||||
turn_to_roles: Dict[int, List[str]] = {}
|
||||
try:
|
||||
for rt in role_turns or []:
|
||||
try:
|
||||
turn = int(rt.get("turn"))
|
||||
role = (rt.get("role") or "").strip().lower()
|
||||
if role:
|
||||
if turn not in turn_to_roles:
|
||||
turn_to_roles[turn] = []
|
||||
turn_to_roles[turn].append(role)
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
turn_to_roles = {}
|
||||
|
||||
current_turn = 0
|
||||
result_messages: List[ChatMessage] = []
|
||||
|
||||
for msg in messages:
|
||||
msg_blocks = self._to_claude_blocks(msg.content)
|
||||
|
||||
# 0th turn rule: If met user, only add cache to its own content and put it into result.
|
||||
if current_turn == 0 and msg.role == MessageRole.USER:
|
||||
cached_blocks = self._add_cache_flag(msg_blocks)
|
||||
result_messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=cached_blocks,
|
||||
name=msg.name,
|
||||
tool_calls=msg.tool_calls,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
)
|
||||
#if self.logger:
|
||||
# self.logger.debug(
|
||||
# f"Applied cache to initial user message at turn {current_turn}."
|
||||
# )
|
||||
continue
|
||||
|
||||
# Other messages: just add it to result
|
||||
result_messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg_blocks,
|
||||
name=msg.name,
|
||||
tool_calls=msg.tool_calls,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Hit anchor point: When the current turn's configuration includes the assistant and the current message is from the assistant
|
||||
# add an extra system turn.
|
||||
roles_for_turn = turn_to_roles.get(current_turn, [])
|
||||
if msg.role == MessageRole.ASSISTANT and roles_for_turn and ("assistant" in roles_for_turn):
|
||||
refresh_text = f"refresh cache tokens."
|
||||
refresh_blocks = self._to_claude_blocks(refresh_text)
|
||||
refresh_blocks = self._add_cache_flag(refresh_blocks)
|
||||
result_messages.append(
|
||||
ChatMessage(
|
||||
role=MessageRole.SYSTEM,
|
||||
content=refresh_blocks,
|
||||
)
|
||||
)
|
||||
#if self.logger:
|
||||
# self.logger.debug(
|
||||
# f"Appended system refresh after assistant at turn {current_turn}, refresh_blocks: {refresh_blocks}."
|
||||
# )
|
||||
|
||||
# assistant message make the round +1
|
||||
if msg.role == MessageRole.ASSISTANT:
|
||||
current_turn += 1
|
||||
|
||||
return result_messages
|
||||
|
||||
def _to_claude_blocks(self, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
if isinstance(content, list):
|
||||
return content
|
||||
return [{"type": "text", "text": content}]
|
||||
|
||||
def _add_cache_flag(self, blocks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
out: List[Dict[str, Any]] = []
|
||||
added = False
|
||||
for blk in blocks:
|
||||
if not added and isinstance(blk, dict) and blk.get("type") == "text":
|
||||
out.append({**blk, "cache_control": {"type": "ephemeral"}})
|
||||
added = True
|
||||
else:
|
||||
out.append(blk)
|
||||
return out
|
||||
|
||||
def _is_claude_like(self, model: str) -> bool:
|
||||
try:
|
||||
model_lc = (model or "").lower()
|
||||
return ("claude" in model_lc) or ("anthropic" in model_lc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _load_role_turns(self, config: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
|
||||
"""Load token_cache.role_turns,Priority:Invoker config > default value"""
|
||||
try:
|
||||
role_turns = None
|
||||
|
||||
# 1) Read from invoker's config
|
||||
if config and isinstance(config, dict):
|
||||
token_cache_cfg = config.get("token_cache")
|
||||
if isinstance(token_cache_cfg, dict):
|
||||
role_turns = token_cache_cfg.get("role_turns")
|
||||
|
||||
# 2) If still lack, use default value in current example config
|
||||
if not role_turns:
|
||||
role_turns = [
|
||||
{"turn": 0, "role": "user"},
|
||||
{"turn": 25, "role": "assistant"},
|
||||
{"turn": 50, "role": "assistant"},
|
||||
{"turn": 75, "role": "assistant"},
|
||||
]
|
||||
|
||||
return role_turns
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(f"Read token_cache.role_turns fail,use default value: {e}")
|
||||
return [
|
||||
{"turn": 0, "role": "user"},
|
||||
{"turn": 25, "role": "assistant"},
|
||||
{"turn": 50, "role": "assistant"},
|
||||
{"turn": 75, "role": "assistant"},
|
||||
]
|
||||
|
||||
def create_request(
|
||||
self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None, **kwargs
|
||||
) -> ChatCompletionRequest:
|
||||
|
||||
messages = self._make_token_cache_request(messages, model, config)
|
||||
|
||||
return ChatCompletionRequest(messages=messages, model=model, **kwargs)
|
||||
|
||||
def _handle_error(self, error: Exception) -> None:
|
||||
if self.logger:
|
||||
self.logger.error(f"API request failed: {error}")
|
||||
else:
|
||||
print(f"API request failed: {error}")
|
||||
raise error
|
||||
|
||||
def _retry_with_backoff(self, func, *args, **kwargs):
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < self.max_retries:
|
||||
delay = self.retry_delay * (2**attempt)
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"{attempt + 1}th try failed,retry after {delay}s: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
if self.logger:
|
||||
self.logger.error(f"All retries failed: {e}, traceback: {format_exc()}.")
|
||||
|
||||
raise last_exception
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"provider": self.__class__.__name__,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
|
||||
def validate_request(self, request: ChatCompletionRequest) -> bool:
|
||||
if not request.messages:
|
||||
raise ValueError("Message list is empty")
|
||||
|
||||
if not request.model:
|
||||
raise ValueError("Model name is empty")
|
||||
|
||||
for idx, message in enumerate(request.messages):
|
||||
if not message.content and not message.tool_calls:
|
||||
try:
|
||||
msg_info = {
|
||||
"index": idx,
|
||||
"role": getattr(message.role, "value", str(message.role)),
|
||||
"content_len": (
|
||||
len(message.content)
|
||||
if isinstance(message.content, str)
|
||||
else (len(message.content) if isinstance(message.content, list) else 0)
|
||||
),
|
||||
"has_tool_calls": bool(message.tool_calls),
|
||||
"tool_calls": message.tool_calls,
|
||||
}
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"Request validation failed: Invalid message exists (lacking both content and tool_calls): {json.dumps(msg_info, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.warning(
|
||||
f"Request validation failed: Invalid message exists (lacking both content and tool_calls), index={idx}, error: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
raise ValueError("Cannot lacking both content and tool_calls")
|
||||
|
||||
return True
|
||||
|
||||
def format_messages_for_api(
|
||||
self, messages: List[ChatMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
formatted_messages = []
|
||||
|
||||
for message in messages:
|
||||
msg_dict = {"role": message.role.value, "content": message.content}
|
||||
|
||||
if message.name:
|
||||
msg_dict["name"] = message.name
|
||||
|
||||
if message.tool_calls:
|
||||
msg_dict["tool_calls"] = message.tool_calls
|
||||
|
||||
if message.tool_call_id:
|
||||
msg_dict["tool_call_id"] = message.tool_call_id
|
||||
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _stream_chat_completion(
|
||||
self, payload: Dict[str, Any], endpoint: str
|
||||
) -> Generator[ChatCompletionChunk, None, None]:
|
||||
full_url = self.base_url.rstrip("/") + endpoint
|
||||
|
||||
headers = {}
|
||||
if self.session and self.session.headers:
|
||||
headers.update(self.session.headers)
|
||||
else:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.api_key and self.api_key != "EMPTY":
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
response = requests.post(
|
||||
full_url, json=payload, headers=headers, timeout=self.timeout, stream=True
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
line_count = 0
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_count += 1
|
||||
line_str = line.decode("utf-8")
|
||||
chunk = self._parse_stream_line(line_str)
|
||||
if chunk:
|
||||
yield chunk
|
||||
else:
|
||||
print(f"Jump invalid line")
|
||||
print(f"Streaming request process done, {line_count} processed")
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
async def _astream_chat_completion(
|
||||
self, payload: Dict[str, Any], endpoint: str
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
import asyncio
|
||||
|
||||
def _sync_stream():
|
||||
return list(self._stream_chat_completion(payload, endpoint))
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
chunks = await loop.run_in_executor(None, _sync_stream)
|
||||
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
|
||||
# Process standard SSE format
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
|
||||
if data.strip() == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(data)
|
||||
return self._parse_stream_chunk(chunk_data)
|
||||
except json.JSONDecodeError:
|
||||
self.logger.warning(f"Unable to parse streaming data: {data}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
if self.session:
|
||||
self.session.close()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.session:
|
||||
self.session.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.aclose()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}(base_url={self.base_url})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"base_url={self.base_url}, "
|
||||
f"timeout={self.timeout}, "
|
||||
f"max_retries={self.max_retries})"
|
||||
)
|
||||
23
src/managers/llm_api/clients/__init__.py
Normal file
23
src/managers/llm_api/clients/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
LLM Client implement module
|
||||
|
||||
Diverse LLM provider:
|
||||
- OpenAI: OpenAI API and compatible service
|
||||
- Anthropic: Claude models
|
||||
- DeepSeek: DeepSeek models
|
||||
- Private: Private models(vLLM、TGI、Ollama etc.)
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
|
||||
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
|
||||
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
|
||||
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
|
||||
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
|
||||
|
||||
__all__ = [
|
||||
"OpenAIClient",
|
||||
"AnthropicClient",
|
||||
"DeepSeekClient",
|
||||
"OpenRouterClient",
|
||||
"PrivateModelClient",
|
||||
]
|
||||
7
src/managers/llm_api/clients/anthropic/__init__.py
Normal file
7
src/managers/llm_api/clients/anthropic/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Anthropic Claude Client Module
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
|
||||
|
||||
__all__ = ["AnthropicClient"]
|
||||
288
src/managers/llm_api/clients/anthropic/anthropic_client.py
Normal file
288
src/managers/llm_api/clients/anthropic/anthropic_client.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Anthropic Claude Client Module
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicClient(BaseLLMAPI):
|
||||
"""
|
||||
Anthropic Claude API Client
|
||||
Anthropic Claude models supported
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
anthropic_version: str = "2023-06-01",
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize Anthropic client
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
base_url: API basic URL,default Anthropic API
|
||||
timeout: timeout in seconds
|
||||
max_retries: max retries
|
||||
retry_delay: retry delay in seconds
|
||||
anthropic_version: API version
|
||||
**kwargs: other args
|
||||
"""
|
||||
self.anthropic_version = anthropic_version
|
||||
|
||||
if base_url is None:
|
||||
base_url = "https://api.anthropic.com"
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
"""Initialize HTTP client"""
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": self.anthropic_version,
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
}
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
return "/v1/messages"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
"""Build Anthropic API request payload"""
|
||||
messages, system_prompt = self._convert_messages_to_anthropic_format(
|
||||
request.messages
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": messages,
|
||||
"max_tokens": request.max_tokens or 1024,
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop_sequences"] = (
|
||||
request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
)
|
||||
|
||||
if request.tools is not None:
|
||||
payload["tools"] = self._convert_tools_to_anthropic_format(request.tools)
|
||||
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = self._convert_tool_choice_to_anthropic_format(
|
||||
request.tool_choice
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def _convert_messages_to_anthropic_format(
|
||||
self, messages: List[ChatMessage]
|
||||
) -> tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""Convert messages to anthropic format"""
|
||||
anthropic_messages = []
|
||||
system_prompt = None
|
||||
|
||||
for message in messages:
|
||||
if message.role == MessageRole.SYSTEM:
|
||||
system_prompt = message.content
|
||||
elif message.role in [MessageRole.USER, MessageRole.ASSISTANT]:
|
||||
anthropic_messages.append(
|
||||
{"role": message.role.value, "content": message.content}
|
||||
)
|
||||
|
||||
return anthropic_messages, system_prompt
|
||||
|
||||
def _convert_tools_to_anthropic_format(
|
||||
self, tools: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert OpenAI format tools to anthropic format"""
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
function_def = tool.get("function", {})
|
||||
anthropic_tool = {
|
||||
"name": function_def.get("name", ""),
|
||||
"description": function_def.get("description", ""),
|
||||
"input_schema": function_def.get("parameters", {}),
|
||||
}
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _convert_tool_choice_to_anthropic_format(
|
||||
self, tool_choice: Union[str, Dict[str, Any]]
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Convert OpenAI format tool choice to anthropic format"""
|
||||
if isinstance(tool_choice, str):
|
||||
if tool_choice == "auto":
|
||||
return "auto"
|
||||
elif tool_choice == "none":
|
||||
return "none"
|
||||
else:
|
||||
return {"type": "tool", "name": tool_choice}
|
||||
elif isinstance(tool_choice, dict):
|
||||
if tool_choice.get("type") == "function":
|
||||
return {
|
||||
"type": "tool",
|
||||
"name": tool_choice.get("function", {}).get("name", ""),
|
||||
}
|
||||
elif tool_choice.get("type") == "tool":
|
||||
return tool_choice
|
||||
|
||||
return "auto"
|
||||
|
||||
def _convert_anthropic_tool_calls(
|
||||
self, content_list: List[Dict[str, Any]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Convert anthropic tool calls to OpenAI format"""
|
||||
tool_calls = []
|
||||
|
||||
for item in content_list:
|
||||
if item.get("type") == "tool_use":
|
||||
tool_call = {
|
||||
"id": item.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item.get("name", ""),
|
||||
"arguments": item.get("input", {}),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
"""Convert anthropic API response to OpenAI format"""
|
||||
content = ""
|
||||
tool_calls = None
|
||||
|
||||
if response_data.get("content"):
|
||||
content_data = (
|
||||
response_data["content"][0] if response_data["content"] else {}
|
||||
)
|
||||
content = content_data.get("text", "")
|
||||
|
||||
if content_data.get("type") == "tool_use":
|
||||
tool_calls = self._convert_anthropic_tool_calls(
|
||||
response_data.get("content", [])
|
||||
)
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole.ASSISTANT, content=content, tool_calls=tool_calls
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=self._convert_stop_reason(response_data.get("stop_reason")),
|
||||
)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("input_tokens", 0),
|
||||
completion_tokens=usage_data.get("output_tokens", 0),
|
||||
total_tokens=usage_data.get("input_tokens", 0)
|
||||
+ usage_data.get("output_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=response_data.get("model", ""),
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _convert_stop_reason(self, stop_reason: Optional[str]) -> Optional[str]:
|
||||
"""Convert anthropic stop reason to OpenAI format"""
|
||||
if stop_reason == "end_turn":
|
||||
return "stop"
|
||||
elif stop_reason == "max_tokens":
|
||||
return "length"
|
||||
elif stop_reason == "stop_sequence":
|
||||
return "stop"
|
||||
else:
|
||||
return stop_reason
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
"""Convert anthropic steam response to OpenAI format"""
|
||||
event_type = chunk_data.get("type")
|
||||
|
||||
if event_type == "content_block_delta":
|
||||
delta = chunk_data.get("delta", {})
|
||||
text = delta.get("text", "")
|
||||
|
||||
if text:
|
||||
choices = [
|
||||
{"index": 0, "delta": {"content": text}, "finish_reason": None}
|
||||
]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("message", {}).get("id", ""),
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=chunk_data.get("message", {}).get("model", ""),
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
elif event_type == "message_stop":
|
||||
choices = [{"index": 0, "delta": {}, "finish_reason": "stop"}]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("message", {}).get("id", ""),
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=chunk_data.get("message", {}).get("model", ""),
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
return self._parse_stream_chunk(chunk_data)
|
||||
except json.JSONDecodeError:
|
||||
# if not json, try SSE
|
||||
return super()._parse_stream_line(line)
|
||||
7
src/managers/llm_api/clients/deepseek/__init__.py
Normal file
7
src/managers/llm_api/clients/deepseek/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
DeepSeek Client Module
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
|
||||
|
||||
__all__ = ["DeepSeekClient"]
|
||||
164
src/managers/llm_api/clients/deepseek/deepseek_client.py
Normal file
164
src/managers/llm_api/clients/deepseek/deepseek_client.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
DeepSeek Client Module
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class DeepSeekClient(BaseLLMAPI):
|
||||
"""
|
||||
DeepSeek API Client
|
||||
|
||||
DeepSeek models supported, including DeepSeek-Coder、DeepSeek-Chat etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if base_url is None:
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
}
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
return "/chat/completions"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
# optional
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
if request.frequency_penalty != 0.0:
|
||||
payload["frequency_penalty"] = request.frequency_penalty
|
||||
|
||||
if request.presence_penalty != 0.0:
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
|
||||
if request.tools is not None:
|
||||
payload["tools"] = request.tools
|
||||
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
if request.response_format is not None:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
if request.seed is not None:
|
||||
payload["seed"] = request.seed
|
||||
|
||||
if request.user is not None:
|
||||
payload["user"] = request.user
|
||||
|
||||
return payload
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
choices = []
|
||||
for choice_data in response_data.get("choices", []):
|
||||
message_data = choice_data.get("message", {})
|
||||
|
||||
tool_calls = message_data.get("tool_calls")
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole(message_data.get("role", "assistant")),
|
||||
content=message_data.get("content", ""),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=choice_data.get("index", 0),
|
||||
message=message,
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
logprobs=choice_data.get("logprobs"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object=response_data.get("object", "chat.completion"),
|
||||
created=response_data.get("created", int(time.time())),
|
||||
model=response_data.get("model", ""),
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
system_fingerprint=response_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
"""Parse stream chunk"""
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("id", ""),
|
||||
object=chunk_data.get("object", "chat.completion.chunk"),
|
||||
created=chunk_data.get("created", int(time.time())),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=chunk_data.get("choices", []),
|
||||
system_fingerprint=chunk_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def list_models(self) -> Dict[str, Any]:
|
||||
"""Get available models"""
|
||||
response = self.client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def alist_models(self) -> Dict[str, Any]:
|
||||
response = await self.async_client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
7
src/managers/llm_api/clients/openai/__init__.py
Normal file
7
src/managers/llm_api/clients/openai/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
OpenAI Client Module
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
|
||||
|
||||
__all__ = ["OpenAIClient"]
|
||||
279
src/managers/llm_api/clients/openai/openai_client.py
Normal file
279
src/managers/llm_api/clients/openai/openai_client.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
OpenAI Client
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingData,
|
||||
EmbeddingUsage,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMAPI):
|
||||
"""
|
||||
OpenAI API Client
|
||||
|
||||
OpenAI API supported, and other API services compatible with the OpenAI format
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.organization = organization
|
||||
|
||||
if base_url is None:
|
||||
base_url = "https://api.openai.com/v1"
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
}
|
||||
|
||||
if self.organization:
|
||||
headers["OpenAI-Organization"] = self.organization
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
return "/chat/completions"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
# Optional
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
if request.tools is not None:
|
||||
payload["tools"] = request.tools
|
||||
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
if request.response_format is not None:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
if request.seed is not None:
|
||||
payload["seed"] = request.seed
|
||||
|
||||
if request.user is not None:
|
||||
payload["user"] = request.user
|
||||
|
||||
return payload
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
choices = []
|
||||
for choice_data in response_data.get("choices", []):
|
||||
message_data = choice_data.get("message", {})
|
||||
|
||||
tool_calls = message_data.get("tool_calls")
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole(message_data.get("role", "assistant")),
|
||||
content=message_data.get("content", ""),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=choice_data.get("index", 0),
|
||||
message=message,
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
logprobs=choice_data.get("logprobs"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object=response_data.get("object", "chat.completion"),
|
||||
created=response_data.get("created", int(time.time())),
|
||||
model=response_data.get("model", ""),
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
system_fingerprint=response_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("id", ""),
|
||||
object=chunk_data.get("object", "chat.completion.chunk"),
|
||||
created=chunk_data.get("created", int(time.time())),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=chunk_data.get("choices", []),
|
||||
system_fingerprint=chunk_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def list_models(self) -> Dict[str, Any]:
|
||||
response = self.client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def alist_models(self) -> Dict[str, Any]:
|
||||
response = await self.async_client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def create_embeddings(
|
||||
self,
|
||||
input_text: Union[str, List[str]],
|
||||
model: str,
|
||||
encoding_format: str = "float",
|
||||
dimensions: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
retry_delay: Optional[float] = None,
|
||||
) -> EmbeddingResponse:
|
||||
|
||||
actual_timeout = timeout if timeout is not None else self.timeout
|
||||
actual_max_retries = (
|
||||
max_retries if max_retries is not None else self.max_retries
|
||||
)
|
||||
actual_retry_delay = (
|
||||
retry_delay if retry_delay is not None else self.retry_delay
|
||||
)
|
||||
|
||||
request = EmbeddingRequest(
|
||||
input=input_text,
|
||||
model=model,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
)
|
||||
|
||||
payload = self._build_embedding_request_payload(request)
|
||||
|
||||
for attempt in range(actual_max_retries + 1):
|
||||
try:
|
||||
print(f"Debug: Payload: {payload}")
|
||||
|
||||
response = self.session.post(
|
||||
f"{self.base_url}/embeddings", json=payload, timeout=actual_timeout
|
||||
)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
return self._parse_embedding_response(response_data)
|
||||
else:
|
||||
error_msg = f"embedding failed (try {attempt + 1}): HTTP {response.status_code}"
|
||||
if hasattr(response, "text"):
|
||||
error_msg += f" - {response.text}"
|
||||
print(f"Debug: {error_msg}")
|
||||
|
||||
if attempt < actual_max_retries:
|
||||
print(f"Debug: wait {actual_retry_delay} and retry...")
|
||||
time.sleep(actual_retry_delay)
|
||||
continue
|
||||
else:
|
||||
raise Exception(f"All retries failed: {error_msg}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Embedding request failed (try {attempt + 1}): {str(e)}, traceback: {format_exc()}."
|
||||
print(f"Debug: {error_msg}")
|
||||
|
||||
if attempt < actual_max_retries:
|
||||
print(f"Debug: wait {actual_retry_delay} and retry...")
|
||||
time.sleep(actual_retry_delay)
|
||||
continue
|
||||
else:
|
||||
raise Exception(f"All retries failed: {str(e)}, traceback: {format_exc()}.")
|
||||
|
||||
raise Exception("Unknown error")
|
||||
|
||||
def _build_embedding_request_payload(
|
||||
self, request: EmbeddingRequest
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
payload = {
|
||||
"input": request.input,
|
||||
"model": request.model,
|
||||
"encoding_format": request.encoding_format,
|
||||
}
|
||||
|
||||
if request.dimensions is not None:
|
||||
payload["dimensions"] = request.dimensions
|
||||
|
||||
if request.user is not None:
|
||||
payload["user"] = request.user
|
||||
|
||||
return payload
|
||||
|
||||
def _parse_embedding_response(
|
||||
self, response_data: Dict[str, Any]
|
||||
) -> EmbeddingResponse:
|
||||
embedding_data_list = []
|
||||
for data_item in response_data.get("data", []):
|
||||
embedding_data = EmbeddingData(
|
||||
object=data_item.get("object", "embedding"),
|
||||
embedding=data_item.get("embedding", []),
|
||||
index=data_item.get("index", 0),
|
||||
)
|
||||
embedding_data_list.append(embedding_data)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = EmbeddingUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
object=response_data.get("object", "list"),
|
||||
data=embedding_data_list,
|
||||
model=response_data.get("model", ""),
|
||||
usage=usage,
|
||||
)
|
||||
8
src/managers/llm_api/clients/openrouter/__init__.py
Normal file
8
src/managers/llm_api/clients/openrouter/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
OpenRouter Client
|
||||
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
|
||||
|
||||
__all__ = ["OpenRouterClient"]
|
||||
329
src/managers/llm_api/clients/openrouter/openrouter_client.py
Normal file
329
src/managers/llm_api/clients/openrouter/openrouter_client.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
OpenRouter Client
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from src.managers.log.logger import Logger
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class OpenRouterClient(BaseLLMAPI):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
app_name: Optional[str] = None,
|
||||
site_url: Optional[str] = None,
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize OpenRouter Client
|
||||
|
||||
Args:
|
||||
api_key: OpenRouter API key
|
||||
base_url: API base URL, default to OpenRouter official API
|
||||
app_name: Application name (optional, for statistics)
|
||||
site_url: Website URL (optional, for statistics)
|
||||
timeout: Request timeout
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Retry delay
|
||||
**kwargs: Other configuration parameters
|
||||
"""
|
||||
self.app_name = app_name or "tokfinity-llm-client"
|
||||
self.site_url = site_url
|
||||
|
||||
if base_url is None:
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
"X-Title": self.app_name,
|
||||
}
|
||||
|
||||
if self.site_url:
|
||||
headers["HTTP-Referer"] = self.site_url
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
|
||||
return "/chat/completions"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
if request.frequency_penalty != 0.0:
|
||||
payload["frequency_penalty"] = request.frequency_penalty
|
||||
|
||||
if request.presence_penalty != 0.0:
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
|
||||
if request.tools is not None:
|
||||
payload["tools"] = request.tools
|
||||
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
if request.response_format is not None:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
if request.seed is not None:
|
||||
payload["seed"] = request.seed
|
||||
|
||||
if request.user is not None:
|
||||
payload["user"] = request.user
|
||||
|
||||
return payload
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
"""
|
||||
Parse OpenRouter API response
|
||||
|
||||
Args:
|
||||
response_data: API response data
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: Parsed response object
|
||||
"""
|
||||
choices = []
|
||||
for choice_data in response_data.get("choices", []):
|
||||
message_data = choice_data.get("message", {})
|
||||
|
||||
tool_calls = message_data.get("tool_calls")
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole(message_data.get("role", "assistant")),
|
||||
content=message_data.get("content", ""),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=choice_data.get("index", 0),
|
||||
message=message,
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
logprobs=choice_data.get("logprobs"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object=response_data.get("object", "chat.completion"),
|
||||
created=response_data.get("created", int(time.time())),
|
||||
model=response_data.get("model", ""),
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
system_fingerprint=response_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
"""
|
||||
Parse stream chunk
|
||||
Args:
|
||||
chunk_data: Raw chunk data
|
||||
|
||||
Returns:
|
||||
Optional[ChatCompletionChunk]: Parsed chunk
|
||||
"""
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("id", ""),
|
||||
object=chunk_data.get("object", "chat.completion.chunk"),
|
||||
created=chunk_data.get("created", int(time.time())),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=chunk_data.get("choices", []),
|
||||
system_fingerprint=chunk_data.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
def list_models(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get available models
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model list response
|
||||
"""
|
||||
response = self.client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def alist_models(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Async get available models
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model list response
|
||||
"""
|
||||
response = await self.async_client.get("/models")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_generation_info(self, generation_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get specific generation request details
|
||||
|
||||
Args:
|
||||
generation_id: Generation request ID
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Generation information
|
||||
"""
|
||||
response = self.client.get(f"/generation?id={generation_id}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def aget_generation_info(self, generation_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Async get specific generation request details
|
||||
|
||||
Args:
|
||||
generation_id: Generation request ID
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Generation information
|
||||
"""
|
||||
response = await self.async_client.get(f"/generation?id={generation_id}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_account_credits(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get account credits
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Account credits information
|
||||
"""
|
||||
response = self.client.get("/auth/key")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def aget_account_credits(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Async get account credits
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Account credits information
|
||||
"""
|
||||
response = await self.async_client.get("/auth/key")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_popular_models() -> List[str]:
|
||||
"""
|
||||
Get popular models
|
||||
|
||||
Returns:
|
||||
List[str]: Popular model names
|
||||
"""
|
||||
return [
|
||||
# OpenAI
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"openai/gpt-4",
|
||||
"openai/gpt-3.5-turbo",
|
||||
# Anthropic
|
||||
"anthropic/claude-3-opus",
|
||||
"anthropic/claude-3-sonnet",
|
||||
"anthropic/claude-3-haiku",
|
||||
# Google
|
||||
"google/gemini-pro",
|
||||
"google/gemini-pro-vision",
|
||||
# Meta
|
||||
"meta-llama/llama-2-70b-chat",
|
||||
"meta-llama/llama-2-13b-chat",
|
||||
# Mistral
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
"mistralai/mistral-7b-instruct",
|
||||
# Open Source
|
||||
"microsoft/wizardlm-2-8x22b",
|
||||
"databricks/dbrx-instruct",
|
||||
"cohere/command-r-plus",
|
||||
]
|
||||
|
||||
def get_model_info(self, model_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get specific model details
|
||||
|
||||
Args:
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model information
|
||||
"""
|
||||
models_response = self.list_models()
|
||||
models = models_response.get("data", [])
|
||||
|
||||
for model in models:
|
||||
if model.get("id") == model_name:
|
||||
return model
|
||||
|
||||
raise ValueError(f"Model {model_name} not found")
|
||||
|
||||
async def aget_model_info(self, model_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Async get specific model details
|
||||
|
||||
Args:
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Model information
|
||||
"""
|
||||
models_response = await self.alist_models()
|
||||
models = models_response.get("data", [])
|
||||
|
||||
for model in models:
|
||||
if model.get("id") == model_name:
|
||||
return model
|
||||
|
||||
raise ValueError(f"Model {model_name} not found")
|
||||
7
src/managers/llm_api/clients/private/__init__.py
Normal file
7
src/managers/llm_api/clients/private/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Private model client
|
||||
"""
|
||||
|
||||
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
|
||||
|
||||
__all__ = ["PrivateModelClient"]
|
||||
321
src/managers/llm_api/clients/private/private_client.py
Normal file
321
src/managers/llm_api/clients/private/private_client.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Private model client
|
||||
Implement of private model API based on the BaseLLMAPI
|
||||
vLLM, Text Generation Inference, Ollama supported
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
from src.managers.llm_api.base_client import (
|
||||
BaseLLMAPI,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatMessage,
|
||||
MessageRole,
|
||||
Choice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class PrivateModelClient(BaseLLMAPI):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
timeout: int = 60,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
deployment_type: str = "vllm",
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
supports_tools: bool = True,
|
||||
logger: Optional[Logger] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.deployment_type = deployment_type.lower()
|
||||
self.custom_headers = custom_headers or {}
|
||||
self.supports_tools = supports_tools
|
||||
super().__init__(
|
||||
api_key=api_key or "EMPTY",
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
logger=logger,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _initialize_client(self) -> None:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "tokfinity-llm-client/1.0",
|
||||
}
|
||||
|
||||
if self.api_key and self.api_key != "EMPTY":
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
headers.update(self.custom_headers)
|
||||
|
||||
|
||||
if self.deployment_type == "ollama":
|
||||
pass
|
||||
elif self.deployment_type == "tgi":
|
||||
pass
|
||||
|
||||
self._create_http_clients(headers)
|
||||
|
||||
def _get_chat_endpoint(self) -> str:
|
||||
if self.deployment_type == "ollama":
|
||||
return "/api/chat"
|
||||
elif self.deployment_type == "tgi":
|
||||
return "/generate"
|
||||
else:
|
||||
return "/chat/completions"
|
||||
|
||||
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
if self.deployment_type == "ollama":
|
||||
return self._build_ollama_payload(request)
|
||||
elif self.deployment_type == "tgi":
|
||||
return self._build_tgi_payload(request)
|
||||
else:
|
||||
return self._build_openai_payload(request)
|
||||
|
||||
def _build_openai_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
||||
# Optional
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["stop"] = request.stop
|
||||
|
||||
if request.frequency_penalty != 0.0:
|
||||
payload["frequency_penalty"] = request.frequency_penalty
|
||||
|
||||
if request.presence_penalty != 0.0:
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
|
||||
if self.supports_tools and request.tools is not None:
|
||||
payload["tools"] = request.tools
|
||||
|
||||
if self.supports_tools and request.tool_choice is not None:
|
||||
payload["tool_choice"] = request.tool_choice
|
||||
|
||||
if request.response_format is not None:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
return payload
|
||||
|
||||
def _build_ollama_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"messages": self.format_messages_for_api(request.messages),
|
||||
"stream": request.stream,
|
||||
"options": {
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if request.max_tokens is not None:
|
||||
payload["options"]["num_predict"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["options"]["stop"] = (
|
||||
request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def _build_tgi_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
|
||||
prompt = self._messages_to_prompt(request.messages)
|
||||
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"temperature": request.temperature,
|
||||
"top_p": request.top_p,
|
||||
"do_sample": True,
|
||||
"stream": request.stream,
|
||||
},
|
||||
}
|
||||
|
||||
if request.max_tokens is not None:
|
||||
payload["parameters"]["max_new_tokens"] = request.max_tokens
|
||||
|
||||
if request.stop is not None:
|
||||
payload["parameters"]["stop_sequences"] = (
|
||||
request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def _messages_to_prompt(self, messages: List[ChatMessage]) -> str:
|
||||
prompt_parts = []
|
||||
|
||||
for message in messages:
|
||||
if message.role == MessageRole.SYSTEM:
|
||||
prompt_parts.append(f"System: {message.content}")
|
||||
elif message.role == MessageRole.USER:
|
||||
prompt_parts.append(f"User: {message.content}")
|
||||
elif message.role == MessageRole.ASSISTANT:
|
||||
prompt_parts.append(f"Assistant: {message.content}")
|
||||
|
||||
prompt_parts.append("Assistant:")
|
||||
return "\n\n".join(prompt_parts)
|
||||
|
||||
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
|
||||
if self.deployment_type == "ollama":
|
||||
return self._parse_ollama_response(response_data)
|
||||
elif self.deployment_type == "tgi":
|
||||
return self._parse_tgi_response(response_data)
|
||||
else:
|
||||
return self._parse_openai_response(response_data)
|
||||
|
||||
def _parse_openai_response(
|
||||
self, response_data: Dict[str, Any]
|
||||
) -> ChatCompletionResponse:
|
||||
choices = []
|
||||
for choice_data in response_data.get("choices", []):
|
||||
message_data = choice_data.get("message", {})
|
||||
|
||||
message = ChatMessage(
|
||||
role=MessageRole(message_data.get("role", "assistant")),
|
||||
content=message_data.get("content", ""),
|
||||
tool_calls=message_data.get("tool_calls"),
|
||||
)
|
||||
|
||||
choice = Choice(
|
||||
index=choice_data.get("index", 0),
|
||||
message=message,
|
||||
finish_reason=choice_data.get("finish_reason"),
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
usage_data = response_data.get("usage", {})
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_data.get("id", ""),
|
||||
object=response_data.get("object", "chat.completion"),
|
||||
created=response_data.get("created", int(time.time())),
|
||||
model=response_data.get("model", ""),
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _parse_ollama_response(
|
||||
self, response_data: Dict[str, Any]
|
||||
) -> ChatCompletionResponse:
|
||||
content = response_data.get("message", {}).get("content", "")
|
||||
|
||||
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
|
||||
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason="stop" if response_data.get("done", False) else None,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=f"ollama-{int(time.time())}",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model=response_data.get("model", ""),
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
def _parse_tgi_response(
|
||||
self, response_data: Dict[str, Any]
|
||||
) -> ChatCompletionResponse:
|
||||
if isinstance(response_data, list) and len(response_data) > 0:
|
||||
response_data = response_data[0]
|
||||
|
||||
content = response_data.get("generated_text", "")
|
||||
|
||||
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
|
||||
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=response_data.get("finish_reason", "stop"),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=f"tgi-{int(time.time())}",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model="tgi-model",
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
|
||||
if self.deployment_type == "ollama":
|
||||
try:
|
||||
chunk_data = json.loads(line)
|
||||
return self._parse_ollama_chunk(chunk_data)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
else:
|
||||
# Standard SSE format
|
||||
if line.startswith("data: "):
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
chunk_data = json.loads(data)
|
||||
return self._parse_stream_chunk(chunk_data)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _parse_ollama_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
content = chunk_data.get("message", {}).get("content", "")
|
||||
done = chunk_data.get("done", False)
|
||||
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": content} if content else {},
|
||||
"finish_reason": "stop" if done else None,
|
||||
}
|
||||
]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id=f"ollama-{int(time.time())}",
|
||||
object="chat.completion.chunk",
|
||||
created=int(time.time()),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
self, chunk_data: Dict[str, Any]
|
||||
) -> Optional[ChatCompletionChunk]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk_data.get("id", ""),
|
||||
object=chunk_data.get("object", "chat.completion.chunk"),
|
||||
created=chunk_data.get("created", int(time.time())),
|
||||
model=chunk_data.get("model", ""),
|
||||
choices=chunk_data.get("choices", []),
|
||||
)
|
||||
43
src/managers/log/__init__.py
Normal file
43
src/managers/log/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Log manage Module
|
||||
|
||||
Provide log management functionality categorized by timestamp and level
|
||||
|
||||
Usage example:
|
||||
from src.managers.log import Logger, create_logger
|
||||
|
||||
# Basic Usage
|
||||
logger = Logger("logs", "my_app")
|
||||
logger.info("This is an info")
|
||||
logger.error("This is an error")
|
||||
logger.close()
|
||||
|
||||
# Use the context manager
|
||||
with Logger("logs", "my_app") as logger:
|
||||
logger.info("Automatically manage resources")
|
||||
|
||||
# Convenient function
|
||||
logger = create_logger("logs", "my_app")
|
||||
logger.info("Convenient create")
|
||||
logger.close()
|
||||
"""
|
||||
|
||||
from .logger import (
|
||||
Logger,
|
||||
create_logger,
|
||||
init_global_logger,
|
||||
get_global_logger,
|
||||
set_global_logger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Logger",
|
||||
"create_logger",
|
||||
"init_global_logger",
|
||||
"get_global_logger",
|
||||
"set_global_logger",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Tokfinity Team"
|
||||
__description__ = "Timestamp:Directory and Level-Based Log Manager "
|
||||
357
src/managers/log/logger.py
Normal file
357
src/managers/log/logger.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
|
||||
Log Manager
|
||||
Create dir according to timestamp, store level-based log files
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
"""
|
||||
Module-level Self-defined NOTICE level log (between INFO and WARNING)
|
||||
"""
|
||||
NOTICE_LEVEL = 25
|
||||
if not hasattr(logging, "NOTICE"):
|
||||
logging.addLevelName(NOTICE_LEVEL, "NOTICE")
|
||||
|
||||
def notice(self, message, *args, **kwargs):
|
||||
if self.isEnabledFor(NOTICE_LEVEL):
|
||||
self._log(NOTICE_LEVEL, message, args, **kwargs)
|
||||
|
||||
logging.Logger.notice = notice # type: ignore[attr-defined]
|
||||
|
||||
class ImageBuilderLogger:
|
||||
"""
|
||||
ImageBuilder log manager
|
||||
|
||||
Function:
|
||||
- Create sub dir based on timestamp
|
||||
- Create debug.log, info.log and error.log
|
||||
- Provide standard log format, able to print logs two levels up and the code location
|
||||
- Console and file outputs
|
||||
"""
|
||||
|
||||
def __init__(self, log_base_path: str, console_output: bool = True):
|
||||
self.log_base_path = Path(log_base_path)
|
||||
self.console_output = console_output
|
||||
|
||||
self.log_dir = self._create_log_dir()
|
||||
|
||||
self.file_handlers = {}
|
||||
|
||||
self.logger = self._setup_logger()
|
||||
|
||||
def _create_log_dir(self) -> Path:
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M")
|
||||
log_dir = self.log_base_path / f"build_images/{timestamp}"
|
||||
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return log_dir
|
||||
|
||||
def _setup_logger(self) -> logging.Logger:
|
||||
logger = logging.getLogger("image_builder_logger")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
logger.handlers.clear()
|
||||
|
||||
log_format = (
|
||||
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
|
||||
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
|
||||
)
|
||||
formatter = logging.Formatter(log_format)
|
||||
|
||||
if self.console_output:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
self._add_file_handlers(logger, formatter)
|
||||
|
||||
return logger
|
||||
|
||||
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
|
||||
log_levels = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
}
|
||||
|
||||
for level_name, level_value in log_levels.items():
|
||||
log_file = self.log_dir / f"{level_name}.log"
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
if level_name == "debug":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
|
||||
elif level_name == "info":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
|
||||
elif level_name == "warning":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
|
||||
elif level_name == "error":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
self.file_handlers[level_name] = file_handler
|
||||
|
||||
all_log_file = self.log_dir / "all.log"
|
||||
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
|
||||
all_file_handler.setFormatter(formatter)
|
||||
all_file_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(all_file_handler)
|
||||
self.file_handlers["all"] = all_file_handler
|
||||
|
||||
def debug(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.debug(message, *args, **kwargs)
|
||||
|
||||
def info(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.info(message, *args, **kwargs)
|
||||
|
||||
def warning(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.warning(message, *args, **kwargs)
|
||||
|
||||
def error(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.error(message, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def log_file(self) -> str:
|
||||
return str(self.log_dir / "all.log")
|
||||
|
||||
def get_log_dir(self) -> str:
|
||||
return str(self.log_dir.absolute())
|
||||
|
||||
def get_log_files(self) -> dict:
|
||||
return {
|
||||
"debug": str(self.log_dir / "debug.log"),
|
||||
"info": str(self.log_dir / "info.log"),
|
||||
"warning": str(self.log_dir / "warning.log"),
|
||||
"error": str(self.log_dir / "error.log"),
|
||||
"all": str(self.log_dir / "all.log"),
|
||||
}
|
||||
|
||||
def close(self):
|
||||
for handler in self.file_handlers.values():
|
||||
handler.close()
|
||||
|
||||
for handler in self.logger.handlers[:]:
|
||||
self.logger.removeHandler(handler)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ImageBuilderLogger(dir={self.log_dir})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ImageBuilderLogger("
|
||||
f"base_path='{self.log_base_path}', "
|
||||
f"log_dir='{self.log_dir}', "
|
||||
f"console_output={self.console_output})"
|
||||
)
|
||||
|
||||
class Logger:
|
||||
"""
|
||||
Log manager
|
||||
|
||||
Function:
|
||||
- Create sub dir based on timestamp
|
||||
- Create debug.log, info.log, notice.log, warning.log and error.log
|
||||
- Provide standard log format
|
||||
- Console and file outputs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_base_path: str,
|
||||
logger_name: str = "tokfinity_logger",
|
||||
console_output: bool = True,
|
||||
log_format: Optional[str] = None,
|
||||
instance_id: Optional[str] = None,
|
||||
):
|
||||
self.log_base_path = Path(log_base_path)
|
||||
self.logger_name = logger_name
|
||||
self.console_output = console_output
|
||||
self.instance_id = instance_id
|
||||
|
||||
self.log_format = log_format or (
|
||||
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
|
||||
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
|
||||
)
|
||||
|
||||
self.log_dir = self._create_log_dir()
|
||||
|
||||
self.file_handlers = {}
|
||||
|
||||
self.logger = self._setup_logger()
|
||||
|
||||
def _create_log_dir(self) -> Path:
|
||||
if self.instance_id:
|
||||
log_dir = self.log_base_path / self.instance_id
|
||||
else:
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M")
|
||||
log_dir = self.log_base_path / timestamp
|
||||
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return log_dir
|
||||
|
||||
def _setup_logger(self) -> logging.Logger:
|
||||
logger = logging.getLogger(self.logger_name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
logger.handlers.clear()
|
||||
|
||||
formatter = logging.Formatter(self.log_format)
|
||||
|
||||
if self.console_output:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
self._add_file_handlers(logger, formatter)
|
||||
|
||||
return logger
|
||||
|
||||
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
|
||||
log_levels = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"notice": NOTICE_LEVEL,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
}
|
||||
|
||||
for level_name, level_value in log_levels.items():
|
||||
log_file = self.log_dir / f"{level_name}.log"
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
if level_name == "debug":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
|
||||
elif level_name == "info":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
|
||||
elif level_name == "notice":
|
||||
file_handler.addFilter(lambda record: record.levelno == NOTICE_LEVEL)
|
||||
elif level_name == "warning":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
|
||||
elif level_name == "error":
|
||||
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
self.file_handlers[level_name] = file_handler
|
||||
|
||||
all_log_file = self.log_dir / "all.log"
|
||||
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
|
||||
all_file_handler.setFormatter(formatter)
|
||||
all_file_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(all_file_handler)
|
||||
self.file_handlers["all"] = all_file_handler
|
||||
|
||||
def debug(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.debug(message, *args, **kwargs)
|
||||
|
||||
def info(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.info(message, *args, **kwargs)
|
||||
|
||||
def notice(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.log(NOTICE_LEVEL, message, *args, **kwargs)
|
||||
|
||||
def warning(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.warning(message, *args, **kwargs)
|
||||
|
||||
def error(self, message: str, *args, **kwargs):
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
self.logger.error(message, *args, **kwargs)
|
||||
|
||||
def get_log_dir(self) -> str:
|
||||
return str(self.log_dir.absolute())
|
||||
|
||||
def get_log_files(self) -> dict:
|
||||
return {
|
||||
"debug": str(self.log_dir / "debug.log"),
|
||||
"info": str(self.log_dir / "info.log"),
|
||||
"notice": str(self.log_dir / "notice.log"),
|
||||
"warning": str(self.log_dir / "warning.log"),
|
||||
"error": str(self.log_dir / "error.log"),
|
||||
"all": str(self.log_dir / "all.log"),
|
||||
}
|
||||
|
||||
@property
|
||||
def log_file(self) -> str:
|
||||
return str(self.log_dir / "all.log")
|
||||
|
||||
def close(self):
|
||||
for handler in self.file_handlers.values():
|
||||
handler.close()
|
||||
for handler in self.logger.handlers[:]:
|
||||
self.logger.removeHandler(handler)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Logger(name={self.logger_name}, dir={self.log_dir})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Logger("
|
||||
f"name='{self.logger_name}', "
|
||||
f"base_path='{self.log_base_path}', "
|
||||
f"log_dir='{self.log_dir}', "
|
||||
f"console_output={self.console_output})"
|
||||
)
|
||||
|
||||
|
||||
def create_logger(
|
||||
log_base_path: str,
|
||||
logger_name: str = "tokfinity_logger",
|
||||
console_output: bool = True,
|
||||
) -> Logger:
|
||||
return Logger(
|
||||
log_base_path=log_base_path,
|
||||
logger_name=logger_name,
|
||||
console_output=console_output,
|
||||
)
|
||||
|
||||
|
||||
# Global log manager instance (optional)
|
||||
_global_logger: Optional[Logger] = None
|
||||
|
||||
|
||||
def get_global_logger() -> Optional[Logger]:
|
||||
return _global_logger
|
||||
|
||||
|
||||
def set_global_logger(logger: Logger):
|
||||
global _global_logger
|
||||
_global_logger = logger
|
||||
|
||||
|
||||
def init_global_logger(
|
||||
log_base_path: str, logger_name: str = "global_logger"
|
||||
) -> Logger:
|
||||
global _global_logger
|
||||
_global_logger = Logger(log_base_path, logger_name)
|
||||
return _global_logger
|
||||
275
src/managers/loop/base.py
Normal file
275
src/managers/loop/base.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
from src.managers.llm_api.api_manager import LLMAPIManager
|
||||
from src.managers.prompts.prompts_manager import PromptsManager
|
||||
from src.tools.base import (
|
||||
ToolExecutor,
|
||||
BASH_TOOL_NAME,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
from src.managers.loop.types import ToolStats, LLMUsage
|
||||
|
||||
|
||||
|
||||
class BaseLoop:
|
||||
def __init__(self, instance_id: str, instance_data: Dict[str, Any], logger: Logger, prompts_manager: PromptsManager | None, llm_manager: LLMAPIManager | None, tool_executor: ToolExecutor, config: Dict[str, Any] | None = None):
|
||||
self.instance_id = instance_id
|
||||
self.instance_data = instance_data
|
||||
self.logger = logger
|
||||
self.prompts_manager = prompts_manager
|
||||
self.llm_manager = llm_manager
|
||||
self.tool_executor = tool_executor
|
||||
self.config = config or {}
|
||||
self.component_name = self.__class__.__name__
|
||||
|
||||
|
||||
def _make_assistant(
|
||||
self, content: str | None, tool_calls: Any, messages: List[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""
|
||||
Construct an assistant message based on the current content and tool calls, and append it to the messages.
|
||||
"""
|
||||
safe_content = content or ""
|
||||
if not safe_content and not tool_calls:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] Assistant returned an empty message with no tool calls; skipping this message and prompting to continue"
|
||||
)
|
||||
messages.append(
|
||||
{"role": "user", "content": "请继续分析问题并使用工具来解决问题。"}
|
||||
)
|
||||
return False
|
||||
assistant_message: Dict[str, Any] = {"role": "assistant"}
|
||||
if tool_calls and not safe_content:
|
||||
assistant_message["content"] = ""
|
||||
elif safe_content:
|
||||
assistant_message["content"] = safe_content
|
||||
if tool_calls:
|
||||
assistant_message["tool_calls"] = tool_calls
|
||||
messages.append(assistant_message)
|
||||
return True
|
||||
|
||||
def _make_tool_response(
|
||||
self, tool_results: List[Any], messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Convert tool execution results into standard tool messages (role=tool) and append them to the messages.
|
||||
|
||||
- Generate content per result: use prompts_manager.tool_response_prompts([{...}]) to produce the content
|
||||
- Set tool_call_id: prefer ToolResult.id; fallback to ToolResult.call_id
|
||||
"""
|
||||
if not tool_results:
|
||||
return
|
||||
for result in tool_results:
|
||||
single_dict = [
|
||||
{
|
||||
"name": getattr(result, "name", "unknown"),
|
||||
"success": getattr(result, "success", False),
|
||||
"result": getattr(result, "result", None) or "",
|
||||
"error": getattr(result, "error", None) or "",
|
||||
}
|
||||
]
|
||||
content_text = (
|
||||
self.prompts_manager.tool_response_prompts(single_dict)
|
||||
if self.prompts_manager
|
||||
else ""
|
||||
)
|
||||
tool_call_id = getattr(result, "id", None) or getattr(
|
||||
result, "call_id", None
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": content_text,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
)
|
||||
|
||||
def _response_log(
|
||||
self, response: Any, first_content: str, first_tool_calls: Any, total_turns: int
|
||||
) -> None:
|
||||
"""notice log for the current turn's LLM output"""
|
||||
try:
|
||||
response_log: Dict[str, Any] = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
response_log["usage"] = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", None),
|
||||
"completion_tokens": getattr(response.usage, "completion_tokens", None),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", None),
|
||||
}
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
response_log["choice"] = {
|
||||
"message": {
|
||||
"content": first_content,
|
||||
"tool_calls": first_tool_calls,
|
||||
}
|
||||
}
|
||||
if response_log:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] The {total_turns}th turn output: {json.dumps(response_log, ensure_ascii=False)}"
|
||||
)
|
||||
else:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] The {total_turns}th turn output: {str(response)}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] 第 {total_turns} 轮: LLM 输出序列化失败,使用字符串表示: {str(response)}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
def _debug_messages(
|
||||
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
|
||||
) -> None:
|
||||
"""debug log for the messages to be sent to the model"""
|
||||
try:
|
||||
self.logger.debug(f"[{self.component_name}] msg:")
|
||||
recent_messages = messages[-2:] if len(messages) > 2 else messages
|
||||
base_index = len(messages) - len(recent_messages)
|
||||
for offset, msg in enumerate[Dict[str, Any]](recent_messages):
|
||||
idx = base_index + offset
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
content_str = content if isinstance(content, str) else ""
|
||||
preview = content_str[:prefix_len]
|
||||
content_len = len(content_str)
|
||||
extra = ""
|
||||
if role == "assistant":
|
||||
tool_calls = msg.get("tool_calls")
|
||||
has_tool = tool_calls is not None and tool_calls != []
|
||||
try:
|
||||
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] In debug_messages function, fail: {format_exc()}, tool calls: {tool_calls}."
|
||||
)
|
||||
tool_calls_json = str(tool_calls)
|
||||
extra = f", has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
|
||||
elif role == "tool":
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
extra = f", tool_call_id={tool_call_id}"
|
||||
self.logger.debug(
|
||||
f"[{self.component_name}] {turn+1}th, msg#{idx}: role={role}, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}{extra}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] In debug_messages function, fail msg: {format_exc()}."
|
||||
)
|
||||
|
||||
def _debug_last_message(
|
||||
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
|
||||
) -> None:
|
||||
"""debug last turn msg"""
|
||||
try:
|
||||
if not messages:
|
||||
return
|
||||
last_assistant_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].get("role") == "assistant":
|
||||
last_assistant_idx = i
|
||||
break
|
||||
if last_assistant_idx is None:
|
||||
return
|
||||
msg = messages[last_assistant_idx]
|
||||
content = msg.get("content")
|
||||
content_str = content if isinstance(content, str) else ""
|
||||
preview = content_str[:prefix_len]
|
||||
content_len = len(content_str)
|
||||
tool_calls = msg.get("tool_calls")
|
||||
has_tool = tool_calls is not None and tool_calls != []
|
||||
try:
|
||||
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] In debug_last_message function, fail: {format_exc()}, tool calls: {tool_calls}."
|
||||
)
|
||||
tool_calls_json = str(tool_calls)
|
||||
self.logger.debug(
|
||||
f"[{self.component_name}] {turn+1}th turn, output_preview: role=assistant, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}, has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] In debug_last_message function, last turn fail: {format_exc()}."
|
||||
)
|
||||
|
||||
def _debug_tools(self, tools: List[Dict[str, Any]]) -> None:
|
||||
"""debug tools msg"""
|
||||
try:
|
||||
self.logger.debug(f"[{self.component_name}] tools num: {len(tools)}")
|
||||
for i, tool in enumerate(tools):
|
||||
try:
|
||||
tool_json = json.dumps(tool, ensure_ascii=False)
|
||||
self.logger.debug(f"[{self.component_name}] tool #{i+1}: {tool_json}")
|
||||
except Exception:
|
||||
self.logger.debug(
|
||||
f"[{self.component_name}] tool #{i+1} fail: {format_exc()}, string: {str(tool)}."
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] fail; traceback: {format_exc()}."
|
||||
)
|
||||
self.logger.warning(f"[{self.component_name}] tools string: {str(tools)}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_tools(self) -> List[Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def _is_bash_tool(self, tool_name: str) -> bool:
|
||||
return BASH_TOOL_NAME in tool_name
|
||||
|
||||
def _is_edit_tool(self, tool_name: str) -> bool:
|
||||
return "edit" in tool_name or "str_replace" in tool_name or STR_REPLACE_BASED_EDIT_TOOL_NAME in tool_name
|
||||
|
||||
def _is_search_tool(self, tool_name: str) -> bool:
|
||||
return SEARCH_TOOL_NAME in tool_name or "search" in tool_name
|
||||
|
||||
def _is_submit_result_tool(self, tool_name: str) -> bool:
|
||||
return SUBMIT_RESULT_TOOL_NAME in tool_name
|
||||
|
||||
def _update_usage(self, response: Any, usage_stats: LLMUsage) -> None:
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage_stats.prompt_tokens += int(getattr(response.usage, "prompt_tokens", 0) or 0)
|
||||
usage_stats.completion_tokens += int(
|
||||
getattr(response.usage, "completion_tokens", 0) or 0
|
||||
)
|
||||
usage_stats.total_tokens += int(getattr(response.usage, "total_tokens", 0) or 0)
|
||||
|
||||
def _init_usage_stats(self) -> LLMUsage:
|
||||
return LLMUsage()
|
||||
|
||||
def _init_tools_stats(self) -> ToolStats:
|
||||
return ToolStats()
|
||||
|
||||
def _update_tool_call_statistic(
|
||||
self, tool_results: List[Any], tool_stats: ToolStats
|
||||
) -> None:
|
||||
for result in tool_results:
|
||||
try:
|
||||
tool_name = getattr(result, "name", "")
|
||||
tool_name = tool_name.lower() if isinstance(tool_name, str) else ""
|
||||
success = bool(getattr(result, "success", False))
|
||||
|
||||
if self._is_bash_tool(tool_name):
|
||||
tool_stats.bash["count"] += 1
|
||||
if not success:
|
||||
tool_stats.bash["failed"] += 1
|
||||
|
||||
elif self._is_edit_tool(tool_name):
|
||||
tool_stats.edit["count"] += 1
|
||||
if not success:
|
||||
tool_stats.edit["failed"] += 1
|
||||
|
||||
elif self._is_search_tool(tool_name):
|
||||
tool_stats.search["count"] += 1
|
||||
if not success:
|
||||
tool_stats.search["failed"] += 1
|
||||
|
||||
elif self._is_submit_result_tool(tool_name):
|
||||
tool_stats.submit_result["count"] += 1
|
||||
if not success:
|
||||
tool_stats.submit_result["failed"] += 1
|
||||
except Exception:
|
||||
continue
|
||||
339
src/managers/loop/patch_generator.py
Normal file
339
src/managers/loop/patch_generator.py
Normal file
@@ -0,0 +1,339 @@
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
from src.managers.llm_api.api_manager import LLMAPIManager
|
||||
from src.managers.prompts.prompts_manager import PromptsManager
|
||||
from src.managers.loop.base import BaseLoop
|
||||
from src.tools.base import (
|
||||
ToolExecutor,
|
||||
ToolResult,
|
||||
SubmitToolResult,
|
||||
BASH_TOOL_NAME,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
|
||||
|
||||
class PatchGenerator(BaseLoop):
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
instance_data: Dict[str, Any],
|
||||
logger: Logger,
|
||||
prompts_manager: PromptsManager | None,
|
||||
llm_manager: LLMAPIManager | None,
|
||||
tool_executor: ToolExecutor,
|
||||
config: Dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
|
||||
|
||||
|
||||
async def _submit_all_tool_calls(
|
||||
self, other_tool_calls: List[Dict[str, Any]]
|
||||
) -> List[Any]:
|
||||
"""execute tool calls, return tool execution results list"""
|
||||
if not other_tool_calls:
|
||||
return []
|
||||
from src.tools.base import ToolCall
|
||||
|
||||
tool_call_objects = []
|
||||
for tool_call_dict in other_tool_calls:
|
||||
raw_args = tool_call_dict.get("function", {}).get("arguments", {})
|
||||
parsed_args = raw_args
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed_args = json.loads(raw_args)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"[{self.component_name}] In _submit_all_tool_calls function, fail: {e}, traceback: {format_exc()}, args: {raw_args}.")
|
||||
parsed_args = {}
|
||||
tool_call_obj = ToolCall(
|
||||
name=tool_call_dict.get("function", {}).get("name", ""),
|
||||
call_id=tool_call_dict.get("id", ""),
|
||||
arguments=parsed_args,
|
||||
id=tool_call_dict.get("id", ""),
|
||||
)
|
||||
tool_call_objects.append(tool_call_obj)
|
||||
return await self.tool_executor.container_sequential_tool_call(
|
||||
tool_call_objects
|
||||
)
|
||||
|
||||
def _process_submit_result_tool_result(
|
||||
self,
|
||||
submit_result: ToolResult,
|
||||
golden_patch: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""process submit_result tool call, fill golden_patch and log"""
|
||||
if not submit_result.success or not submit_result.result:
|
||||
self.logger.warning(f"[{self.component_name}] submit_result failed and no result.")
|
||||
return
|
||||
|
||||
try:
|
||||
submit_tool_result = SubmitToolResult.from_string(submit_result.result)
|
||||
|
||||
if submit_tool_result.output:
|
||||
patch_info = {
|
||||
"patch_content": submit_tool_result.output,
|
||||
"test_status": submit_tool_result.test_status,
|
||||
"reasoning": submit_tool_result.reasoning,
|
||||
}
|
||||
golden_patch.clear()
|
||||
golden_patch.append(patch_info)
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] patch len: {len(submit_tool_result.output)}."
|
||||
)
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] test status: {submit_tool_result.test_status}."
|
||||
)
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] reasoning: {submit_tool_result.reasoning[:100]}..."
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] submit_result success but no patch content."
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"[{self.component_name}] parse submit_result result fail: {e}, traceback: {format_exc()}.")
|
||||
|
||||
def _get_tools(self) -> List[Dict[str, Any]]:
|
||||
tools = []
|
||||
#use_openai_format = self._should_use_openai_format()
|
||||
use_openai_format = True
|
||||
|
||||
for tool in self.tool_executor.tools.values():
|
||||
if use_openai_format:
|
||||
tool_def = tool._definition_for_openai_fmt()
|
||||
else:
|
||||
tool_def = tool._definition_for_claude_fmt()
|
||||
tools.append(tool_def)
|
||||
|
||||
return tools
|
||||
|
||||
def _should_use_openai_format(self) -> bool:
|
||||
|
||||
if not self.llm_manager or not hasattr(self.llm_manager, "get_model_name"):
|
||||
return True # openAI format by default
|
||||
|
||||
model_name = self.llm_manager.get_model_name().lower()
|
||||
|
||||
return "claude" not in model_name
|
||||
|
||||
def _get_issue_prompt(self) -> str:
|
||||
"""generate issue prompt based on instance data"""
|
||||
if not self.prompts_manager:
|
||||
self.logger.warning("PromptsManager not initialized, cannot generate issue prompt.")
|
||||
return ""
|
||||
|
||||
#instance_id = self.instance_data.get("instance_id", "")
|
||||
#repo = self.instance_data.get("repo", "")
|
||||
created_at = self.instance_data.get("created_at", "")
|
||||
base_commit = self.instance_data.get("base_commit", "")
|
||||
environment_setup_commit = self.instance_data.get(
|
||||
"environment_setup_commit", ""
|
||||
)
|
||||
version = self.instance_data.get("version", "")
|
||||
problem_statement = self.instance_data.get("problem_statement", "")
|
||||
difficulty = self.instance_data.get("difficulty", "")
|
||||
|
||||
return self.prompts_manager.format_issue_prompt(
|
||||
created_at=created_at,
|
||||
base_commit=base_commit,
|
||||
environment_setup_commit=environment_setup_commit,
|
||||
version=version,
|
||||
problem_statement=problem_statement,
|
||||
difficulty=difficulty,
|
||||
)
|
||||
|
||||
async def _generate_patch(self) -> Dict[str, Any] | None:
|
||||
"""main loop logic for generating candidate patch"""
|
||||
|
||||
usage_stats = self._init_usage_stats()
|
||||
tool_stats = self._init_tools_stats()
|
||||
|
||||
if not self.llm_manager or not self.prompts_manager:
|
||||
self.logger.error(f"[{self.component_name}] LLM manager or prompts manager not initialized.")
|
||||
return {
|
||||
"success": False,
|
||||
"golden_patch": [],
|
||||
"llm_usage": usage_stats.to_dict(),
|
||||
"tool_stats": tool_stats.to_dict(),
|
||||
"total_turns": 0,
|
||||
}
|
||||
|
||||
tools = self._get_tools()
|
||||
|
||||
self._debug_tools(tools)
|
||||
|
||||
root_path = self.config.get("builder", {}).get("repo_root_path", "")
|
||||
max_turn = (
|
||||
self.config.get("runner", {}).get("generator_loop", {}).get("max_turn", 10)
|
||||
)
|
||||
temperature = (
|
||||
self.config.get("runner", {})
|
||||
.get("generator_loop", {})
|
||||
.get("temperature", 0.2)
|
||||
)
|
||||
|
||||
issue_prompt = self._get_issue_prompt()
|
||||
user_prompt = self.prompts_manager.get_generator_user(root_path, issue_prompt)
|
||||
system_prompt = self.prompts_manager.get_generator_system(root_path)
|
||||
|
||||
|
||||
total_turns = 0
|
||||
golden_patch = []
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] {self.instance_id}: start generating candidate patch, max turn: {max_turn}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
|
||||
)
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
for turn in range(max_turn):
|
||||
total_turns = turn + 1
|
||||
self.logger.info(f"[{self.component_name}] The {total_turns}th turn started.")
|
||||
|
||||
try:
|
||||
current_input_msg = messages[-1] if messages else None
|
||||
if current_input_msg is not None:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] The {total_turns}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] {total_turns}th turn: LLM input fail: {messages[-1] if messages else None}, error: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
self._debug_messages(turn, messages)
|
||||
|
||||
response = self.llm_manager.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
first_content: str = ""
|
||||
first_tool_calls: Any = None
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
ch0 = response.choices[0]
|
||||
first_content = (
|
||||
getattr(getattr(ch0, "message", None), "content", None) or ""
|
||||
)
|
||||
first_tool_calls = getattr(
|
||||
getattr(ch0, "message", None), "tool_calls", None
|
||||
)
|
||||
|
||||
self._response_log(
|
||||
response, first_content, first_tool_calls, total_turns
|
||||
)
|
||||
|
||||
self._update_usage(response, usage_stats)
|
||||
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
content = first_content
|
||||
tool_calls = first_tool_calls
|
||||
|
||||
if not self._make_assistant(content, tool_calls, messages):
|
||||
continue
|
||||
|
||||
if tool_calls:
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] {total_turns}th turn: call {len(tool_calls)} tools."
|
||||
)
|
||||
|
||||
tool_results = await self._submit_all_tool_calls(tool_calls)
|
||||
|
||||
self._update_tool_call_statistic(tool_results, tool_stats)
|
||||
|
||||
if tool_results:
|
||||
submit_result = None
|
||||
other_tool_results = []
|
||||
|
||||
for tool_result in tool_results:
|
||||
tool_name = getattr(tool_result, "name", "")
|
||||
if tool_name == SUBMIT_RESULT_TOOL_NAME:
|
||||
submit_result = tool_result
|
||||
else:
|
||||
other_tool_results.append(tool_result)
|
||||
|
||||
if submit_result:
|
||||
self.logger.debug(
|
||||
f"[{self.component_name}] {total_turns}th turn: got submit_result tool call."
|
||||
)
|
||||
self.logger.debug(f"[{self.component_name}] {total_turns}th turn: submit_result result: {submit_result}")
|
||||
|
||||
self._process_submit_result_tool_result(
|
||||
submit_result, golden_patch
|
||||
)
|
||||
|
||||
self._debug_last_message(turn, messages)
|
||||
break
|
||||
|
||||
if other_tool_results:
|
||||
self._make_tool_response(other_tool_results, messages)
|
||||
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请继续分析问题并使用工具来解决问题。",
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.debug(f"[{self.component_name}] final golden_patch: {golden_patch}")
|
||||
|
||||
success = (
|
||||
len(golden_patch) > 0 and golden_patch[0].get("patch_content", "") != ""
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] status={success}, total_turns={total_turns}, tools_stats={tool_stats}"
|
||||
)
|
||||
|
||||
result_payload = {
|
||||
"success": success,
|
||||
"golden_patch": golden_patch,
|
||||
"llm_usage": usage_stats.to_dict(),
|
||||
"tool_stats": tool_stats.to_dict(),
|
||||
"total_turns": total_turns,
|
||||
}
|
||||
try:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] final output: {json.dumps(result_payload, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] output: {str(result_payload)}, error: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
return result_payload
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"[{self.component_name}] fail: {e}, traceback: {format_exc()}.")
|
||||
result_payload = {
|
||||
"success": False,
|
||||
"golden_patch": [],
|
||||
"llm_usage": usage_stats.to_dict(),
|
||||
"tool_stats": tool_stats.to_dict(),
|
||||
"total_turns": total_turns,
|
||||
}
|
||||
try:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] 最终返回数据(失败): {json.dumps(result_payload, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] 最终返回数据(失败, 字符串回退): {str(result_payload)}, error: {e}, traceback: {format_exc()}."
|
||||
)
|
||||
return result_payload
|
||||
338
src/managers/loop/patch_selector.py
Normal file
338
src/managers/loop/patch_selector.py
Normal file
@@ -0,0 +1,338 @@
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
from traceback import format_exc
|
||||
from src.managers.log.logger import Logger
|
||||
from src.managers.llm_api.api_manager import LLMAPIManager
|
||||
from src.managers.prompts.prompts_manager import PromptsManager
|
||||
from src.managers.loop.types import GeneratorResult, SelectorResult, LLMUsage, ToolStats, PatchInfo
|
||||
from src.tools.base import ToolExecutor, ToolCall, ToolResult
|
||||
from src.managers.loop.base import BaseLoop
|
||||
|
||||
SELECTOR_SUBMIT_TOOL_NAME = "submit_result"
|
||||
|
||||
class PatchSelector(BaseLoop):
|
||||
def __init__(
|
||||
self,
|
||||
instance_id: str,
|
||||
instance_data: Dict[str, Any],
|
||||
logger: Logger,
|
||||
prompts_manager: PromptsManager | None,
|
||||
llm_manager: LLMAPIManager | None,
|
||||
tool_executor: ToolExecutor,
|
||||
config: Dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
|
||||
|
||||
def _get_submit_result_tool_name(self):
|
||||
return SELECTOR_SUBMIT_TOOL_NAME
|
||||
|
||||
def _definition_for_submit_tool(self, use_openai_format: bool) -> Dict[str, Any]:
|
||||
"""submit_result tool"""
|
||||
if use_openai_format:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self._get_submit_result_tool_name(),
|
||||
"description": "Submit the final selected patch index and reasoning.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "The chosen patch index (0-based).",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Detailed reasoning for the selection.",
|
||||
},
|
||||
},
|
||||
"required": ["index", "reason"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self._get_submit_result_tool_name(),
|
||||
"description": "Submit the final selected patch index and reasoning.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": "The chosen patch index (0-based).",
|
||||
},
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Detailed reasoning for the selection.",
|
||||
},
|
||||
},
|
||||
"required": ["index", "reason"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _build_user_prompt(self, candidates: List[GeneratorResult], root_path: str) -> str:
|
||||
|
||||
if not self.prompts_manager:
|
||||
return ""
|
||||
return self.prompts_manager.get_selector_user(self.instance_data, candidates, root_path)
|
||||
|
||||
def _get_system_prompt(self, patches_count: int, root_path: str) -> str:
|
||||
if not self.prompts_manager:
|
||||
return ""
|
||||
return self.prompts_manager.get_selector_system(patches_count, root_path)
|
||||
|
||||
def _get_tools(self) -> List[Dict[str, Any]]:
|
||||
|
||||
tool_defs: List[Dict[str, Any]] = []
|
||||
try:
|
||||
for tool in self.tool_executor.tools.values():
|
||||
try:
|
||||
tool_defs.append(tool._definition_for_openai_fmt())
|
||||
except Exception:
|
||||
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tool_defs.append(self._definition_for_submit_tool(True))
|
||||
return tool_defs
|
||||
|
||||
def _extract_submit_choice(self, tool_call: Dict[str, Any]) -> Dict[str, Any] | None:
|
||||
|
||||
if not tool_call:
|
||||
return None
|
||||
fn = tool_call.get("function", {})
|
||||
if fn.get("name") != self._get_submit_result_tool_name():
|
||||
return None
|
||||
raw_args = fn.get("arguments", {})
|
||||
try:
|
||||
args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args
|
||||
except Exception:
|
||||
args = {}
|
||||
index = args.get("index")
|
||||
reason = args.get("reason")
|
||||
if isinstance(index, int) and index >= 0:
|
||||
return {"index": index, "reason": reason or ""}
|
||||
return None
|
||||
|
||||
async def _submit_other_tool_calls(
|
||||
self, tool_calls: List[Dict[str, Any]]
|
||||
) -> List[ToolResult]:
|
||||
|
||||
if not tool_calls:
|
||||
return []
|
||||
to_run: List[ToolCall] = []
|
||||
for tool_call_dict in tool_calls:
|
||||
fn = tool_call_dict.get("function", {})
|
||||
name = fn.get("name", "")
|
||||
if name == SELECTOR_SUBMIT_TOOL_NAME:
|
||||
continue
|
||||
raw_args = fn.get("arguments", {})
|
||||
parsed_args = raw_args
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
parsed_args = json.loads(raw_args)
|
||||
except Exception:
|
||||
parsed_args = {}
|
||||
to_run.append(
|
||||
ToolCall(
|
||||
name=name,
|
||||
call_id=tool_call_dict.get("id", ""),
|
||||
arguments=parsed_args,
|
||||
id=tool_call_dict.get("id", ""),
|
||||
)
|
||||
)
|
||||
if not to_run:
|
||||
return []
|
||||
results: List[ToolResult] = await self.tool_executor.container_sequential_tool_call(to_run)
|
||||
return results
|
||||
|
||||
async def _select_patch(self, candidates: List[GeneratorResult]) -> SelectorResult:
|
||||
|
||||
if not candidates:
|
||||
raise ValueError("No candidates provided")
|
||||
if not self.llm_manager:
|
||||
raise ValueError("LLM manager is not initialized")
|
||||
|
||||
|
||||
tools = self._get_tools()
|
||||
|
||||
self._debug_tools(tools)
|
||||
|
||||
root_path = self.config.get("builder", {}).get("repo_root_path", "")
|
||||
system_prompt = self._get_system_prompt(len(candidates), root_path)
|
||||
user_prompt = self._build_user_prompt(candidates, root_path)
|
||||
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
|
||||
)
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] Initial fail in selector loop: SP={str(messages[0])}, UP={str(messages[1])}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
max_turn = int(
|
||||
self.config.get("runner", {})
|
||||
.get("selector_loop", {})
|
||||
.get("max_turn", 200)
|
||||
)
|
||||
temperature = (
|
||||
self.config.get("runner", {})
|
||||
.get("selector_loop", {})
|
||||
.get("temperature", 0.2)
|
||||
)
|
||||
|
||||
usage_stats = self._init_usage_stats()
|
||||
tool_stats = self._init_tools_stats()
|
||||
|
||||
total_turns = 0
|
||||
|
||||
chosen_index: int | None = None
|
||||
select_reason: str = ""
|
||||
for turn in range(max_turn):
|
||||
try:
|
||||
try:
|
||||
current_input_msg = messages[-1] if messages else None
|
||||
if current_input_msg is not None:
|
||||
self.logger.notice(
|
||||
f"[{self.component_name}] The {turn+1}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
|
||||
)
|
||||
except Exception:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] {turn+1}th turn fail: {messages[-1] if messages else None}, traceback: {format_exc()}."
|
||||
)
|
||||
|
||||
self._debug_messages(turn, messages)
|
||||
|
||||
response = self.llm_manager.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
first_tool_calls = None
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
ch0 = response.choices[0]
|
||||
first_tool_calls = getattr(getattr(ch0, "message", None), "tool_calls", None)
|
||||
first_content = getattr(getattr(ch0, "message", None), "content", None) or ""
|
||||
else:
|
||||
first_content = ""
|
||||
|
||||
total_turns = turn + 1
|
||||
|
||||
self._response_log(response, first_content, first_tool_calls, turn + 1)
|
||||
|
||||
self._update_usage(response, usage_stats)
|
||||
|
||||
|
||||
if first_tool_calls:
|
||||
|
||||
if not self._make_assistant(first_content, first_tool_calls, messages):
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
|
||||
}
|
||||
)
|
||||
continue
|
||||
submit_found = False
|
||||
for tc in first_tool_calls:
|
||||
choice = self._extract_submit_choice(tc)
|
||||
if choice is not None:
|
||||
chosen_index = choice["index"]
|
||||
reason = choice.get("reason", "")
|
||||
self.logger.info(
|
||||
f"[{self.component_name}] choose: index={chosen_index}, reason={reason}"
|
||||
)
|
||||
select_reason = reason or ""
|
||||
submit_found = True
|
||||
|
||||
self._debug_last_message(turn, messages)
|
||||
break
|
||||
|
||||
|
||||
if not submit_found:
|
||||
|
||||
results = await self._submit_other_tool_calls(first_tool_calls)
|
||||
|
||||
self._make_tool_response(results, messages)
|
||||
|
||||
self._update_tool_call_statistic(results, tool_stats)
|
||||
|
||||
else:
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
|
||||
}
|
||||
)
|
||||
|
||||
if chosen_index is not None:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"[{self.component_name}] fail: {e}, traceback: {format_exc()}"
|
||||
)
|
||||
break
|
||||
|
||||
if chosen_index is None:
|
||||
# If the model provides no choice, fallback: pick the first successful one; otherwise the first
|
||||
for i, r in enumerate(candidates):
|
||||
try:
|
||||
if r.success:
|
||||
chosen_index = i
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if chosen_index is None:
|
||||
chosen_index = 0
|
||||
|
||||
if not (0 <= chosen_index < len(candidates)):
|
||||
chosen_index = 0
|
||||
|
||||
selected = candidates[chosen_index]
|
||||
|
||||
try:
|
||||
gp = selected.golden_patch[0] if selected.golden_patch else None
|
||||
if gp is None:
|
||||
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
|
||||
else:
|
||||
patch_info = PatchInfo(
|
||||
patch_content=gp.patch_content,
|
||||
test_status=gp.test_status,
|
||||
reasoning=gp.reasoning,
|
||||
)
|
||||
except Exception:
|
||||
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
|
||||
|
||||
selector_result = SelectorResult(
|
||||
instance_id=selected.instance_id,
|
||||
generator_id=selected.generator_id,
|
||||
image=selected.image,
|
||||
success=True,
|
||||
golden_patch=patch_info,
|
||||
llm_usage=usage_stats,
|
||||
tool_stats=tool_stats,
|
||||
total_turns=total_turns,
|
||||
select_reason=select_reason,
|
||||
error=None,
|
||||
)
|
||||
return selector_result
|
||||
|
||||
|
||||
254
src/managers/loop/types.py
Normal file
254
src/managers/loop/types.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
This module defines the GeneratorResult data structure for patch generation results.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Any, Optional
|
||||
from src.tools.base import (
|
||||
BASH_TOOL_NAME,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMUsage:
|
||||
"""LLM usage statistics."""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, int]:
|
||||
"""Serialize LLMUsage to a plain dictionary."""
|
||||
return {
|
||||
"prompt_tokens": int(self.prompt_tokens),
|
||||
"completion_tokens": int(self.completion_tokens),
|
||||
"total_tokens": int(self.total_tokens),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolStats:
|
||||
"""Tool usage statistics per tool.
|
||||
|
||||
Each tool is represented by a small map with two fields:
|
||||
- count: total invocation count
|
||||
- failed: failed invocation count
|
||||
"""
|
||||
|
||||
bash: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
|
||||
edit: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
|
||||
search: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
|
||||
submit_result: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, int]]:
|
||||
"""Serialize ToolStats to a plain dictionary."""
|
||||
return {
|
||||
BASH_TOOL_NAME: {"count": int(self.bash.get("count", 0)), "failed": int(self.bash.get("failed", 0))},
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: {"count": int(self.edit.get("count", 0)), "failed": int(self.edit.get("failed", 0))},
|
||||
SEARCH_TOOL_NAME: {"count": int(self.search.get("count", 0)), "failed": int(self.search.get("failed", 0))},
|
||||
SUBMIT_RESULT_TOOL_NAME: {"count": int(self.submit_result.get("count", 0)), "failed": int(self.submit_result.get("failed", 0))},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchInfo:
|
||||
"""Information about a generated patch."""
|
||||
|
||||
patch_content: str
|
||||
test_status: str
|
||||
reasoning: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratorResult:
|
||||
"""Result from a patch generator."""
|
||||
|
||||
instance_id: str
|
||||
generator_id: int
|
||||
image: str
|
||||
success: bool
|
||||
golden_patch: List[
|
||||
PatchInfo
|
||||
]
|
||||
llm_usage: LLMUsage
|
||||
tool_stats: ToolStats
|
||||
total_turns: int
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GeneratorResult":
|
||||
"""Create GeneratorResult from dictionary."""
|
||||
# Handle golden_patch conversion
|
||||
golden_patch = []
|
||||
if data.get("golden_patch"):
|
||||
for patch_data in data["golden_patch"]:
|
||||
if isinstance(patch_data, dict):
|
||||
golden_patch.append(
|
||||
PatchInfo(
|
||||
patch_content=patch_data.get("patch_content", ""),
|
||||
test_status=patch_data.get("test_status", ""),
|
||||
reasoning=patch_data.get("reasoning", ""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Legacy format: just patch content string
|
||||
golden_patch.append(
|
||||
PatchInfo(
|
||||
patch_content=str(patch_data), test_status="", reasoning=""
|
||||
)
|
||||
)
|
||||
|
||||
# Handle LLM usage
|
||||
llm_usage_data = data.get("llm_usage", {})
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=llm_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=llm_usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
# Handle tool stats
|
||||
tool_stats_data = data.get("tool_stats", {})
|
||||
tool_stats = ToolStats(
|
||||
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
|
||||
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
|
||||
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
|
||||
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
|
||||
)
|
||||
|
||||
return cls(
|
||||
instance_id=data.get("instance_id", ""),
|
||||
generator_id=data.get("generator_id", 0),
|
||||
image=data.get("image", ""),
|
||||
success=data.get("success", False),
|
||||
golden_patch=golden_patch,
|
||||
llm_usage=llm_usage,
|
||||
tool_stats=tool_stats,
|
||||
total_turns=data.get("total_turns", 0),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert GeneratorResult to dictionary."""
|
||||
return {
|
||||
"instance_id": self.instance_id,
|
||||
"generator_id": self.generator_id,
|
||||
"image": self.image,
|
||||
"success": self.success,
|
||||
"golden_patch": [
|
||||
{
|
||||
"patch_content": patch.patch_content,
|
||||
"test_status": patch.test_status,
|
||||
"reasoning": patch.reasoning,
|
||||
}
|
||||
for patch in self.golden_patch
|
||||
],
|
||||
"llm_usage": {
|
||||
"prompt_tokens": self.llm_usage.prompt_tokens,
|
||||
"completion_tokens": self.llm_usage.completion_tokens,
|
||||
"total_tokens": self.llm_usage.total_tokens,
|
||||
},
|
||||
"tool_stats": {
|
||||
BASH_TOOL_NAME: self.tool_stats.bash,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
|
||||
SEARCH_TOOL_NAME: self.tool_stats.search,
|
||||
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
|
||||
},
|
||||
"total_turns": self.total_turns,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SelectorResult:
|
||||
"""Result from a patch selector.
|
||||
"""
|
||||
|
||||
instance_id: str
|
||||
generator_id: int
|
||||
image: str
|
||||
success: bool
|
||||
golden_patch: PatchInfo
|
||||
llm_usage: LLMUsage
|
||||
tool_stats: ToolStats
|
||||
total_turns: int
|
||||
select_reason: str
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SelectorResult":
|
||||
"""Create SelectorResult from dictionary."""
|
||||
|
||||
gp_data = data.get("golden_patch", {})
|
||||
if isinstance(gp_data, dict):
|
||||
golden_patch = PatchInfo(
|
||||
patch_content=gp_data.get("patch_content", ""),
|
||||
test_status=gp_data.get("test_status", ""),
|
||||
reasoning=gp_data.get("reasoning", ""),
|
||||
)
|
||||
else:
|
||||
golden_patch = PatchInfo(
|
||||
patch_content=str(gp_data) if gp_data is not None else "",
|
||||
test_status="",
|
||||
reasoning="",
|
||||
)
|
||||
|
||||
# LLM usage
|
||||
llm_usage_data = data.get("llm_usage", {})
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=llm_usage_data.get("completion_tokens", 0),
|
||||
total_tokens=llm_usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
# Tool stats
|
||||
tool_stats_data = data.get("tool_stats", {})
|
||||
tool_stats = ToolStats(
|
||||
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
|
||||
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
|
||||
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
|
||||
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
|
||||
)
|
||||
|
||||
return cls(
|
||||
instance_id=data.get("instance_id", ""),
|
||||
generator_id=data.get("generator_id", 0),
|
||||
image=data.get("image", ""),
|
||||
success=data.get("success", False),
|
||||
golden_patch=golden_patch,
|
||||
llm_usage=llm_usage,
|
||||
tool_stats=tool_stats,
|
||||
total_turns=data.get("total_turns", 0),
|
||||
select_reason=data.get("select_reason", ""),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert SelectorResult to dictionary."""
|
||||
return {
|
||||
"instance_id": self.instance_id,
|
||||
"generator_id": self.generator_id,
|
||||
"image": self.image,
|
||||
"success": self.success,
|
||||
"golden_patch": {
|
||||
"patch_content": self.golden_patch.patch_content,
|
||||
"test_status": self.golden_patch.test_status,
|
||||
"reasoning": self.golden_patch.reasoning,
|
||||
},
|
||||
"llm_usage": {
|
||||
"prompt_tokens": self.llm_usage.prompt_tokens,
|
||||
"completion_tokens": self.llm_usage.completion_tokens,
|
||||
"total_tokens": self.llm_usage.total_tokens,
|
||||
},
|
||||
"tool_stats": {
|
||||
BASH_TOOL_NAME: self.tool_stats.bash,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
|
||||
SEARCH_TOOL_NAME: self.tool_stats.search,
|
||||
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
|
||||
},
|
||||
"total_turns": self.total_turns,
|
||||
"select_reason": self.select_reason,
|
||||
"error": self.error,
|
||||
}
|
||||
268
src/managers/prompts/prompts_manager.py
Normal file
268
src/managers/prompts/prompts_manager.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from typing import Any, List, Dict
|
||||
|
||||
|
||||
class PromptsManager:
|
||||
def __init__(self, config):
|
||||
self.candidate_length = config.get("runner", {}).get("generator_concurrency", 5)
|
||||
|
||||
def get_generator_system(self, root_path: str | None = None):
|
||||
return f"""
|
||||
# You are a highly skilled expert in software engineering focused on resolving complex GitHub issues by effectively analyzing codebases, implementing fixes, and ensuring code reliability through rigorous testing.
|
||||
|
||||
## Skills
|
||||
1. Code Analysis and Debugging
|
||||
- Issue Exploration: Ability to explore and comprehend codebases within repositories.
|
||||
- Workflow Tracing: Skilled in using debugging techniques to trace issues through code.
|
||||
- Root Cause Identification: Proficient in pinpointing the underlying causes of software issues.
|
||||
- Test Creation: Expertise in establishing tests that replicate and validate issues.
|
||||
|
||||
2. Solution Implementation and Testing
|
||||
- Fix Implementation: Experience in crafting precise and minimal code patches.
|
||||
- Comprehensive Testing: Skilled in running and analyzing both existing and newly created tests.
|
||||
- Regression Prevention: Ensures all changes maintain overall code stability.
|
||||
- Continuous Improvement: Iterates on solutions based on test results to achieve optimal functionality.
|
||||
|
||||
## Task
|
||||
Your task is resolve the given GitHub issue by understanding the repository and the issue, implementing a fix, and checking your changes against existing tests and your own test(s).
|
||||
|
||||
Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project.
|
||||
|
||||
For example, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
|
||||
|
||||
Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and I activated the virtual environment for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
|
||||
|
||||
Follow these steps:
|
||||
|
||||
1. Problem Analysis:
|
||||
- Read the issue description carefully to fully grasp the issue and explore the repository (source code, tests, examples) to understand expected behavior of relevant components.
|
||||
- Identify the full scope. Does the issue mention multiple components, backends, or functions? Your solution must address all of them.
|
||||
|
||||
2. Reproduce the issue (IMPORTANT):
|
||||
- Create a test that reproduces the issue as a baseline for verification.
|
||||
- Check that the output of your test matches your understanding of the issue in step 1.
|
||||
|
||||
3. Identify the root cause:
|
||||
- Go through relavant files, create debugging scripts with print statements or use other methods if necessary,to trace the workflow and exact cause of the issue.
|
||||
- Trace the problem to its root cause.** Do not just patch the symptom where the error appears. Trace the data and execution flow upstream to find where the problem originates.
|
||||
|
||||
4. Implement a Fix:
|
||||
- Once you have identified the root cause, develop a precise and targeted fix and then apply it as a minimal patch using the `str_replace_based_edit_tool` tools.
|
||||
|
||||
5. Test comprehensively:
|
||||
- Verify the Fix: Run your initial reproduction script to confirm that the bug is resolved.
|
||||
- Prevent Regressions:
|
||||
--Identify the right tests: Once you have verified your fix, identify the most relevant tests within the project's existing test suite that correspond to your code changes.
|
||||
--Run the tests: Then you **must** run these tests to ensure that your fix does not introduce any new bugs.
|
||||
--Analyze failures carefully:
|
||||
---If tests fail, do not immediately assume your fix is wrong. Critically analyze the failure.
|
||||
---Is it a **regression**? Did your change break existing, valid functionality? If so, you must refine your fix.
|
||||
---Is it an **unrelated failure**? It could be an environmental issue (e.g., missing dependency, network error) or a pre-existing flaky test. If you suspect this, try to run a more focused test and note the issue in your final reasoning.
|
||||
---Is the **test now obsolete**? If your fix improves behavior in a way that makes an old test's assertions incorrect, you should **update the test** to match the new, correct behavior and explain why in your reasoning.
|
||||
- Write New Tests: Create new, specific test cases (e.g., using `pytest`) that cover the original bug scenario.
|
||||
- Consider Edge Cases: Think about and test potential edge cases related to your changes.
|
||||
|
||||
6. Revisit step 1 through 5 if unexpected behavior occurs, then call `submit_result` to submit the reliable and verified solution patch after successful testing and validation.
|
||||
|
||||
**Mandatory Workflow** As a senior engineer, ensure solution correctness and safety. Upon successful verification, immediately conclude the task by calling `submit_result`.
|
||||
"""
|
||||
|
||||
def format_issue_prompt(
|
||||
self,
|
||||
created_at: str,
|
||||
base_commit: str,
|
||||
environment_setup_commit: str,
|
||||
version: str,
|
||||
problem_statement: str,
|
||||
difficulty: str,
|
||||
) -> str:
|
||||
|
||||
template = f"""
|
||||
[📝 Issue Description]
|
||||
|
||||
**Created at**: {created_at}
|
||||
**Base commit**: {base_commit}
|
||||
---
|
||||
|
||||
### 📌 Problem Statement
|
||||
{problem_statement}
|
||||
---
|
||||
|
||||
### ⚙️ Difficulty Level
|
||||
{difficulty}
|
||||
---
|
||||
"""
|
||||
return template.strip()
|
||||
|
||||
def get_generator_user(self, root_path: str, issue_text: str):
|
||||
return (
|
||||
f"""
|
||||
[Project root path]:
|
||||
{root_path}
|
||||
[Issue Information]:
|
||||
{issue_text}
|
||||
"""
|
||||
+ self.get_generator_notice()
|
||||
)
|
||||
|
||||
def get_generator_notice(self):
|
||||
return """
|
||||
[notice]
|
||||
1. Use the available tools to locate the root cause.
|
||||
2. Prioritize using the `search_tool` to retrieve and locate the precise location of key information in the project.
|
||||
3. Collect supporting evidence: stack traces, logs, configs, recent changes, related modules.
|
||||
"""
|
||||
|
||||
def get_selector_system(self, patches_count: int, root_path: str):
|
||||
return f"""
|
||||
# ROLE:
|
||||
|
||||
*You are a highly proficient software engineer tasked with evaluating and selecting optimal code patches to resolve specific issues within a given project.
|
||||
|
||||
*You colleagus worked on {patches_count} potential patches for an github issue. Select ONE correct patch to solve the issue.
|
||||
|
||||
*Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and the virtual environment has been activated for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
|
||||
|
||||
*Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project. For instance, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
|
||||
|
||||
# WORKFLOWS:
|
||||
*Follow these steps without any skipping:
|
||||
1.Problem Analysis:
|
||||
- Read the issue description and the current code that needs to be fixed. Explore the repository (source code, tests, examples) to understand expected behavior of relevant components, and gather comprehensive information about the problem area
|
||||
|
||||
2.Conduct a thorough review of each patch:
|
||||
- Scrutinize all code modifications.
|
||||
- Decipher the core logic and problem-solving methodology.
|
||||
- Evaluate potential edge cases and unintended consequences.
|
||||
- Validate that each patch fully addresses the initial issue specifications.
|
||||
|
||||
3.Verify Your Analysis
|
||||
- Use available tools to verify your analysis works of this issue.
|
||||
- Test your conclusions against relevant code sections.
|
||||
- Ensure full contextual understanding.
|
||||
|
||||
4.Proceed with Your Decision
|
||||
- Upon completion of the preceding three steps, utilize the `submit_result` tool with your detailed reasoning.
|
||||
|
||||
#RULES:
|
||||
1.It is MANDATORY to utilize both available tools prior to finalizing any selectio:
|
||||
-- Start with `bash` to explore the codebase structure;
|
||||
-- Employ the str_replace_based_edit_tool to inspect the current code;
|
||||
-- Use `search_tool` to search related code and file;
|
||||
2.You MUST first explore the codebase before using the `submit_result` tool.
|
||||
3.Substantiate your reasoning with evidence from your analysis.
|
||||
4.Only selections made after employing the tools will be accepted.
|
||||
|
||||
#FINAL DECISION:
|
||||
Upon completion of your tool-based analysis, finalize the process by submitting your choice via the `submit_result` tool.
|
||||
|
||||
#NOTICE:
|
||||
1. Tool usage is MANDATORY - do not skip this step.
|
||||
2. Without making a decision after completing analysis is not permitted.
|
||||
3. Never generate new patches by your own, just make the selection.
|
||||
4. Always provide detailed reasoning for the selection based on your tool-based investigation
|
||||
"""
|
||||
|
||||
def get_selector_user(
|
||||
self, instance_data: Dict[str, Any] | None = None, candidates: List[Any] | None = None, root_path: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate user prompt of selector, including issue information and the first golden patch of each candidate.
|
||||
|
||||
- instance_data: Current instance metadata (issue description etc.)
|
||||
- candidates: Candidates list (.to_dict() supported), only get golden_patch[0].patch_content
|
||||
"""
|
||||
if not instance_data or not candidates:
|
||||
return ""
|
||||
|
||||
created_at = instance_data.get("created_at", "")
|
||||
base_commit = instance_data.get("base_commit", "")
|
||||
environment_setup_commit = instance_data.get("environment_setup_commit", "")
|
||||
version = instance_data.get("version", "")
|
||||
problem_statement = instance_data.get("problem_statement", "")
|
||||
difficulty = instance_data.get("difficulty", "")
|
||||
|
||||
issue_block = self.format_issue_prompt(
|
||||
created_at=created_at,
|
||||
base_commit=base_commit,
|
||||
environment_setup_commit=environment_setup_commit,
|
||||
version=version,
|
||||
problem_statement=problem_statement,
|
||||
difficulty=difficulty,
|
||||
)
|
||||
|
||||
root_path_block = f"""
|
||||
[Project root path]:
|
||||
{root_path or "/testbed"}
|
||||
"""
|
||||
|
||||
parts: List[str] = [root_path_block, issue_block, "\n[🔎 Candidates]\n"]
|
||||
for idx, r in enumerate(candidates):
|
||||
try:
|
||||
data = r.to_dict() if hasattr(r, "to_dict") else {}
|
||||
except Exception:
|
||||
data = {}
|
||||
golden_patch = data.get("golden_patch", [])
|
||||
patch_content = golden_patch[0].get("patch_content", "") if golden_patch else ""
|
||||
test_status = golden_patch[0].get("test_status", "") if golden_patch else ""
|
||||
reasoning = golden_patch[0].get("reasoning", "") if golden_patch else ""
|
||||
parts.append(self.format_selector_candidate(idx, patch_content, test_status, reasoning))
|
||||
|
||||
parts.append(
|
||||
"\nPlease analyze the candidates, then call the submit_result tool with the final index and reasoning."
|
||||
)
|
||||
return "\n".join(parts)
|
||||
|
||||
def get_terminal_response(self, exit_code: int, output: str, timeout_status: bool):
|
||||
if timeout_status == True:
|
||||
return f"""[Terminal response]
|
||||
Exit code: {exit_code}
|
||||
Output: {output}"""
|
||||
else:
|
||||
return f"""[Terminal response]
|
||||
Terminal time out."""
|
||||
|
||||
def tool_response_prompts(self, tool_results: list) -> str:
|
||||
if not tool_results:
|
||||
return ""
|
||||
|
||||
response_parts = ["[tool_response]"]
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
output = result.get("result", "")
|
||||
error = result.get("error", "")
|
||||
|
||||
response_parts.append(f"Tool {i}: {tool_name}")
|
||||
response_parts.append(f"Success: {success}")
|
||||
|
||||
if success and output:
|
||||
response_parts.append(f"Output:\n{output}")
|
||||
elif error:
|
||||
response_parts.append(f"Error: {error}")
|
||||
else:
|
||||
response_parts.append("No output")
|
||||
|
||||
response_parts.append("")
|
||||
|
||||
return "\n".join(response_parts)
|
||||
|
||||
def format_selector_candidate(self, index: int, patch_content: str, test_status: str, reasoning: str) -> str:
|
||||
"""
|
||||
Generate description of selector candidate items, including key information of the first golden_patch
|
||||
|
||||
- index: Candidate index(0-based)
|
||||
- patch_content: golden_patch[0].patch_content
|
||||
- test_status: golden_patch[0].test_status test status in generating stage
|
||||
- reasoning: golden_patch[0].reasoning model reasoning in generating stage
|
||||
"""
|
||||
header = f"- Candidate #{index}:"
|
||||
patch_block = patch_content or ""
|
||||
status_block = test_status or ""
|
||||
reasoning_block = reasoning or ""
|
||||
return (
|
||||
f"--{header}\n"
|
||||
f"--Patch content (the proposed fix):\n{patch_block}\n\n"
|
||||
f"--Test status during generation: {status_block}\n\n"
|
||||
f"--Reasoning during generation (model's logic):\n{reasoning_block}"
|
||||
)
|
||||
103
src/managers/result_builder/result_builder.py
Normal file
103
src/managers/result_builder/result_builder.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
import json
|
||||
|
||||
|
||||
class ResultBuilder:
|
||||
"""
|
||||
Build preds.json
|
||||
|
||||
Iterate through JSON files named by instance_id in the directory specified by `runner.selector_result_dump_path` in the configuration.
|
||||
- Parse the `golden_patch` field from each JSON file and extract the patch text as `model_patch`.
|
||||
- Read the first top-level field from providers and the first model under it, and concatenate them to form `model_name_or_path`.
|
||||
- The output location is `{workspace.path}/{result.preds.path}/preds.json`.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config or {}
|
||||
|
||||
def _get_selector_dump_dir(self) -> Path:
|
||||
runner_cfg = self.config.get("runner", {}) if isinstance(self.config, dict) else {}
|
||||
dump_dir_str = runner_cfg.get(
|
||||
"selector_result_dump_path", "workspace/selector_result_dump"
|
||||
)
|
||||
return Path(dump_dir_str)
|
||||
|
||||
def _get_preds_output_dir(self) -> Path:
|
||||
workspace_cfg = self.config.get("workspace", {}) if isinstance(self.config, dict) else {}
|
||||
result_cfg = self.config.get("result", {}) if isinstance(self.config, dict) else {}
|
||||
preds_cfg = result_cfg.get("preds", {}) if isinstance(result_cfg, dict) else {}
|
||||
|
||||
workspace_path = workspace_cfg.get("path", "workspace")
|
||||
preds_path = preds_cfg.get("path", "result")
|
||||
return Path(workspace_path) / preds_path
|
||||
|
||||
def _get_model_name_or_path(self) -> str:
|
||||
providers = self.config.get("providers", {}) if isinstance(self.config, dict) else {}
|
||||
if not isinstance(providers, dict) or not providers:
|
||||
return ""
|
||||
first_provider_name = next(iter(providers.keys()))
|
||||
first_models = providers.get(first_provider_name, [])
|
||||
if isinstance(first_models, list) and first_models:
|
||||
first_model = first_models[0]
|
||||
else:
|
||||
first_model = ""
|
||||
return f"{first_provider_name}/{first_model}" if first_provider_name and first_model else ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_patch(golden_patch: Any) -> str:
|
||||
"""
|
||||
Extract patch content from golden_patch
|
||||
|
||||
Forms supported:
|
||||
- dict: prioritize extract 'patch_content', then attempt `model_patch`
|
||||
- string: Directly return
|
||||
- other: return empty string
|
||||
"""
|
||||
if isinstance(golden_patch, dict):
|
||||
if "patch_content" in golden_patch and isinstance(golden_patch["patch_content"], str):
|
||||
return golden_patch["patch_content"]
|
||||
if "model_patch" in golden_patch and isinstance(golden_patch["model_patch"], str):
|
||||
return golden_patch["model_patch"]
|
||||
return ""
|
||||
if isinstance(golden_patch, str):
|
||||
return golden_patch
|
||||
return ""
|
||||
|
||||
def build_preds(self) -> Path:
|
||||
dump_dir = self._get_selector_dump_dir()
|
||||
output_dir = self._get_preds_output_dir()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_file = output_dir / "preds.json"
|
||||
|
||||
model_name_or_path = self._get_model_name_or_path()
|
||||
|
||||
# SWE-bench evaluation expects: list[dict], each element includes instance_id / model_patch / model
|
||||
predictions: list[dict[str, str]] = []
|
||||
|
||||
if dump_dir.exists() and dump_dir.is_dir():
|
||||
for path in sorted(dump_dir.glob("*.json")):
|
||||
try:
|
||||
instance_id = path.stem
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
golden_patch = data.get("golden_patch", {}) if isinstance(data, dict) else {}
|
||||
model_patch = self._extract_model_patch(golden_patch)
|
||||
predictions.append(
|
||||
{
|
||||
"instance_id": instance_id,
|
||||
"model_patch": model_patch,
|
||||
"model_name_or_path": model_name_or_path,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(predictions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return output_file
|
||||
|
||||
|
||||
35
src/tools/__init__.py
Normal file
35
src/tools/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Tools module for Code Agent."""
|
||||
|
||||
from src.tools.base import (
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolExecutor,
|
||||
ToolResult,
|
||||
BASH_TOOL_NAME,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
from src.tools.bash_tool import BashTool
|
||||
from src.tools.edit_tool import TextEditorTool
|
||||
from src.tools.search_tool import SearchTool
|
||||
from src.tools.submit_result_tool import SubmitResultTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolResult",
|
||||
"ToolCall",
|
||||
"ToolExecutor",
|
||||
"BashTool",
|
||||
"TextEditorTool",
|
||||
"JSONEditTool",
|
||||
"SearchTool",
|
||||
"SubmitResultTool",
|
||||
]
|
||||
|
||||
tools_registry: dict[str, type[Tool]] = {
|
||||
BASH_TOOL_NAME: BashTool,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME: TextEditorTool,
|
||||
SEARCH_TOOL_NAME: SearchTool,
|
||||
SUBMIT_RESULT_TOOL_NAME: SubmitResultTool,
|
||||
}
|
||||
523
src/tools/base.py
Normal file
523
src/tools/base.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
||||
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
|
||||
#
|
||||
# Original file was released under MIT License, with the full license text
|
||||
# available at https://github.com/bytedance/trae-agent/blob/main/LICENSE
|
||||
#
|
||||
# This modified file is released under the same license.
|
||||
|
||||
|
||||
"""Base classes for tools and tool calling."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import override
|
||||
from src.managers.log.logger import Logger
|
||||
from typing import Dict, Any
|
||||
from traceback import format_exc
|
||||
from pathlib import Path
|
||||
|
||||
ParamSchemaValue = str | list[str] | bool | dict[str, object]
|
||||
Property = dict[str, ParamSchemaValue]
|
||||
|
||||
BASH_TOOL_NAME = "bash"
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME = "str_replace_based_edit_tool"
|
||||
SEARCH_TOOL_NAME = "search_tool"
|
||||
SUBMIT_RESULT_TOOL_NAME = "submit_result"
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Base class for tool errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message: str = message
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecResult:
|
||||
"""Intermediate result of a tool execution."""
|
||||
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
error_code: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result of a tool execution."""
|
||||
|
||||
call_id: str
|
||||
name: str # Gemini specific field
|
||||
success: bool
|
||||
result: str | None = None
|
||||
error: str | None = None
|
||||
id: str | None = None # OpenAI-specific field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubmitToolResult:
|
||||
"""Structured result for submit_result tool."""
|
||||
|
||||
return_code: int
|
||||
output: str
|
||||
is_task_done: bool
|
||||
test_status: str
|
||||
reasoning: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Convert to JSON string for output."""
|
||||
import json
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"return_code": self.return_code,
|
||||
"output": self.output,
|
||||
"is_task_done": self.is_task_done,
|
||||
"test_status": self.test_status,
|
||||
"reasoning": self.reasoning,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, json_str: str) -> "SubmitToolResult":
|
||||
"""Create SubmitToolResult from JSON string."""
|
||||
import json
|
||||
|
||||
data = json.loads(json_str)
|
||||
return cls(
|
||||
return_code=data.get("return_code", 0),
|
||||
output=data.get("output", ""),
|
||||
is_task_done=data.get("is_task_done", False),
|
||||
test_status=data.get("test_status", "error"),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
)
|
||||
|
||||
|
||||
ToolCallArguments = dict[
|
||||
str, str | int | float | dict[str, object] | list[object] | None
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""Represents a parsed tool call."""
|
||||
|
||||
name: str
|
||||
call_id: str
|
||||
arguments: ToolCallArguments = field(default_factory=dict)
|
||||
id: str | None = None
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"ToolCall(name={self.name}, arguments={self.arguments}, call_id={self.call_id}, id={self.id})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParameter:
|
||||
"""Tool parameter definition."""
|
||||
|
||||
name: str
|
||||
type: str | list[str]
|
||||
description: str
|
||||
enum: list[str] | None = None
|
||||
items: dict[str, object] | None = None
|
||||
required: bool = True
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""Base class for all tools."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: str | None = None,
|
||||
logger: Logger | None = None,
|
||||
config: Dict[str, Any] | None = None,
|
||||
):
|
||||
self._model_provider = model_provider
|
||||
self.logger = logger
|
||||
self.config = config
|
||||
|
||||
@cached_property
|
||||
def model_provider(self) -> str | None:
|
||||
return self.get_model_provider()
|
||||
|
||||
@cached_property
|
||||
def name(self) -> str:
|
||||
return self.get_name()
|
||||
|
||||
@cached_property
|
||||
def description(self) -> str:
|
||||
return self.get_description()
|
||||
|
||||
@cached_property
|
||||
def parameters(self) -> list[ToolParameter]:
|
||||
return self.get_parameters()
|
||||
|
||||
def get_model_provider(self) -> str | None:
|
||||
"""Get the model provider."""
|
||||
return self._model_provider
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Get the tool name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""Get the tool description."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_parameters(self) -> list[ToolParameter]:
|
||||
"""Get the tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
|
||||
"""Execute the tool with given parameters."""
|
||||
pass
|
||||
|
||||
# Optional container execution hooks (to be overridden by tools that support containers)
|
||||
async def container_execute(self, arguments: ToolCallArguments) -> ToolExecResult:
|
||||
"""Execute the tool inside a container shell (optional)."""
|
||||
raise ToolError(
|
||||
f"Tool '{self.get_name()}' does not support container execution"
|
||||
)
|
||||
|
||||
def container_search(
|
||||
self, arguments: ToolCallArguments, session_id: str = "0"
|
||||
) -> ToolExecResult:
|
||||
"""Execute a search-like operation inside container (optional)."""
|
||||
raise ToolError(f"Tool '{self.get_name()}' does not support container search")
|
||||
|
||||
# Optional container file editing hooks used by edit tools
|
||||
def container_read_file(self, path) -> str:
|
||||
"""Read a file inside container (optional)."""
|
||||
raise ToolError(
|
||||
f"Tool '{self.get_name()}' does not support container_read_file"
|
||||
)
|
||||
|
||||
def container_write_file(self, path, content: str) -> None:
|
||||
"""Write a file inside container (optional)."""
|
||||
raise ToolError(
|
||||
f"Tool '{self.get_name()}' does not support container_write_file"
|
||||
)
|
||||
|
||||
def container_str_replace(
|
||||
self, path, old_str: str, new_str: str | None
|
||||
) -> ToolExecResult:
|
||||
"""String replace inside a file in container (optional)."""
|
||||
raise ToolError(
|
||||
f"Tool '{self.get_name()}' does not support container_str_replace"
|
||||
)
|
||||
|
||||
def container_insert(self, path, insert_line: int, new_str: str) -> ToolExecResult:
|
||||
"""Insert text into a file in container (optional)."""
|
||||
raise ToolError(f"Tool '{self.get_name()}' does not support container_insert")
|
||||
|
||||
def view_handler_container(
|
||||
self, arguments: ToolCallArguments, path: Path
|
||||
) -> ToolExecResult:
|
||||
"""View handler in container (optional)."""
|
||||
raise ToolError(f"Tool '{self.get_name()}' does not support view_handler_container")
|
||||
|
||||
|
||||
def json_definition(self) -> dict[str, object]:
|
||||
"""Default return Claude format (backward compatibility)"""
|
||||
return self._definition_for_claude_fmt()
|
||||
|
||||
def _definition_for_claude_fmt(self) -> dict[str, object]:
|
||||
"""Return Claude format tool definition (Anthropic Messages API)"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"input_schema": self.get_input_schema(),
|
||||
}
|
||||
|
||||
def _definition_for_openai_fmt(self) -> dict[str, object]:
|
||||
"""Return OpenAI format tool definition"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.get_input_schema(),
|
||||
},
|
||||
}
|
||||
|
||||
def get_input_schema(self) -> dict[str, object]:
|
||||
"""Get the input schema for the tool."""
|
||||
schema: dict[str, object] = {
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
properties: dict[str, Property] = {}
|
||||
required: list[str] = []
|
||||
|
||||
for param in self.parameters:
|
||||
param_schema: Property = {
|
||||
"type": param.type,
|
||||
"description": param.description,
|
||||
}
|
||||
|
||||
# For OpenAI strict mode, all params must be in 'required'.
|
||||
# Optional params are made "nullable" to be compliant.
|
||||
if self.model_provider == "openai":
|
||||
required.append(param.name)
|
||||
if not param.required:
|
||||
current_type = param_schema["type"]
|
||||
if isinstance(current_type, str):
|
||||
param_schema["type"] = [current_type, "null"]
|
||||
elif isinstance(current_type, list) and "null" not in current_type:
|
||||
param_schema["type"] = list(current_type) + ["null"]
|
||||
elif param.required:
|
||||
required.append(param.name)
|
||||
|
||||
if param.enum:
|
||||
param_schema["enum"] = param.enum
|
||||
|
||||
if param.items:
|
||||
param_schema["items"] = param.items
|
||||
|
||||
# For OpenAI, nested objects also need additionalProperties: false
|
||||
if self.model_provider == "openai" and param.type == "object":
|
||||
param_schema["additionalProperties"] = False
|
||||
|
||||
properties[param.name] = param_schema
|
||||
|
||||
schema["properties"] = properties
|
||||
if len(required) > 0:
|
||||
schema["required"] = required
|
||||
|
||||
# For OpenAI, the top-level schema needs additionalProperties: false
|
||||
if self.model_provider == "openai":
|
||||
schema["additionalProperties"] = False
|
||||
|
||||
return schema
|
||||
|
||||
async def close(self):
|
||||
"""Ensure proper tool resource deallocation before task completion."""
|
||||
return None # Using "pass" will trigger a Ruff check error: B027
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Tool executor that manages tool execution."""
|
||||
|
||||
def __init__(self, tools: list[Tool], logger: Logger | None = None):
|
||||
self._tools = tools
|
||||
self._tool_map: dict[str, Tool] | None = None
|
||||
self.logger = logger
|
||||
|
||||
async def close_tools(self):
|
||||
"""Ensure all tool resources are properly released."""
|
||||
tasks = [tool.close() for tool in self._tools if hasattr(tool, "close")]
|
||||
res = await asyncio.gather(*tasks)
|
||||
return res
|
||||
|
||||
def _normalize_name(self, name: str) -> str:
|
||||
"""Normalize tool name by making it lowercase and removing underscores."""
|
||||
return name.lower().replace("_", "")
|
||||
|
||||
@property
|
||||
def tools(self) -> dict[str, Tool]:
|
||||
if self._tool_map is None:
|
||||
self._tool_map = {
|
||||
self._normalize_name(tool.name): tool for tool in self._tools
|
||||
}
|
||||
return self._tool_map
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
|
||||
"""Execute a tool call locally."""
|
||||
normalized_name = self._normalize_name(tool_call.name)
|
||||
if normalized_name not in self.tools:
|
||||
return ToolResult(
|
||||
name=tool_call.name,
|
||||
success=False,
|
||||
error=f"Tool '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}",
|
||||
call_id=tool_call.call_id,
|
||||
id=tool_call.id,
|
||||
)
|
||||
|
||||
tool = self.tools[normalized_name]
|
||||
|
||||
try:
|
||||
tool_exec_result = await tool.execute(tool_call.arguments)
|
||||
return ToolResult(
|
||||
name=tool_call.name,
|
||||
success=tool_exec_result.error_code == 0,
|
||||
result=tool_exec_result.output,
|
||||
error=tool_exec_result.error,
|
||||
call_id=tool_call.call_id,
|
||||
id=tool_call.id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
name=tool_call.name,
|
||||
success=False,
|
||||
error=f"Error executing tool '{tool_call.name}': {str(e)}, traceback: {format_exc()}.",
|
||||
call_id=tool_call.call_id,
|
||||
id=tool_call.id,
|
||||
)
|
||||
|
||||
async def container_execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
|
||||
"""Execute a tool call in container."""
|
||||
normalized_name = self._normalize_name(tool_call.name)
|
||||
if normalized_name not in self.tools:
|
||||
self.logger.warning(
|
||||
f"[ToolExecutor] '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}"
|
||||
)
|
||||
return ToolResult(
|
||||
name=tool_call.name,
|
||||
success=False,
|
||||
error=f"Tool '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}",
|
||||
call_id=tool_call.call_id,
|
||||
id=tool_call.id,
|
||||
)
|
||||
|
||||
tool = self.tools[normalized_name]
|
||||
|
||||
try:
|
||||
tool_exec_result = await self._container_execute_tool_by_name(
|
||||
tool, tool_call
|
||||
)
|
||||
return ToolResult(
|
||||
name=tool_call.name,
|
||||
success=tool_exec_result.error_code == 0,
|
||||
result=tool_exec_result.output,
|
||||
error=tool_exec_result.error,
|
||||
call_id=tool_call.call_id,
|
||||
id=tool_call.id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
name=tool_call.name,
|
||||
success=False,
|
||||
error=f"Error executing tool '{tool_call.name}': {str(e)}, traceback: {format_exc()}.",
|
||||
call_id=tool_call.call_id,
|
||||
id=tool_call.id,
|
||||
)
|
||||
|
||||
async def _container_execute_tool_by_name(
|
||||
self, tool: Tool, tool_call: ToolCall
|
||||
) -> ToolExecResult:
|
||||
tool_name = tool.get_name()
|
||||
|
||||
if tool_name == BASH_TOOL_NAME:
|
||||
# BashTool: execute through container
|
||||
if hasattr(tool, "container_execute"):
|
||||
return await tool.container_execute(tool_call.arguments)
|
||||
else:
|
||||
raise ToolError(
|
||||
f"Tool '{tool_name}' does not support container execution"
|
||||
)
|
||||
|
||||
elif tool_name == STR_REPLACE_BASED_EDIT_TOOL_NAME:
|
||||
# TextEditorTool: execute through container
|
||||
if hasattr(tool, "container_read_file"):
|
||||
return await self._execute_edit_tool_in_container(
|
||||
tool, tool_call.arguments
|
||||
)
|
||||
else:
|
||||
raise ToolError(
|
||||
f"Tool '{tool_name}' does not support container execution"
|
||||
)
|
||||
|
||||
elif tool_name == SEARCH_TOOL_NAME:
|
||||
# SearchTool: execute through container
|
||||
if hasattr(tool, "container_search"):
|
||||
return tool.container_search(tool_call.arguments)
|
||||
else:
|
||||
raise ToolError(
|
||||
f"Tool '{tool_name}' does not support container execution"
|
||||
)
|
||||
elif tool_name == SUBMIT_RESULT_TOOL_NAME:
|
||||
# SubmitResultTool: execute through container
|
||||
if hasattr(tool, "container_execute"):
|
||||
return await tool.container_execute(tool_call.arguments)
|
||||
else:
|
||||
raise ToolError(
|
||||
f"Tool '{tool_name}' does not support container execution"
|
||||
)
|
||||
else:
|
||||
# Other tools:container execution not supported
|
||||
raise ToolError(f"Tool '{tool_name}' does not support container execution")
|
||||
|
||||
async def _execute_edit_tool_in_container(
|
||||
self, tool: Tool, arguments: ToolCallArguments
|
||||
) -> ToolExecResult:
|
||||
command = str(arguments.get("command", ""))
|
||||
path_str = str(arguments.get("path", ""))
|
||||
|
||||
if not path_str:
|
||||
return ToolExecResult(
|
||||
error="No path provided for the edit tool", error_code=-1
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(path_str)
|
||||
|
||||
try:
|
||||
if command == "view":
|
||||
return tool.view_handler_container(arguments, path)
|
||||
#return ToolExecResult(output=tool._make_output(content, str(path)))
|
||||
|
||||
elif command == "create":
|
||||
file_text = str(arguments.get("file_text", ""))
|
||||
tool.container_write_file(path, file_text)
|
||||
return ToolExecResult(output=f"File created successfully at: {path}")
|
||||
|
||||
elif command == "str_replace":
|
||||
old_str = str(arguments.get("old_str", ""))
|
||||
new_str = arguments.get("new_str")
|
||||
if new_str is not None:
|
||||
new_str = str(new_str)
|
||||
return tool.container_str_replace(path, old_str, new_str)
|
||||
|
||||
elif command == "insert":
|
||||
insert_line = int(arguments.get("insert_line", 0))
|
||||
new_str = str(arguments.get("new_str", ""))
|
||||
return tool.container_insert(path, insert_line, new_str)
|
||||
|
||||
else:
|
||||
return ToolExecResult(
|
||||
error=f"Unsupported command '{command}' for container execution",
|
||||
error_code=-1,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolExecResult(
|
||||
error=f"Container edit tool error: {str(e)}.", error_code=-1
|
||||
)
|
||||
|
||||
async def parallel_tool_call(self, tool_calls: list[ToolCall]) -> list[ToolResult]:
|
||||
"""Execute tool calls in parallel locally"""
|
||||
return await asyncio.gather(
|
||||
*[self.execute_tool_call(call) for call in tool_calls]
|
||||
)
|
||||
|
||||
async def sequential_tool_call(
|
||||
self, tool_calls: list[ToolCall]
|
||||
) -> list[ToolResult]:
|
||||
"""Execute tool calls in sequential locally"""
|
||||
return [await self.execute_tool_call(call) for call in tool_calls]
|
||||
|
||||
async def container_parallel_tool_call(
|
||||
self, tool_calls: list[ToolCall]
|
||||
) -> list[ToolResult]:
|
||||
"""Execute tool calls in parallel in container"""
|
||||
return await asyncio.gather(
|
||||
*[self.container_execute_tool_call(call) for call in tool_calls]
|
||||
)
|
||||
|
||||
async def container_sequential_tool_call(
|
||||
self, tool_calls: list[ToolCall]
|
||||
) -> list[ToolResult]:
|
||||
"""Execute tool calls in sequential in container"""
|
||||
return [await self.container_execute_tool_call(call) for call in tool_calls]
|
||||
314
src/tools/bash_tool.py
Normal file
314
src/tools/bash_tool.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# Copyright (c) 2023 Anthropic
|
||||
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
||||
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
|
||||
#
|
||||
# Original file was released under MIT License, with the full license text
|
||||
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
|
||||
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
|
||||
#
|
||||
# This modified file is released under the same license.
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import override
|
||||
|
||||
from src.tools.base import (
|
||||
Tool,
|
||||
ToolCallArguments,
|
||||
ToolError,
|
||||
ToolExecResult,
|
||||
ToolParameter,
|
||||
BASH_TOOL_NAME,
|
||||
)
|
||||
from src.tools.executor import Executor
|
||||
from src.managers.log.logger import Logger
|
||||
from typing import Dict, Any
|
||||
from traceback import format_exc
|
||||
|
||||
|
||||
class _BashSession:
|
||||
"""A session of a bash shell."""
|
||||
|
||||
_started: bool
|
||||
_timed_out: bool
|
||||
|
||||
command: str = "/bin/bash"
|
||||
_output_delay: float = 0.2 # seconds
|
||||
_timeout: float = 120.0 # seconds
|
||||
_sentinel: str = (
|
||||
",,,,bash-command-exit-__ERROR_CODE__-banner,,,," # `__ERROR_CODE__` will be replaced by `$?` or `!errorlevel!` later
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._started = False
|
||||
self._timed_out = False
|
||||
self._process: asyncio.subprocess.Process | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
|
||||
# Windows compatibility: os.setsid not available
|
||||
|
||||
if os.name != "nt": # Unix-like systems
|
||||
self._process = await asyncio.create_subprocess_shell(
|
||||
self.command,
|
||||
shell=True,
|
||||
bufsize=0,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
else:
|
||||
self._process = await asyncio.create_subprocess_shell(
|
||||
"cmd.exe /v:on", # enable delayed expansion to allow `echo !errorlevel!`
|
||||
shell=True,
|
||||
bufsize=0,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Terminate the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process is None:
|
||||
return
|
||||
if self._process.returncode is not None:
|
||||
return
|
||||
self._process.terminate()
|
||||
|
||||
# Wait until the process has truly terminated.
|
||||
stdout, stderr = await self._process.communicate()
|
||||
|
||||
async def run(self, command: str) -> ToolExecResult:
|
||||
"""Execute a command in the bash shell."""
|
||||
if not self._started or self._process is None:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return ToolExecResult(
|
||||
error=f"bash has exited with returncode {self._process.returncode}. tool must be restarted.",
|
||||
error_code=-1,
|
||||
)
|
||||
if self._timed_out:
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
)
|
||||
|
||||
# we know these are not None because we created the process with PIPEs
|
||||
assert self._process.stdin
|
||||
assert self._process.stdout
|
||||
assert self._process.stderr
|
||||
|
||||
error_code = 0
|
||||
|
||||
sentinel_before, pivot, sentinel_after = self._sentinel.partition(
|
||||
"__ERROR_CODE__"
|
||||
)
|
||||
assert pivot == "__ERROR_CODE__"
|
||||
|
||||
errcode_retriever = "!errorlevel!" if os.name == "nt" else "$?"
|
||||
command_sep = "&" if os.name == "nt" else ";"
|
||||
|
||||
# send command to the process
|
||||
self._process.stdin.write(
|
||||
b"(\n"
|
||||
+ command.encode()
|
||||
+ f"\n){command_sep} echo {self._sentinel.replace('__ERROR_CODE__', errcode_retriever)}\n".encode()
|
||||
)
|
||||
await self._process.stdin.drain()
|
||||
|
||||
# read output from the process, until the sentinel is found
|
||||
try:
|
||||
async with asyncio.timeout(self._timeout):
|
||||
while True:
|
||||
await asyncio.sleep(self._output_delay)
|
||||
# if we read directly from stdout/stderr, it will wait forever for
|
||||
# EOF. use the StreamReader buffer directly instead.
|
||||
output: str = self._process.stdout._buffer.decode() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType]
|
||||
if sentinel_before in output:
|
||||
# strip the sentinel from output
|
||||
output, pivot, exit_banner = output.rpartition(sentinel_before)
|
||||
assert pivot
|
||||
|
||||
# get error code inside banner
|
||||
error_code_str, pivot, _ = exit_banner.partition(sentinel_after)
|
||||
if not pivot or not error_code_str.isdecimal():
|
||||
continue
|
||||
|
||||
error_code = int(error_code_str)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
self._timed_out = True
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
) from None
|
||||
|
||||
if output.endswith("\n"): # pyright: ignore[reportUnknownMemberType]
|
||||
output = output[:-1] # pyright: ignore[reportUnknownVariableType]
|
||||
|
||||
error: str = self._process.stderr._buffer.decode() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue]
|
||||
if error.endswith("\n"): # pyright: ignore[reportUnknownMemberType]
|
||||
error = error[:-1] # pyright: ignore[reportUnknownVariableType]
|
||||
|
||||
# clear the buffers so that the next output can be read correctly
|
||||
self._process.stdout._buffer.clear() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
|
||||
self._process.stderr._buffer.clear() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
|
||||
|
||||
return ToolExecResult(
|
||||
output=output, error=error, error_code=error_code
|
||||
) # pyright: ignore[reportUnknownArgumentType]
|
||||
|
||||
|
||||
class BashTool(Tool):
|
||||
"""
|
||||
A tool that allows the agent to run bash commands.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: str | None = None,
|
||||
executor: Executor | None = None,
|
||||
logger: Logger | None = None,
|
||||
config: Dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(model_provider, logger, config)
|
||||
self._session: _BashSession | None = None
|
||||
self.executor = executor
|
||||
|
||||
@override
|
||||
def get_model_provider(self) -> str | None:
|
||||
return self._model_provider
|
||||
|
||||
@override
|
||||
def get_name(self) -> str:
|
||||
return BASH_TOOL_NAME
|
||||
|
||||
@override
|
||||
def get_description(self) -> str:
|
||||
return """Execute commands within a bash shell environment, either on the local system or inside a container.
|
||||
* When providing the "command" parameter, its contents must be provided as-is without any XML escaping.
|
||||
* You have access to a mirrored repository of common Linux (via apt) and Python (via pip) packages for installation.
|
||||
* State is persisted across all command executions and throughout our conversation session.
|
||||
* Avoid executing commands that are likely to generate excessively large outputs.
|
||||
* Avoid executing interactive commands that require user input (e.g., password prompts, confirmation messages).
|
||||
* For Git commands, always prefer non-interactive forms. For example, use git --no-pager diff instead of git diff to prevent opening a pager.
|
||||
* To inspect a specific range of lines in a file (e.g., lines 5-10), you can use a command like: sed -n '5,10p' /path/to/file
|
||||
"""
|
||||
|
||||
@override
|
||||
def get_parameters(self) -> list[ToolParameter]:
|
||||
# For OpenAI models, all parameters must be required=True
|
||||
# For other providers, optional parameters can have required=False
|
||||
restart_required = self.model_provider == "openai"
|
||||
|
||||
return [
|
||||
ToolParameter(
|
||||
name="command",
|
||||
type="string",
|
||||
description="The exact bash command string to be executed.",
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="restart",
|
||||
type="boolean",
|
||||
description="If true, terminates the current shell session and starts a new one before executing the command. This clears the session state.",
|
||||
required=restart_required,
|
||||
),
|
||||
]
|
||||
|
||||
@override
|
||||
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
|
||||
if arguments.get("restart"):
|
||||
if self._session:
|
||||
await self._session.stop()
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
|
||||
return ToolExecResult(output="tool has been restarted.")
|
||||
|
||||
if self._session is None:
|
||||
try:
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
except Exception as e:
|
||||
return ToolExecResult(
|
||||
error=f"Error starting bash session: {e}",
|
||||
error_code=-1,
|
||||
)
|
||||
|
||||
command = str(arguments["command"]) if "command" in arguments else None
|
||||
if command is None:
|
||||
return ToolExecResult(
|
||||
error=f"No command provided for the {self.get_name()} tool",
|
||||
error_code=-1,
|
||||
)
|
||||
try:
|
||||
return await self._session.run(command)
|
||||
except Exception as e:
|
||||
return ToolExecResult(
|
||||
error=f"Error running bash command: {e}",
|
||||
error_code=-1,
|
||||
)
|
||||
|
||||
async def container_execute(
|
||||
self, arguments: ToolCallArguments, session_id: str = "0"
|
||||
) -> ToolExecResult:
|
||||
"""Execute a command in a container bash shell."""
|
||||
if not self.executor:
|
||||
return ToolExecResult(
|
||||
error="Container execution requires an executor to be provided during tool initialization",
|
||||
error_code=-1,
|
||||
)
|
||||
|
||||
if arguments.get("restart"):
|
||||
# Close the existing session if it exists
|
||||
self.executor.close_session("0")
|
||||
# The executor will automatically recreate session '0' when needed
|
||||
return ToolExecResult(output="Container session has been restarted.")
|
||||
|
||||
command = str(arguments["command"]) if "command" in arguments else None
|
||||
if command is None:
|
||||
return ToolExecResult(
|
||||
error=f"No command provided for container execution",
|
||||
error_code=-1,
|
||||
)
|
||||
# command_with_init = f"source /opt/miniconda3/bin/activate && conda activate testbed && {command}"
|
||||
# Check if the session is alive before executing the command
|
||||
if not self.executor.check_session():
|
||||
return ToolExecResult(
|
||||
error="Container session is not alive and could not be restarted",
|
||||
error_code=-1,
|
||||
)
|
||||
|
||||
try:
|
||||
return_code, output = self.executor.execute(session_id, command)
|
||||
# return_code, output = self.executor.execute_once(command_with_init)
|
||||
|
||||
# The executor returns (return_code, output) tuple
|
||||
# We'll treat any non-zero return code as an error
|
||||
error = None
|
||||
if return_code != 0:
|
||||
error = f"Command failed with exit code {return_code}, output: {output}"
|
||||
|
||||
return ToolExecResult(output=output, error=error, error_code=return_code)
|
||||
except Exception as e:
|
||||
return ToolExecResult(
|
||||
error=f"Error running container bash command: {e}", error_code=-1
|
||||
)
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
"""Properly close self._process."""
|
||||
if self._session:
|
||||
await self._session.stop()
|
||||
self._session = None
|
||||
735
src/tools/edit_tool.py
Normal file
735
src/tools/edit_tool.py
Normal file
@@ -0,0 +1,735 @@
|
||||
# Copyright (c) 2023 Anthropic
|
||||
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
||||
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
|
||||
#
|
||||
# Original file was released under MIT License, with the full license text
|
||||
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
|
||||
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
|
||||
#
|
||||
# This modified file is released under the same license.
|
||||
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Optional, override
|
||||
import shlex
|
||||
|
||||
from src.tools.base import (
|
||||
Tool,
|
||||
ToolCallArguments,
|
||||
ToolError,
|
||||
ToolExecResult,
|
||||
ToolParameter,
|
||||
STR_REPLACE_BASED_EDIT_TOOL_NAME,
|
||||
)
|
||||
from src.tools.run import maybe_truncate, run
|
||||
from src.tools.executor import Executor
|
||||
from src.managers.log.logger import Logger
|
||||
from typing import Dict, Any
|
||||
from traceback import format_exc
|
||||
|
||||
EditToolSubCommands = [
|
||||
"view",
|
||||
"create",
|
||||
"str_replace",
|
||||
"insert",
|
||||
]
|
||||
SNIPPET_LINES: int = 4
|
||||
|
||||
|
||||
class TextEditorTool(Tool):
|
||||
"""Tool to replace a string in a file."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: str | None = None,
|
||||
executor: Executor | None = None,
|
||||
logger: Logger | None = None,
|
||||
config: Dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(model_provider, logger, config)
|
||||
self.executor = executor
|
||||
|
||||
@override
|
||||
def get_model_provider(self) -> str | None:
|
||||
return self._model_provider
|
||||
|
||||
@override
|
||||
def get_name(self) -> str:
|
||||
return STR_REPLACE_BASED_EDIT_TOOL_NAME
|
||||
|
||||
@override
|
||||
def get_description(self) -> str:
|
||||
|
||||
return """This tool provides capabilities for viewing, creating and editing files
|
||||
* This tool is stateless. No context is retained between individual command invocations.
|
||||
* Content Examination `view`:
|
||||
** For a file: Executing `view` on a file path will output the file's full content with sequential line numbers prefixed (using cat -n).
|
||||
** For a directory: Executing view on a directory path will recursively list all non-hidden items, displaying contents up to two directory levels deep.
|
||||
|
||||
* File Creation `create`:
|
||||
** The `create` operation is strictly prohibited if a file already exists at the specified `path`.
|
||||
** Mandatory Pre-action: You must explicitly remove any existing file at the target `path` before proceeding with the creation of a new file.
|
||||
|
||||
* Output Handling:
|
||||
** Should the output generated by a `command` exceed a certain length threshold, it will be automatically shortened and clearly marked with the indicator: <response clipped>.
|
||||
|
||||
* String Replacement `str_replace` Operational Rules:
|
||||
** Precision Targeting: The `old_str` parameter must be an exact, character-for-character match of one or more complete lines from the source file. Special attention must be paid to invisible characters like spaces and tabs.
|
||||
** Match Uniqueness: The replacement will be canceled if the specified `old_str` pattern is not absolutely unique within the file. To ensure a single match, expand the `old_str` scope to include sufficient preceding or following context lines.
|
||||
** Content Insertion: The `new_str` parameter defines the complete set of lines that will be inserted into the file, directly replacing the content matched by `old_str`.
|
||||
"""
|
||||
|
||||
@override
|
||||
def get_parameters(self) -> list[ToolParameter]:
|
||||
"""Get the parameters for the str_replace_based_edit_tool."""
|
||||
return [
|
||||
ToolParameter(
|
||||
name="command",
|
||||
type="string",
|
||||
description=f"Operation to execute. Supported commands: {', '.join(EditToolSubCommands)}.",
|
||||
required=True,
|
||||
enum=EditToolSubCommands,
|
||||
),
|
||||
ToolParameter(
|
||||
name="file_text",
|
||||
type="string",
|
||||
description="Required for `create` command. Specifies the textual content for the new file.",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
name="insert_line",
|
||||
type="integer",
|
||||
description="Required for `insert` command. The line number AFTER which the `new_str` will be inserted.",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_str",
|
||||
type="string",
|
||||
description="For `str_replace`: the replacement text (optional, defaults to empty). For `insert`: the text to insert (required).",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
name="old_str",
|
||||
type="string",
|
||||
description="Required for `str_replace` command. The exact text segment in the file to be replaced.",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
name="path",
|
||||
type="string",
|
||||
description="Absolute filesystem path to the target file or directory. Example: `/workspace/script.py` or `/workspace`.",
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="view_range",
|
||||
type="array",
|
||||
description="Optional for `view` command on files. Defines the line range to display. Examples: `[5, 10]` shows lines 5-10; `[15, -1]` shows from line 15 to EOF. Line numbering starts at 1.",
|
||||
items={"type": "integer"},
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
@override
|
||||
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
|
||||
"""Execute the str_replace_editor tool."""
|
||||
command = str(arguments["command"]) if "command" in arguments else None
|
||||
if command is None:
|
||||
return ToolExecResult(
|
||||
error=f"No command provided for the {self.get_name()} tool",
|
||||
error_code=-1,
|
||||
)
|
||||
path = str(arguments["path"]) if "path" in arguments else None
|
||||
if path is None:
|
||||
return ToolExecResult(
|
||||
error=f"No path provided for the {self.get_name()} tool", error_code=-1
|
||||
)
|
||||
_path = Path(path)
|
||||
try:
|
||||
self.validate_path(command, _path)
|
||||
match command:
|
||||
case "view":
|
||||
return await self._view_handler(arguments, _path)
|
||||
case "create":
|
||||
return self._create_handler(arguments, _path)
|
||||
case "str_replace":
|
||||
return self._str_replace_handler(arguments, _path)
|
||||
case "insert":
|
||||
return self._insert_handler(arguments, _path)
|
||||
case _:
|
||||
return ToolExecResult(
|
||||
error=f"Unrecognized command {command}. The allowed commands for the {self.name} tool are: {', '.join(EditToolSubCommands)}",
|
||||
error_code=-1,
|
||||
)
|
||||
except ToolError as e:
|
||||
return ToolExecResult(error=str(e), error_code=-1)
|
||||
|
||||
def validate_path(self, command: str, path: Path):
|
||||
"""Validate the path for the str_replace_editor tool."""
|
||||
if not path.is_absolute():
|
||||
suggested_path = Path("/") / path
|
||||
raise ToolError(
|
||||
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
||||
)
|
||||
# Check if path exists
|
||||
if not path.exists() and command != "create":
|
||||
raise ToolError(
|
||||
f"The path {path} does not exist. Please provide a valid path."
|
||||
)
|
||||
if path.exists() and command == "create":
|
||||
raise ToolError(
|
||||
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||
)
|
||||
# Check if the path points to a directory
|
||||
if path.is_dir() and command != "view":
|
||||
raise ToolError(
|
||||
f"The path {path} is a directory and only the `view` command can be used on directories"
|
||||
)
|
||||
|
||||
async def _view(
|
||||
self, path: Path, view_range: list[int] | None = None
|
||||
) -> ToolExecResult:
|
||||
"""Implement the view command"""
|
||||
if path.is_dir():
|
||||
if view_range:
|
||||
raise ToolError(
|
||||
"The `view_range` parameter is not allowed when `path` points to a directory."
|
||||
)
|
||||
|
||||
return_code, stdout, stderr = await run(
|
||||
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
|
||||
)
|
||||
if not stderr:
|
||||
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
|
||||
return ToolExecResult(error_code=return_code, output=stdout, error=stderr)
|
||||
|
||||
file_content = self.read_file(path)
|
||||
init_line = 1
|
||||
if view_range:
|
||||
if len(view_range) != 2 or not all(
|
||||
isinstance(i, int) for i in view_range
|
||||
): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
raise ToolError(
|
||||
"Invalid `view_range`. It should be a list of two integers."
|
||||
)
|
||||
file_lines = file_content.split("\n")
|
||||
n_lines_file = len(file_lines)
|
||||
init_line, final_line = view_range
|
||||
if init_line < 1 or init_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
||||
)
|
||||
if final_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
||||
)
|
||||
if final_line != -1 and final_line < init_line:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
||||
)
|
||||
|
||||
if final_line == -1:
|
||||
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||
else:
|
||||
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||
|
||||
return ToolExecResult(
|
||||
output=self._make_output(file_content, str(path), init_line=init_line)
|
||||
)
|
||||
|
||||
def _view_container(
|
||||
self, path: Path, view_range: list[int] | None = None
|
||||
) -> ToolExecResult:
|
||||
"""Implement the view command"""
|
||||
if path.is_dir():
|
||||
raise ToolError("The `path` parameter is not allowed be a directory.")
|
||||
|
||||
file_content = self.container_read_file(path)
|
||||
init_line = 1
|
||||
make_out_max_lines = None
|
||||
if view_range:
|
||||
if len(view_range) != 2 or not all(
|
||||
isinstance(i, int) for i in view_range
|
||||
): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
raise ToolError(
|
||||
"Invalid `view_range`. It should be a list of two integers."
|
||||
)
|
||||
file_lines = file_content.split("\n")
|
||||
n_lines_file = len(file_lines)
|
||||
init_line, final_line = view_range
|
||||
# Initial line must start from 1, initial line cannot be greater than max line of file
|
||||
if init_line < 1 or init_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
||||
)
|
||||
# When the end line takes effect, the end line cannot be less than the start line.
|
||||
if final_line != -1 and final_line < init_line:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
||||
)
|
||||
|
||||
if final_line == -1:
|
||||
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||
elif final_line > n_lines_file:
|
||||
file_content = "\n".join(file_lines[init_line - 1 : n_lines_file])
|
||||
make_out_max_lines = n_lines_file
|
||||
pass
|
||||
else:
|
||||
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||
|
||||
return ToolExecResult(
|
||||
output=self._make_output(
|
||||
file_content,
|
||||
str(path),
|
||||
init_line=init_line,
|
||||
max_lines=make_out_max_lines,
|
||||
)
|
||||
)
|
||||
|
||||
def str_replace(
|
||||
self, path: Path, old_str: str, new_str: str | None
|
||||
) -> ToolExecResult:
|
||||
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
|
||||
# Read the file content
|
||||
file_content = self.read_file(path).expandtabs()
|
||||
old_str = old_str.expandtabs()
|
||||
new_str = new_str.expandtabs() if new_str is not None else ""
|
||||
|
||||
# Check if old_str is unique in the file
|
||||
occurrences = file_content.count(old_str)
|
||||
if occurrences == 0:
|
||||
raise ToolError(
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||
)
|
||||
elif occurrences > 1:
|
||||
file_content_lines = file_content.split("\n")
|
||||
lines = [
|
||||
idx + 1
|
||||
for idx, line in enumerate(file_content_lines)
|
||||
if old_str in line
|
||||
]
|
||||
raise ToolError(
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
new_file_content = file_content.replace(old_str, new_str)
|
||||
|
||||
# Write the new content to the file
|
||||
self.write_file(path, new_file_content)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
replacement_line = file_content.split(old_str)[0].count("\n")
|
||||
start_line = max(0, replacement_line - SNIPPET_LINES)
|
||||
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
||||
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet, f"a snippet of {path}", start_line + 1
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
|
||||
|
||||
return ToolExecResult(
|
||||
output=success_msg,
|
||||
)
|
||||
|
||||
def _insert(self, path: Path, insert_line: int, new_str: str) -> ToolExecResult:
|
||||
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
|
||||
file_text = self.read_file(path).expandtabs()
|
||||
new_str = new_str.expandtabs()
|
||||
file_text_lines = file_text.split("\n")
|
||||
n_lines_file = len(file_text_lines)
|
||||
|
||||
if insert_line < 0 or insert_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
||||
)
|
||||
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_file_text_lines = (
|
||||
file_text_lines[:insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line:]
|
||||
)
|
||||
snippet_lines = (
|
||||
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
new_file_text = "\n".join(new_file_text_lines)
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
self.write_file(path, new_file_text)
|
||||
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet,
|
||||
"a snippet of the edited file",
|
||||
max(1, insert_line - SNIPPET_LINES + 1),
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||
return ToolExecResult(
|
||||
output=success_msg,
|
||||
)
|
||||
|
||||
# Note: undo_edit method is not implemented in this version as it was removed
|
||||
|
||||
def read_file(self, path: Path):
|
||||
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
|
||||
try:
|
||||
return path.read_text()
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"In edit_tool, read_file command error, ran into {e} while trying to read {path}, traceback: {format_exc()}."
|
||||
)
|
||||
raise ToolError(f"Ran into {e} while trying to read {path}.") from None
|
||||
|
||||
def write_file(self, path: Path, file: str):
|
||||
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
|
||||
try:
|
||||
_ = path.write_text(file)
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"In edit_tool, write_file command error, ran into {e} while trying to write to {path}, traceback: {format_exc()}."
|
||||
)
|
||||
raise ToolError(f"Ran into {e} while trying to write to {path}.") from None
|
||||
|
||||
def container_read_file(self, path: Path, session_id: str = "0") -> str:
|
||||
"""Read the content of a file from a container using cat command."""
|
||||
if not self.executor:
|
||||
raise ToolError("No executor provided for container operations")
|
||||
|
||||
try:
|
||||
# Check if session is alive and restart if needed
|
||||
if not self.executor.check_session(session_id):
|
||||
raise ToolError(
|
||||
"Container session is not alive and could not be restarted"
|
||||
)
|
||||
|
||||
# Use cat command to read file content
|
||||
command = f"cat {path}"
|
||||
# return_code, output = self.executor.execute(session_id, command)
|
||||
return_code, output = self.executor.execute_once(command)
|
||||
|
||||
if return_code != 0:
|
||||
raise ToolError(
|
||||
f"Failed to read file {path} from container. Exit code: {return_code}, Output: {output}"
|
||||
)
|
||||
|
||||
# Clean the output by removing only the command echo, preserving file content exactly
|
||||
# lines = output.split("\n")
|
||||
# Remove the first line if it contains the command echo
|
||||
# if lines and f"cat {path}" in lines[0]:
|
||||
# lines = lines[2:-1]
|
||||
final = output[:-1] if output.endswith("\n") else output
|
||||
|
||||
# return "\n".join(lines)
|
||||
return final
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"In edit_tool, container_read_file command error, ran into {e} while trying to read {path} from container, traceback: {format_exc()}."
|
||||
)
|
||||
raise ToolError(
|
||||
f"Ran into {e} while trying to read {path} from container."
|
||||
) from None
|
||||
|
||||
def container_write_file(
|
||||
self, path: Path, content: str, session_id: str = "0"
|
||||
) -> None:
|
||||
"""Write content to a file in a container using cat with here document."""
|
||||
if not self.executor:
|
||||
raise ToolError("No executor provided for container operations")
|
||||
|
||||
try:
|
||||
# Check if session is alive and restart if needed
|
||||
if not self.executor.check_session():
|
||||
raise ToolError(
|
||||
"Container session is not alive and could not be restarted"
|
||||
)
|
||||
# 先创建目录
|
||||
return_code, output = self.executor.execute_once(f"mkdir -p {path.parent}")
|
||||
if return_code != 0:
|
||||
raise ToolError(
|
||||
f"Failed to create dir {path.parent} in container. Exit code: {return_code}, Output: {output}"
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w+", delete=False, encoding="utf-8"
|
||||
) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_file_path = temp_file.name
|
||||
return_code, output = self.executor.cpfile_host_to_container(
|
||||
temp_file_path, path
|
||||
)
|
||||
os.remove(temp_file_path)
|
||||
|
||||
if return_code != 0:
|
||||
raise ToolError(
|
||||
f"Failed to write to file {path} in container. Exit code: {return_code}, Output: {output}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"In edit_tool, container_write_file command error, ran into {e} while trying to write to {path} in container, traceback: {format_exc()}."
|
||||
)
|
||||
raise ToolError(
|
||||
f"Ran into {e} while trying to write to {path} in container."
|
||||
) from None
|
||||
|
||||
def container_str_replace(
|
||||
self, path: Path, old_str: str, new_str: str | None, session_id: str = "0"
|
||||
) -> ToolExecResult:
|
||||
"""Replace old_str with new_str in a file in a container using sed command."""
|
||||
if not self.executor:
|
||||
raise ToolError("No executor provided for container operations")
|
||||
|
||||
try:
|
||||
# Check if session is alive and restart if needed
|
||||
if not self.executor.check_session():
|
||||
raise ToolError(
|
||||
"Container session is not alive and could not be restarted"
|
||||
)
|
||||
|
||||
# First, read the file to check if old_str exists
|
||||
file_content = self.container_read_file(path, session_id)
|
||||
|
||||
# Check if old_str is unique in the file
|
||||
occurrences = file_content.count(old_str)
|
||||
if occurrences == 0:
|
||||
raise ToolError(
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||
)
|
||||
elif occurrences > 1:
|
||||
# Here the calculation is wrong, old_str could be a multi-line text, cannot calculate all locations through the code below
|
||||
# file_content_lines = file_content.split("\n")
|
||||
# lines = [
|
||||
# idx + 1
|
||||
# for idx, line in enumerate(file_content_lines)
|
||||
# if old_str in line
|
||||
# ]
|
||||
raise ToolError(
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}`. Total occurrences: {occurrences}. Please ensure it is unique"
|
||||
)
|
||||
|
||||
updated_content = file_content.replace(old_str, new_str)
|
||||
self.container_write_file(path=path, content=updated_content)
|
||||
|
||||
# Read the file to show a snippet of the changes
|
||||
try:
|
||||
file_content = self.container_read_file(path, session_id)
|
||||
# Create a simple snippet showing the change
|
||||
lines = file_content.split("\n")
|
||||
snippet_lines = lines[: min(10, len(lines))] # Show first 10 lines
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
success_msg = f"The file {path} has been edited in container. "
|
||||
success_msg += self._make_output(
|
||||
snippet, f"a snippet of {path}", init_line=1
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
|
||||
|
||||
return ToolExecResult(output=success_msg)
|
||||
except Exception:
|
||||
# If we can't read the file for snippet, just return success
|
||||
return ToolExecResult(
|
||||
output=f"Successfully replaced string in file {path} in container."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"In edit_tool, container_str_replace command error, ran into {e} while trying to replace string in {path} in container, traceback: {format_exc()}."
|
||||
)
|
||||
raise ToolError(
|
||||
f"Ran into {e} while trying to replace string in {path} in container."
|
||||
) from None
|
||||
|
||||
def container_insert(
|
||||
self, path: Path, insert_line: int, new_str: str, session_id: str = "0"
|
||||
) -> ToolExecResult:
|
||||
if not self.executor:
|
||||
raise ToolError("No executor provided for container operations")
|
||||
|
||||
try:
|
||||
if not self.executor.check_session():
|
||||
raise ToolError(
|
||||
"Container session is not alive and could not be restarted"
|
||||
)
|
||||
|
||||
file_content = self.container_read_file(path, session_id)
|
||||
file_text_lines = file_content.split("\n")
|
||||
n_lines_file = len(file_text_lines)
|
||||
|
||||
if insert_line < 0 or insert_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
||||
)
|
||||
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_file_text_lines = (
|
||||
file_text_lines[:insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line:]
|
||||
)
|
||||
snippet_lines = (
|
||||
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
new_file_text = "\n".join(new_file_text_lines)
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
self.container_write_file(path, new_file_text, session_id)
|
||||
|
||||
success_msg = f"The file {path} has been edited in container. "
|
||||
success_msg += self._make_output(
|
||||
snippet,
|
||||
"a snippet of the edited file",
|
||||
max(1, insert_line - SNIPPET_LINES + 1),
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||
return ToolExecResult(
|
||||
output=success_msg,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(
|
||||
f"In edit_tool, container_insert command error, ran into {e} while trying to insert content in {path} in container, traceback: {format_exc()}."
|
||||
)
|
||||
raise ToolError(
|
||||
f"Ran into {e} while trying to insert content in {path} in container."
|
||||
) from None
|
||||
|
||||
def _escape_sed(self, text: str) -> str:
|
||||
"""Escape special characters in text for use with sed command."""
|
||||
# Escape sed special characters: / \ &
|
||||
escaped = text.replace("\\", "\\\\") # Escape backslashes first
|
||||
escaped = escaped.replace("/", "\\/") # Escape forward slashes
|
||||
escaped = escaped.replace("&", "\\&") # Escape ampersands
|
||||
escaped = escaped.replace("\n", "\\n") # Handle newlines
|
||||
escaped = escaped.replace("\t", "\\t") # Handle tabs
|
||||
return escaped
|
||||
|
||||
def _make_output(
|
||||
self,
|
||||
file_content: str,
|
||||
file_descriptor: str,
|
||||
init_line: int = 1,
|
||||
expand_tabs: bool = True,
|
||||
max_lines: Optional[int] = None,
|
||||
):
|
||||
"""Generate output for the CLI based on the content of a file."""
|
||||
file_content = maybe_truncate(file_content)
|
||||
if expand_tabs:
|
||||
file_content = file_content.expandtabs()
|
||||
file_content = "\n".join(
|
||||
[
|
||||
f"{i + init_line:6}\t{line}"
|
||||
for i, line in enumerate(file_content.split("\n"))
|
||||
]
|
||||
)
|
||||
if max_lines:
|
||||
return (
|
||||
f"Here's the result of running `cat -n` on {file_descriptor}(The file is only {max_lines} lines):\n"
|
||||
+ file_content
|
||||
+ "\n"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
|
||||
+ file_content
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
async def _view_handler(
|
||||
self, arguments: ToolCallArguments, _path: Path
|
||||
) -> ToolExecResult:
|
||||
view_range = arguments.get("view_range", None)
|
||||
if view_range is None:
|
||||
return await self._view(_path, None)
|
||||
if not (
|
||||
isinstance(view_range, list) and all(isinstance(i, int) for i in view_range)
|
||||
):
|
||||
return ToolExecResult(
|
||||
error="Parameter `view_range` should be a list of integers.",
|
||||
error_code=-1,
|
||||
)
|
||||
view_range_int: list[int] = [i for i in view_range if isinstance(i, int)]
|
||||
return await self._view(_path, view_range_int)
|
||||
|
||||
def view_handler_container(
|
||||
self, arguments: ToolCallArguments, path: Path
|
||||
) -> ToolExecResult:
|
||||
view_range = arguments.get("view_range", None)
|
||||
if view_range is None:
|
||||
return self._view_container(path, None)
|
||||
if not (
|
||||
isinstance(view_range, list) and all(isinstance(i, int) for i in view_range)
|
||||
):
|
||||
return ToolExecResult(
|
||||
error="Parameter `view_range` should be a list of integers.",
|
||||
error_code=-1,
|
||||
)
|
||||
view_range_int: list[int] = [i for i in view_range if isinstance(i, int)]
|
||||
return self._view_container(path, view_range_int)
|
||||
|
||||
def _create_handler(
|
||||
self, arguments: ToolCallArguments, _path: Path
|
||||
) -> ToolExecResult:
|
||||
file_text = arguments.get("file_text", None)
|
||||
if not isinstance(file_text, str):
|
||||
return ToolExecResult(
|
||||
error="Parameter `file_text` is required and must be a string for command: create",
|
||||
error_code=-1,
|
||||
)
|
||||
self.write_file(_path, file_text)
|
||||
return ToolExecResult(output=f"File created successfully at: {_path}")
|
||||
|
||||
def _str_replace_handler(
|
||||
self, arguments: ToolCallArguments, _path: Path
|
||||
) -> ToolExecResult:
|
||||
old_str = arguments.get("old_str") if "old_str" in arguments else None
|
||||
if not isinstance(old_str, str):
|
||||
return ToolExecResult(
|
||||
error="Parameter `old_str` is required and should be a string for command: str_replace",
|
||||
error_code=-1,
|
||||
)
|
||||
new_str = arguments.get("new_str") if "new_str" in arguments else None
|
||||
if not (new_str is None or isinstance(new_str, str)):
|
||||
return ToolExecResult(
|
||||
error="Parameter `new_str` should be a string or null for command: str_replace",
|
||||
error_code=-1,
|
||||
)
|
||||
return self.str_replace(_path, old_str, new_str)
|
||||
|
||||
def _insert_handler(
|
||||
self, arguments: ToolCallArguments, _path: Path
|
||||
) -> ToolExecResult:
|
||||
insert_line = (
|
||||
arguments.get("insert_line") if "insert_line" in arguments else None
|
||||
)
|
||||
if not isinstance(insert_line, int):
|
||||
return ToolExecResult(
|
||||
error="Parameter `insert_line` is required and should be integer for command: insert",
|
||||
error_code=-1,
|
||||
)
|
||||
new_str_to_insert = arguments.get("new_str") if "new_str" in arguments else None
|
||||
if not isinstance(new_str_to_insert, str):
|
||||
return ToolExecResult(
|
||||
error="Parameter `new_str` is required for command: insert",
|
||||
error_code=-1,
|
||||
)
|
||||
return self._insert(_path, insert_line, new_str_to_insert)
|
||||
198
src/tools/executor.py
Normal file
198
src/tools/executor.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import subprocess
|
||||
import uuid
|
||||
import docker
|
||||
import pexpect
|
||||
import re
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from src.managers.log.logger import Logger
|
||||
|
||||
|
||||
class Executor:
|
||||
def __init__(self, image: str, logger: Logger | None = None):
|
||||
self.image = image
|
||||
self.container = None
|
||||
self.sessions: dict[str, pexpect.spawn] = {}
|
||||
self.client = docker.from_env()
|
||||
self.logger = logger
|
||||
|
||||
try:
|
||||
self.client.images.get(self.image)
|
||||
except ImageNotFound:
|
||||
raise DockerException(
|
||||
f"Image '{self.image}' not found. Please build the image first."
|
||||
)
|
||||
|
||||
try:
|
||||
self.container = self.client.containers.run(
|
||||
self.image,
|
||||
command="sleep infinity",
|
||||
detach=True,
|
||||
working_dir="/workspace",
|
||||
)
|
||||
self.logger.info(f"Created container {self.container.id}")
|
||||
except DockerException as e:
|
||||
raise DockerException(
|
||||
f"Failed to create container with image '{self.image}': {e}"
|
||||
)
|
||||
|
||||
session_id = self.init_session()
|
||||
if session_id is None:
|
||||
raise DockerException("Failed to initialize default session")
|
||||
if session_id in self.sessions:
|
||||
self.sessions["0"] = self.sessions.pop(session_id)
|
||||
|
||||
def init_session(self) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
command = f"docker exec -it {self.container.id} /bin/bash"
|
||||
|
||||
for attempt in range(3): # Retry up to 3 times
|
||||
try:
|
||||
shell = pexpect.spawn(command, encoding="utf-8", timeout=120)
|
||||
shell.expect([r"\$.*", r"#.*"], timeout=120)
|
||||
|
||||
# Source conda and activate testbed environment
|
||||
shell.sendline("source /opt/miniconda3/bin/activate")
|
||||
shell.expect([r"\$.*", r"#.*"], timeout=30)
|
||||
|
||||
shell.sendline("conda activate testbed")
|
||||
shell.expect([r"\$.*", r"#.*"], timeout=30)
|
||||
|
||||
shell.sendline("export NO_COLOR=1 && export PAGER=cat")
|
||||
shell.expect([r"\$.*", r"#.*"], timeout=30)
|
||||
|
||||
# Verify conda environment is alive by checking the full output
|
||||
# The output should contain (testbed) if the environment is activated
|
||||
# We can check this by looking at the full output from the conda activate command
|
||||
output = shell.before
|
||||
if "(testbed)" not in output:
|
||||
# Environment not properly activated, retry
|
||||
if attempt < 2: # Not the last attempt
|
||||
shell.close(force=True)
|
||||
continue
|
||||
else:
|
||||
shell.close(force=True)
|
||||
raise DockerException(
|
||||
"Failed to activate conda environment 'testbed' after 3 attempts"
|
||||
)
|
||||
|
||||
self.sessions[session_id] = shell
|
||||
return session_id
|
||||
|
||||
except pexpect.exceptions.TIMEOUT:
|
||||
if attempt < 2: # Not the last attempt
|
||||
if "shell" in locals() and shell.isalive():
|
||||
shell.close(force=True)
|
||||
continue
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
if attempt < 2: # Not the last attempt
|
||||
if "shell" in locals() and shell.isalive():
|
||||
shell.close(force=True)
|
||||
continue
|
||||
else:
|
||||
raise DockerException(
|
||||
f"Failed to initialize session after 3 attempts: {e}"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def execute(
|
||||
self, session_id: str, command: str, timeout: int = 300
|
||||
) -> tuple[int, str]:
|
||||
shell = self.sessions.get(session_id)
|
||||
if not shell or not shell.isalive():
|
||||
return -1, "Session not found or is dead."
|
||||
|
||||
full_command = command.strip()
|
||||
shell.sendline(full_command)
|
||||
marker = f"---CMD_DONE---"
|
||||
marker_command = f"echo {marker}$?"
|
||||
shell.sendline(marker_command)
|
||||
try:
|
||||
shell.expect(marker + r"(\d+).*[\n](.*)", timeout=timeout)
|
||||
except pexpect.exceptions.TIMEOUT:
|
||||
return (
|
||||
-1,
|
||||
f"Error: Command '{command}' timed out after {timeout} seconds. Partial output:\n{shell.before}",
|
||||
)
|
||||
exit_code = int(shell.match.group(1))
|
||||
p = str(shell.match.group(2))
|
||||
all_lines: str = p + shell.before
|
||||
# delete all \r
|
||||
all_lines = re.sub(r"\r", "", all_lines)
|
||||
# Remove some non-color-related terminal control characters.
|
||||
# \x1b[?2004h - tell terminal to activate special paste process
|
||||
# \x1b[?2004l - tell terminal to activate special paste process
|
||||
all_lines = re.sub(r"\x1B\[\?2004[l|h]", "", all_lines)
|
||||
# Strip the last line's echo.
|
||||
all_lines = re.sub(r"\n[^\n]+---CMD_DONE---.*", "", all_lines)
|
||||
# self.logger.info(f"'{[all_lines]}'")
|
||||
return exit_code, all_lines
|
||||
|
||||
def execute_once(self, command: str, timeout: int = 300) -> tuple[int, str]:
|
||||
# cmd = ["docker", "exec", self.container.id, "bash", "-c", command]
|
||||
cmd = ["docker", "exec", "-i", self.container.id, "bash", "-s"]
|
||||
sub = subprocess.run(
|
||||
cmd,
|
||||
encoding="utf-8",
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
input=f"{command}\n",
|
||||
)
|
||||
if sub.returncode != 0:
|
||||
return sub.returncode, sub.stderr
|
||||
return sub.returncode, sub.stdout
|
||||
|
||||
def cpfile_host_to_container(self, source: str, dest: str) -> tuple[int, str]:
|
||||
cmd = ["docker", "cp", source, f"{self.container.id}:{dest}"]
|
||||
sub = subprocess.run(
|
||||
cmd,
|
||||
encoding="utf-8",
|
||||
text=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
self.execute_once(f"chmod 0777 {dest}")
|
||||
if sub.returncode != 0:
|
||||
return sub.returncode, sub.stderr
|
||||
return sub.returncode, sub.stdout
|
||||
|
||||
def check_session(self, session_id: str = "0") -> bool:
|
||||
"""
|
||||
Check whether the current '0' session is alive and restart it if not.
|
||||
"""
|
||||
if session_id in self.sessions:
|
||||
session = self.sessions[session_id]
|
||||
if session and session.isalive():
|
||||
return True
|
||||
else:
|
||||
self.sessions.pop(session_id)
|
||||
|
||||
new_session_id = self.init_session()
|
||||
if new_session_id is None:
|
||||
return False
|
||||
if new_session_id != session_id:
|
||||
self.sessions[session_id] = self.sessions.pop(new_session_id)
|
||||
|
||||
return True
|
||||
|
||||
def close_session(self, session_id: str):
|
||||
if session_id in self.sessions:
|
||||
session = self.sessions.pop(session_id)
|
||||
if session and session.isalive():
|
||||
session.close(force=True)
|
||||
# Session not found - this is not an error condition
|
||||
|
||||
def shutdown(self):
|
||||
for session_id in list(self.sessions.keys()):
|
||||
self.close_session(session_id)
|
||||
|
||||
if self.container:
|
||||
try:
|
||||
self.container.stop()
|
||||
self.container.remove()
|
||||
except DockerException as e:
|
||||
pass # Silently handle cleanup errors
|
||||
self.container = None
|
||||
57
src/tools/run.py
Normal file
57
src/tools/run.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2023 Anthropic
|
||||
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
|
||||
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
|
||||
#
|
||||
# Original file was released under MIT License, with the full license text
|
||||
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
|
||||
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
|
||||
#
|
||||
# This modified file is released under the same license.
|
||||
|
||||
|
||||
"""Utility to run shell commands asynchronously with a timeout."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
|
||||
TRUNCATED_MESSAGE: str = (
|
||||
"<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||
)
|
||||
MAX_RESPONSE_LEN: int = 16000
|
||||
|
||||
|
||||
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||
return (
|
||||
content
|
||||
if not truncate_after or len(content) <= truncate_after
|
||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
async def run(
|
||||
cmd: str,
|
||||
timeout: float | None = 120.0, # seconds
|
||||
truncate_after: int | None = MAX_RESPONSE_LEN,
|
||||
):
|
||||
"""Run a shell command asynchronously with a timeout."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
||||
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
process.kill()
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
||||
404
src/tools/search_tool.py
Normal file
404
src/tools/search_tool.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""Search tool for finding files based on text content using ripgrep (rg)."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
import shlex
|
||||
from typing import override
|
||||
from traceback import format_exc
|
||||
|
||||
from src.tools.base import (
|
||||
Tool,
|
||||
ToolCallArguments,
|
||||
ToolError,
|
||||
ToolExecResult,
|
||||
ToolParameter,
|
||||
SEARCH_TOOL_NAME,
|
||||
)
|
||||
from src.tools.run import run
|
||||
from src.tools.executor import Executor
|
||||
from src.managers.log.logger import Logger
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
"""Tool for searching files based on text content using ripgrep."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: str | None = None,
|
||||
executor: Executor | None = None,
|
||||
logger: Logger | None = None,
|
||||
config: Dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(model_provider, logger, config)
|
||||
self._executor = executor
|
||||
|
||||
@override
|
||||
def get_model_provider(self) -> str | None:
|
||||
return self._model_provider
|
||||
|
||||
@override
|
||||
def get_name(self) -> str:
|
||||
return SEARCH_TOOL_NAME
|
||||
|
||||
@override
|
||||
def get_description(self) -> str:
|
||||
return """Search tool for finding files based on text content
|
||||
* Searches for text patterns in files and directories recursively
|
||||
* Returns file paths, line numbers, and surrounding context
|
||||
* Supports regex patterns and various search options
|
||||
* Provides fast and efficient content searching
|
||||
|
||||
Features:
|
||||
- Pattern matching with full regular expression support
|
||||
- Line number display for all matches
|
||||
- Configurable context lines surrounding each match (before and after).
|
||||
- Filtering by file type.
|
||||
- Control over case sensitivity
|
||||
- Option to include hidden files in searches
|
||||
- Handling of binary files
|
||||
|
||||
Example patterns(All patterns must be valid regular expressions):
|
||||
- Simple text: "function main"
|
||||
- Regex: "def\\s+\\w+\\s*\\("
|
||||
"""
|
||||
|
||||
@override
|
||||
def get_parameters(self) -> list[ToolParameter]:
|
||||
"""Get the parameters for the search tool."""
|
||||
params = [
|
||||
ToolParameter(
|
||||
name="pattern",
|
||||
type="string",
|
||||
description=(
|
||||
"The regular expression pattern to search for within the file content. "
|
||||
"To match literal characters that are also regex metacharacters (e.g., '.', '*', '+', '?', '(', ')', '[', ']', '{', '}', '|', '^', '$', '\\'), "
|
||||
"they must be escaped with a backslash. "
|
||||
"Examples: To find the literal string '(some_value)': '\\(some_value\\)'; To find Python function definitions: 'def\\s+[a-zA-Z_]\\w*\\s*\\('. "
|
||||
),
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="search_path",
|
||||
type="string",
|
||||
description="The directory or file path to search in. Must be an absolute path.",
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="context_lines",
|
||||
type="integer",
|
||||
description="Number of context lines to show before and after each match. Default: 2.",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
type="boolean",
|
||||
name="case_insensitive",
|
||||
description="Whether to perform case-insensitive search. Default: false.",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
type="boolean",
|
||||
name="include_hidden",
|
||||
description="Whether to include hidden files and directories. Default: false.",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
type="boolean",
|
||||
name="include_binary",
|
||||
description="Whether to search in binary files. Default: false.",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
type="string",
|
||||
name="file_types",
|
||||
description="Comma-separated list of file types to search (e.g., 'py,js,md'). Optional.",
|
||||
required=False,
|
||||
),
|
||||
ToolParameter(
|
||||
type="integer",
|
||||
name="max_results",
|
||||
description="Maximum number of results to return per file. Default: 100.",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
return params
|
||||
|
||||
@override
|
||||
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
|
||||
"""Execute the search operation."""
|
||||
try:
|
||||
pattern = str(arguments.get("pattern", ""))
|
||||
if not pattern:
|
||||
return ToolExecResult(
|
||||
error="Pattern parameter is required", error_code=-1
|
||||
)
|
||||
|
||||
search_path_str = str(arguments.get("search_path", ""))
|
||||
if not search_path_str:
|
||||
return ToolExecResult(
|
||||
error="search_path parameter is required", error_code=-1
|
||||
)
|
||||
|
||||
search_path = Path(search_path_str)
|
||||
if not search_path.is_absolute():
|
||||
return ToolExecResult(
|
||||
error=f"Search path must be absolute: {search_path}", error_code=-1
|
||||
)
|
||||
|
||||
if not search_path.exists():
|
||||
return ToolExecResult(
|
||||
error=f"Search path does not exist: {search_path}", error_code=-1
|
||||
)
|
||||
|
||||
# Parse optional parameters
|
||||
context_lines = int(arguments.get("context_lines", 2))
|
||||
case_insensitive = bool(arguments.get("case_insensitive", False))
|
||||
include_hidden = bool(arguments.get("include_hidden", False))
|
||||
include_binary = bool(arguments.get("include_binary", False))
|
||||
file_types = arguments.get("file_types")
|
||||
max_results = int(arguments.get("max_results", 100))
|
||||
|
||||
# Build ripgrep command
|
||||
cmd_parts = ["rg"]
|
||||
|
||||
# Add context lines
|
||||
if context_lines > 0:
|
||||
cmd_parts.extend(["-C", str(context_lines)])
|
||||
|
||||
# Add case sensitivity
|
||||
if case_insensitive:
|
||||
cmd_parts.append("-i")
|
||||
|
||||
# Add hidden files
|
||||
if include_hidden:
|
||||
cmd_parts.append("--hidden")
|
||||
|
||||
# Add binary files
|
||||
if include_binary:
|
||||
cmd_parts.append("--binary")
|
||||
else:
|
||||
cmd_parts.append("--no-binary")
|
||||
|
||||
# Add file types
|
||||
if file_types and isinstance(file_types, str):
|
||||
for file_type in file_types.split(","):
|
||||
file_type = file_type.strip()
|
||||
if file_type:
|
||||
cmd_parts.extend(["-g", f'"*.{file_type}"'])
|
||||
|
||||
# Add line numbers and filename
|
||||
cmd_parts.extend(["-n", "-H"])
|
||||
|
||||
# Add max results
|
||||
cmd_parts.extend(["-m", str(max_results)])
|
||||
|
||||
# Add pattern and search path (quote pattern to handle spaces)
|
||||
cmd_parts.extend([f'"{pattern}"', str(search_path)])
|
||||
|
||||
# Execute the command
|
||||
return_code, stdout, stderr = await run(" ".join(cmd_parts))
|
||||
|
||||
if return_code == 0:
|
||||
# Parse and format results
|
||||
results = self._parse_rg_output(stdout)
|
||||
formatted_output = self._format_results(results, max_results)
|
||||
return ToolExecResult(output=formatted_output)
|
||||
elif return_code == 1:
|
||||
# No matches found
|
||||
return ToolExecResult(output=f"No matches found for pattern: {pattern}")
|
||||
else:
|
||||
# Error occurred
|
||||
error_msg = (
|
||||
stderr if stderr else f"ripgrep exited with code {return_code}"
|
||||
)
|
||||
return ToolExecResult(error=error_msg, error_code=return_code)
|
||||
|
||||
except Exception as e:
|
||||
return ToolExecResult(
|
||||
error=f"Search tool error: {str(e)}",
|
||||
error_code=-1,
|
||||
)
|
||||
|
||||
def container_search(
|
||||
self, arguments: ToolCallArguments, session_id: str = "0"
|
||||
) -> ToolExecResult:
|
||||
if not self._executor:
|
||||
return ToolExecResult(
|
||||
error="No executor provided for container search", error_code=-1
|
||||
)
|
||||
|
||||
try:
|
||||
pattern = str(arguments.get("pattern", ""))
|
||||
if not pattern:
|
||||
return ToolExecResult(
|
||||
error="Pattern parameter is required", error_code=-1
|
||||
)
|
||||
|
||||
search_path_str = str(arguments.get("search_path", ""))
|
||||
if not search_path_str:
|
||||
return ToolExecResult(
|
||||
error="search_path parameter is required", error_code=-1
|
||||
)
|
||||
|
||||
context_lines = int(arguments.get("context_lines", 2))
|
||||
case_insensitive = bool(arguments.get("case_insensitive", False))
|
||||
include_hidden = bool(arguments.get("include_hidden", False))
|
||||
include_binary = bool(arguments.get("include_binary", False))
|
||||
file_types = arguments.get("file_types")
|
||||
max_results = int(arguments.get("max_results", 100))
|
||||
|
||||
cmd_parts = ["rg"]
|
||||
|
||||
if context_lines > 0:
|
||||
cmd_parts.extend(["-C", str(context_lines)])
|
||||
|
||||
if case_insensitive:
|
||||
cmd_parts.append("-i")
|
||||
|
||||
if include_hidden:
|
||||
cmd_parts.append("--hidden")
|
||||
|
||||
if include_binary:
|
||||
cmd_parts.append("--binary")
|
||||
else:
|
||||
cmd_parts.append("--no-binary")
|
||||
|
||||
if file_types and isinstance(file_types, str):
|
||||
for file_type in file_types.split(","):
|
||||
file_type = file_type.strip()
|
||||
if file_type:
|
||||
cmd_parts.extend(["-g", f'"*.{file_type}"'])
|
||||
|
||||
cmd_parts.extend(["-n", "-H"])
|
||||
cmd_parts.extend(["-m", str(max_results * 2)])
|
||||
cmd_parts.extend(["--color=never", "-U"])
|
||||
cmd_parts.extend(["--", shlex.quote(pattern), search_path_str])
|
||||
command = " ".join(cmd_parts)
|
||||
|
||||
return_code, output = self._executor.execute_once(command)
|
||||
if self.logger:
|
||||
self.logger.debug(f"search_tool cmd: {command}")
|
||||
# self.logger.debug(f"DEBUG: SearchTool result - Return code: {return_code}, Output: \n{output}")
|
||||
if return_code == 0:
|
||||
results = self._parse_rg_output(output)
|
||||
# self.logger.debug(f"DEBUG: SearchTool _parse_rg_output results: {results}")
|
||||
formatted_output = self._format_results(results, max_results)
|
||||
# self.logger.debug(f"DEBUG: SearchTool _format_results formatted_output: {formatted_output}")
|
||||
return ToolExecResult(output=formatted_output)
|
||||
elif return_code == 1:
|
||||
return ToolExecResult(output=f"No matches found for pattern: {pattern}")
|
||||
else:
|
||||
return ToolExecResult(
|
||||
error=f"ripgrep exited with code {return_code}. Output: {output}",
|
||||
error_code=return_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolExecResult(
|
||||
error=f"Container search error: {str(e)}",
|
||||
error_code=-1,
|
||||
)
|
||||
|
||||
def _parse_rg_output(self, output: str) -> list[dict]:
|
||||
"""Parse ripgrep output into structured results."""
|
||||
import re
|
||||
|
||||
# Remove ANSI escape codes
|
||||
ansi_escape = re.compile(r"\x1b\[[0-9;]*m")
|
||||
clean_output = ansi_escape.sub("", output)
|
||||
|
||||
results = []
|
||||
current_file = None
|
||||
|
||||
for line in clean_output.split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Check if this is a file path line (no colon, just a path)
|
||||
if ":" not in line and "/" in line and not line.strip().startswith("-"):
|
||||
# This is a file path line
|
||||
current_file = line.strip()
|
||||
continue
|
||||
|
||||
# Parse ripgrep output format: file:line:content or file:line-content
|
||||
if ":" in line:
|
||||
# Split by colon to get file, line info, and content
|
||||
parts = line.split(":", 2)
|
||||
if len(parts) >= 3:
|
||||
file_path = parts[0].strip()
|
||||
line_info = parts[1].strip()
|
||||
content = parts[2].strip()
|
||||
|
||||
# Use current_file if file_path is empty or just a dash
|
||||
if not file_path or file_path == "-":
|
||||
file_path = current_file
|
||||
|
||||
# Check if line_info is a number (match line) or contains dash (context line)
|
||||
if line_info.isdigit():
|
||||
# This is a match line
|
||||
line_num = int(line_info)
|
||||
results.append(
|
||||
{
|
||||
"file": file_path,
|
||||
"line": line_num,
|
||||
"content": content,
|
||||
"full_line": line,
|
||||
"is_match": True,
|
||||
}
|
||||
)
|
||||
elif "-" in line_info:
|
||||
# This is a context line (before/after match)
|
||||
# Extract line number from context line format like "12-15" or "12-"
|
||||
try:
|
||||
line_num = int(line_info.split("-")[0])
|
||||
results.append(
|
||||
{
|
||||
"file": file_path,
|
||||
"line": line_num,
|
||||
"content": content,
|
||||
"full_line": line,
|
||||
"is_match": False,
|
||||
}
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _format_results(self, results: list[dict], max_results: int) -> str:
|
||||
"""Format search results for display."""
|
||||
if not results:
|
||||
return "No matches found."
|
||||
|
||||
# Filter only match lines for counting
|
||||
match_results = [r for r in results if r.get("is_match", True)]
|
||||
limited_results = results[:max_results]
|
||||
|
||||
output_lines = [f"Found {len(match_results)} matches:"]
|
||||
output_lines.append("=" * 50)
|
||||
|
||||
current_file = None
|
||||
for result in limited_results:
|
||||
file_path = result["file"]
|
||||
line_num = result["line"]
|
||||
content = result["content"]
|
||||
is_match = result.get("is_match", True)
|
||||
|
||||
# Add file header if this is a new file
|
||||
if current_file != file_path:
|
||||
current_file = file_path
|
||||
output_lines.append(f"\n📁 {file_path}")
|
||||
output_lines.append("-" * (len(file_path) + 4))
|
||||
|
||||
# Add line with appropriate prefix
|
||||
prefix = " " if is_match else " " # Match lines get no special prefix
|
||||
marker = "▶" if is_match else " " # Mark actual matches
|
||||
output_lines.append(f"{marker} {line_num:4d}: {content}")
|
||||
|
||||
if len(results) > max_results:
|
||||
output_lines.append(f"\n... and {len(results) - max_results} more lines")
|
||||
|
||||
return "\n".join(output_lines)
|
||||
127
src/tools/submit_result_tool.py
Normal file
127
src/tools/submit_result_tool.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Search tool for finding files based on text content using ripgrep (rg)."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from logging import Logger
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, override
|
||||
from traceback import format_exc
|
||||
|
||||
from src.tools.base import (
|
||||
Tool,
|
||||
ToolCallArguments,
|
||||
ToolError,
|
||||
ToolExecResult,
|
||||
ToolParameter,
|
||||
SubmitToolResult,
|
||||
SUBMIT_RESULT_TOOL_NAME,
|
||||
)
|
||||
from src.tools.run import run
|
||||
from src.tools.executor import Executor
|
||||
|
||||
|
||||
class SubmitResultTool(Tool):
|
||||
"""Tool for git diff, not for model to invoke"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: str | None = None,
|
||||
executor: Executor | None = None,
|
||||
logger: Logger | None = None,
|
||||
config: Dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(model_provider, logger, config)
|
||||
self._executor = executor
|
||||
|
||||
@override
|
||||
def get_model_provider(self) -> str | None:
|
||||
return self._model_provider
|
||||
|
||||
@override
|
||||
def get_name(self) -> str:
|
||||
return SUBMIT_RESULT_TOOL_NAME
|
||||
|
||||
@override
|
||||
def get_description(self) -> str:
|
||||
return """
|
||||
Submit the final result to complete the task.
|
||||
|
||||
This tool should be called when you are confident that the issue has been resolved. Simply indicate that you are ready to submit the result - the system will automatically capture the git diff and generate the final patch.
|
||||
|
||||
You don't need to provide the actual patch content manually. Just call this tool to signal completion, and the system will handle the rest.
|
||||
"""
|
||||
|
||||
@override
|
||||
def get_parameters(self) -> list[ToolParameter]:
|
||||
params = [
|
||||
ToolParameter(
|
||||
name="is_task_done",
|
||||
type="boolean",
|
||||
description="Whether the task is done",
|
||||
required=True,
|
||||
),
|
||||
ToolParameter(
|
||||
name="test_status",
|
||||
type="string",
|
||||
description="The status of test execution after applying the patch",
|
||||
required=True,
|
||||
enum=["passed", "failed", "skipped", "error"],
|
||||
),
|
||||
ToolParameter(
|
||||
name="reasoning",
|
||||
type="string",
|
||||
description="Detailed explanation of the logic behind the patch, including root cause analysis and solution approach",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
||||
return params
|
||||
|
||||
@override
|
||||
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
|
||||
"""Execute the tool locally (not supported for submit_result tool)."""
|
||||
return ToolExecResult(
|
||||
error="SubmitResultTool only supports container execution", error_code=-1
|
||||
)
|
||||
|
||||
@override
|
||||
async def container_execute(self, arguments: ToolCallArguments) -> ToolExecResult:
|
||||
if not self._executor:
|
||||
return ToolExecResult(
|
||||
error="No executor provided for git diff tool", error_code=-1
|
||||
)
|
||||
try:
|
||||
is_task_done = arguments.get("is_task_done", False)
|
||||
test_status = arguments.get("test_status", "error")
|
||||
reasoning = arguments.get("reasoning", "")
|
||||
root_path = self.config.get("builder", {}).get("repo_root_path", "/")
|
||||
cmd_parts = ["cd", str(root_path), "&&", "git", "--no-pager", "diff"]
|
||||
command = " ".join(cmd_parts)
|
||||
self.logger.debug(
|
||||
f"DEBUG: GitDiffTool executing command: {command}"
|
||||
) # Debug output
|
||||
return_code, output = self._executor.execute_once(command)
|
||||
self.logger.debug(
|
||||
f"DEBUG: GitDiffTool result - Return code: {return_code}, Output: \n{output}"
|
||||
) # Debug output
|
||||
|
||||
if return_code == 0:
|
||||
submit_result = SubmitToolResult(
|
||||
return_code=return_code,
|
||||
output=output,
|
||||
is_task_done=is_task_done,
|
||||
test_status=test_status,
|
||||
reasoning=reasoning,
|
||||
)
|
||||
return ToolExecResult(output=str(submit_result))
|
||||
else:
|
||||
return ToolExecResult(
|
||||
error=f"GitDiffTool exited with code {return_code}. Output: {output}",
|
||||
error_code=return_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolExecResult(
|
||||
error=f"Container search error: {str(e)}", error_code=-1
|
||||
)
|
||||
Reference in New Issue
Block a user