Add files via upload

This commit is contained in:
公明
2025-11-08 20:32:50 +08:00
committed by GitHub
parent add33e1cf7
commit 2bba007295
18 changed files with 1372 additions and 299 deletions
+45
View File
@@ -0,0 +1,45 @@
package main
import (
"cyberstrike-ai/internal/config"
"fmt"
"os"
)
func main() {
cfg, err := config.Load("config.yaml")
if err != nil {
fmt.Printf("❌ 加载配置失败: %v\n", err)
os.Exit(1)
}
fmt.Printf("✅ 配置加载成功\n")
fmt.Printf(" 工具目录: %s\n", cfg.Security.ToolsDir)
fmt.Printf(" 工具数量: %d\n", len(cfg.Security.Tools))
if len(cfg.Security.Tools) > 0 {
fmt.Printf("\n 已加载的工具:\n")
for _, tool := range cfg.Security.Tools {
status := "❌ 禁用"
if tool.Enabled {
status = "✅ 启用"
}
shortDesc := tool.ShortDescription
if shortDesc == "" {
shortDesc = "(无简短描述,将自动提取)"
}
fmt.Printf(" %s %s\n", status, tool.Name)
fmt.Printf(" 简短描述: %s\n", shortDesc)
if len(tool.Description) > 100 {
fmt.Printf(" 详细描述: %s...\n", tool.Description[:100])
} else {
fmt.Printf(" 详细描述: %s\n", tool.Description)
}
fmt.Printf(" 参数数量: %d\n", len(tool.Parameters))
fmt.Println()
}
} else {
fmt.Printf(" ⚠️ 未加载任何工具\n")
}
}
+8 -151
View File
@@ -20,155 +20,12 @@ database:
path: "data/conversations.db"
security:
tools:
# 示例1: 使用参数定义的工具(推荐方式)
- name: "nmap"
command: "nmap"
args: ["-sT", "-sV", "-sC"] # 固定参数
description: "网络扫描工具,用于发现网络主机和服务"
enabled: true
parameters:
- name: "target"
type: "string"
description: "目标IP地址或域名"
required: true
position: 0 # 位置参数,放在最后
format: "positional"
- name: "ports"
type: "string"
description: "端口范围,例如: 1-1000, 80,443,8080"
required: false
flag: "-p"
format: "flag"
# 示例2: 标志参数工具
- name: "sqlmap"
command: "sqlmap"
description: "SQL注入检测和利用工具"
enabled: true
parameters:
- name: "url"
type: "string"
description: "目标URL,例如: http://example.com/page?id=1"
required: true
flag: "-u"
format: "flag"
- name: "batch"
type: "bool"
description: "非交互模式"
required: false
default: true
flag: "--batch"
format: "flag"
- name: "level"
type: "int"
description: "测试级别 (1-5)"
required: false
default: 3
flag: "--level"
format: "combined" # --level=3
# 示例3: 位置参数工具
- name: "nikto"
command: "nikto"
description: "Web服务器扫描工具"
enabled: true
parameters:
- name: "target"
type: "string"
description: "目标URL或IP地址"
required: true
flag: "-h"
format: "flag"
# 示例4: 简单位置参数
- name: "dirb"
command: "dirb"
description: "Web目录扫描工具"
enabled: true
parameters:
- name: "url"
type: "string"
description: "目标URL"
required: true
position: 0
format: "positional"
- name: "wordlist"
type: "string"
description: "字典文件路径"
required: false
flag: "-w"
format: "flag"
# 示例5: 执行系统命令
- name: "exec"
command: "sh"
args: ["-c"]
description: "执行系统命令(谨慎使用)"
enabled: true
parameters:
- name: "command"
type: "string"
description: "要执行的系统命令"
required: true
position: 0
format: "positional"
- name: "shell"
type: "string"
description: "使用的shell(可选,默认为sh"
required: false
default: "sh"
- name: "workdir"
type: "string"
description: "工作目录"
required: false
# 示例6: 自定义工具 - 使用模板格式
- name: "custom_scanner"
command: "my-scanner"
description: "自定义扫描工具示例"
enabled: false # 默认禁用,需要时启用
parameters:
- name: "target"
type: "string"
description: "扫描目标"
required: true
flag: "--target"
format: "flag"
- name: "mode"
type: "string"
description: "扫描模式"
required: false
default: "normal"
options: ["normal", "aggressive", "stealth"] # 枚举值
flag: "--mode"
format: "combined" # --mode=normal
- name: "threads"
type: "int"
description: "线程数"
required: false
default: 10
flag: "-t"
format: "flag"
- name: "output"
type: "string"
description: "输出文件路径"
required: false
flag: "-o"
format: "template"
template: "-o {value}" # 自定义模板
- name: "verbose"
type: "bool"
description: "详细输出"
required: false
default: false
flag: "-v"
format: "flag" # 布尔值:如果为true,只添加-v,不添加值
# 示例7: 向后兼容 - 不定义parameters,使用旧的硬编码逻辑
# - name: "legacy_tool"
# command: "legacy"
# args: ["--option"]
# description: "旧工具(使用硬编码逻辑)"
# enabled: false
# 工具配置文件目录(推荐方式)
# 系统会自动加载 tools/ 目录下的所有 .yaml 和 .yml 文件
# 每个工具一个配置文件,便于维护和管理
tools_dir: "tools"
# 向后兼容:也可以在主配置文件中直接定义工具
# 如果 tools_dir 和 tools 都配置了,tools_dir 中的工具优先
# tools: []
+222 -110
View File
@@ -193,12 +193,40 @@ type AgentLoopResult struct {
MCPExecutionIDs []string
}
// ProgressCallback 进度回调函数类型
type ProgressCallback func(eventType, message string, data interface{})
// AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, nil)
}
// AgentLoopWithProgress 执行Agent循环(带进度回调)
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, callback ProgressCallback) (*AgentLoopResult, error) {
// 发送进度更新
sendProgress := func(eventType, message string, data interface{}) {
if callback != nil {
callback(eventType, message, data)
}
}
// 系统提示词,指导AI如何处理工具错误
systemPrompt := `你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。
重要:当工具调用失败时,请遵循以下原则:
1. 仔细分析错误信息,理解失败的具体原因
2. 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标
3. 如果参数错误,根据错误提示修正参数后重试
4. 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析
5. 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作
6. 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。`
messages := []ChatMessage{
{
Role: "system",
Content: "你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。当需要执行工具时,使用提供的工具函数。",
Content: systemPrompt,
},
}
@@ -248,6 +276,13 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages
// 获取可用工具
tools := a.getAvailableTools()
// 发送进度更新
if i == 0 {
sendProgress("progress", "正在分析请求并制定测试策略...", nil)
} else {
sendProgress("progress", fmt.Sprintf("正在继续分析(第 %d 轮迭代)...", i+1), nil)
}
// 记录每次调用OpenAI
if i == 0 {
a.logger.Info("调用OpenAI",
@@ -277,6 +312,7 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages
}
// 调用OpenAI
sendProgress("progress", "正在调用AI模型...", nil)
response, err := a.callOpenAI(ctx, messages, tools)
if err != nil {
result.Response = ""
@@ -304,17 +340,46 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages
ToolCalls: choice.Message.ToolCalls,
})
// 发送工具调用进度
sendProgress("progress", fmt.Sprintf("检测到 %d 个工具调用,开始执行...", len(choice.Message.ToolCalls)), nil)
// 执行所有工具调用
for _, toolCall := range choice.Message.ToolCalls {
for idx, toolCall := range choice.Message.ToolCalls {
// 发送工具调用开始事件
toolArgsJSON, _ := json.Marshal(toolCall.Function.Arguments)
sendProgress("tool_call", fmt.Sprintf("正在调用工具: %s", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"arguments": string(toolArgsJSON),
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
})
// 执行工具
execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil {
// 构建详细的错误信息,帮助AI理解问题并做出决策
errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err)
messages = append(messages, ChatMessage{
Role: "tool",
ToolCallID: toolCall.ID,
Content: fmt.Sprintf("工具执行失败: %v", err),
Content: errorMsg,
})
// 发送工具执行失败事件
sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"success": false,
"error": err.Error(),
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
})
a.logger.Warn("工具执行失败,已返回详细错误信息",
zap.String("tool", toolCall.Function.Name),
zap.Error(err),
)
} else {
// 即使工具返回了错误结果(IsError=true),也继续处理,让AI决定下一步
messages = append(messages, ChatMessage{
Role: "tool",
ToolCallID: toolCall.ID,
@@ -324,6 +389,29 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages
if execResult.ExecutionID != "" {
result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID)
}
// 发送工具执行成功事件
resultPreview := execResult.Result
if len(resultPreview) > 200 {
resultPreview = resultPreview[:200] + "..."
}
sendProgress("tool_result", fmt.Sprintf("工具 %s 执行完成", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"success": !execResult.IsError,
"isError": execResult.IsError,
"result": resultPreview,
"executionId": execResult.ExecutionID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
})
// 如果工具返回了错误,记录日志但不中断流程
if execResult.IsError {
a.logger.Warn("工具返回错误结果,但继续处理",
zap.String("tool", toolCall.Function.Name),
zap.String("result", execResult.Result),
)
}
}
}
continue
@@ -337,6 +425,7 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages
// 如果完成,返回结果
if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content
return result, nil
}
@@ -347,117 +436,98 @@ func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages
}
// getAvailableTools 获取可用工具
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗
func (a *Agent) getAvailableTools() []Tool {
// 从MCP服务器获取工具列表
executions := a.mcpServer.GetAllExecutions()
toolNames := make(map[string]bool)
for _, exec := range executions {
toolNames[exec.ToolName] = true
// 从MCP服务器获取所有已注册的工具
mcpTools := a.mcpServer.GetAllTools()
// 转换为OpenAI格式的工具定义
tools := make([]Tool, 0, len(mcpTools))
for _, mcpTool := range mcpTools {
// 使用简短描述(如果存在),否则使用详细描述
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
}
// 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema)
tools = append(tools, Tool{
Type: "function",
Function: FunctionDefinition{
Name: mcpTool.Name,
Description: description, // 使用简短描述减少token消耗
Parameters: convertedSchema,
},
})
}
tools := []Tool{
{
Type: "function",
Function: FunctionDefinition{
Name: "nmap",
Description: "使用nmap进行网络扫描,发现开放端口和服务。支持IP地址、域名或URL(会自动提取域名)。使用TCP连接扫描,不需要root权限。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"target": map[string]interface{}{
"type": "string",
"description": "目标IP地址、域名或URL(如 https://example.com)。如果是URL,会自动提取域名部分。",
},
"ports": map[string]interface{}{
"type": "string",
"description": "要扫描的端口范围,例如: 1-1000 或 80,443,8080。如果不指定,将扫描常用端口。",
},
},
"required": []string{"target"},
},
},
},
{
Type: "function",
Function: FunctionDefinition{
Name: "sqlmap",
Description: "使用sqlmap检测SQL注入漏洞",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{
"type": "string",
"description": "目标URL",
},
},
"required": []string{"url"},
},
},
},
{
Type: "function",
Function: FunctionDefinition{
Name: "nikto",
Description: "使用nikto扫描Web服务器漏洞",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"target": map[string]interface{}{
"type": "string",
"description": "目标URL",
},
},
"required": []string{"target"},
},
},
},
{
Type: "function",
Function: FunctionDefinition{
Name: "dirb",
Description: "使用dirb进行目录扫描",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{
"type": "string",
"description": "目标URL",
},
},
"required": []string{"url"},
},
},
},
{
Type: "function",
Function: FunctionDefinition{
Name: "exec",
Description: "执行系统命令(谨慎使用,仅用于必要的系统操作)",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"command": map[string]interface{}{
"type": "string",
"description": "要执行的系统命令",
},
"shell": map[string]interface{}{
"type": "string",
"description": "使用的shell(可选,默认为sh",
},
"workdir": map[string]interface{}{
"type": "string",
"description": "工作目录(可选)",
},
},
"required": []string{"command"},
},
},
},
}
a.logger.Debug("获取可用工具列表",
zap.Int("count", len(tools)),
)
return tools
}
// convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型
func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} {
if schema == nil {
return schema
}
// 创建新的schema副本
converted := make(map[string]interface{})
for k, v := range schema {
converted[k] = v
}
// 转换properties中的类型
if properties, ok := converted["properties"].(map[string]interface{}); ok {
convertedProperties := make(map[string]interface{})
for propName, propValue := range properties {
if prop, ok := propValue.(map[string]interface{}); ok {
convertedProp := make(map[string]interface{})
for pk, pv := range prop {
if pk == "type" {
// 转换类型
if typeStr, ok := pv.(string); ok {
convertedProp[pk] = a.convertToOpenAIType(typeStr)
} else {
convertedProp[pk] = pv
}
} else {
convertedProp[pk] = pv
}
}
convertedProperties[propName] = convertedProp
} else {
convertedProperties[propName] = propValue
}
}
converted["properties"] = convertedProperties
}
return converted
}
// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型
func (a *Agent) convertToOpenAIType(configType string) string {
switch configType {
case "bool":
return "boolean"
case "int", "integer":
return "number"
case "float", "double":
return "number"
case "string", "array", "object":
return configType
default:
// 默认返回原类型
return configType
}
}
// callOpenAI 调用OpenAI API
func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) {
reqBody := OpenAIRequest{
@@ -546,9 +616,11 @@ func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) {
type ToolExecutionResult struct {
Result string
ExecutionID string
IsError bool // 标记是否为错误结果
}
// executeToolViaMCP 通过MCP执行工具
// 即使工具执行失败,也返回结果而不是错误,让AI能够处理错误情况
func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map[string]interface{}) (*ToolExecutionResult, error) {
a.logger.Info("通过MCP执行工具",
zap.String("tool", toolName),
@@ -557,8 +629,30 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
// 通过MCP服务器调用工具
result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args)
// 如果调用失败(如工具不存在),返回友好的错误信息而不是抛出异常
if err != nil {
return nil, fmt.Errorf("工具执行失败: %w", err)
errorMsg := fmt.Sprintf(`工具调用失败
工具名称: %s
错误类型: 系统错误
错误详情: %v
可能的原因:
- 工具 "%s" 不存在或未启用
- 系统配置问题
- 网络或权限问题
建议:
- 检查工具名称是否正确
- 尝试使用其他替代工具
- 如果这是必需的工具,请向用户说明情况`, toolName, err, toolName)
return &ToolExecutionResult{
Result: errorMsg,
ExecutionID: executionID,
IsError: true,
}, nil // 返回 nil 错误,让调用者处理结果
}
// 格式化结果
@@ -571,6 +665,24 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
return &ToolExecutionResult{
Result: resultText.String(),
ExecutionID: executionID,
IsError: result != nil && result.IsError,
}, nil
}
// formatToolError 格式化工具错误信息,提供更友好的错误描述
func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string {
errorMsg := fmt.Sprintf(`工具执行失败
工具名称: %s
调用参数: %v
错误信息: %v
请分析错误原因并采取以下行动之一:
1. 如果参数错误,请修正参数后重试
2. 如果工具不可用,请尝试使用替代工具
3. 如果这是系统问题,请向用户说明情况并提供建议
4. 如果错误信息中包含有用信息,可以基于这些信息继续分析`, toolName, args, err)
return errorMsg
}
+2
View File
@@ -115,6 +115,8 @@ func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitor
{
// Agent Loop
api.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出
api.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
// 对话历史
api.POST("/conversations", conversationHandler.CreateConversation)
+107 -9
View File
@@ -3,6 +3,8 @@ package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
)
@@ -39,7 +41,8 @@ type OpenAIConfig struct {
}
type SecurityConfig struct {
Tools []ToolConfig `yaml:"tools"`
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
}
type DatabaseConfig struct {
@@ -47,13 +50,14 @@ type DatabaseConfig struct {
}
type ToolConfig struct {
Name string `yaml:"name"`
Command string `yaml:"command"`
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
Description string `yaml:"description"`
Enabled bool `yaml:"enabled"`
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
Name string `yaml:"name"`
Command string `yaml:"command"`
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗)
Description string `yaml:"description"` // 详细描述(用于工具文档)
Enabled bool `yaml:"enabled"`
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
}
// ParameterConfig 参数配置
@@ -81,9 +85,102 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
// 如果配置了工具目录,从目录加载工具配置
if cfg.Security.ToolsDir != "" {
configDir := filepath.Dir(path)
toolsDir := cfg.Security.ToolsDir
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(toolsDir) {
toolsDir = filepath.Join(configDir, toolsDir)
}
tools, err := LoadToolsFromDir(toolsDir)
if err != nil {
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
}
// 合并工具配置:目录中的工具优先,主配置中的工具作为补充
existingTools := make(map[string]bool)
for _, tool := range tools {
existingTools[tool.Name] = true
}
// 添加主配置中不存在于目录中的工具(向后兼容)
for _, tool := range cfg.Security.Tools {
if !existingTools[tool.Name] {
tools = append(tools, tool)
}
}
cfg.Security.Tools = tools
}
return &cfg, nil
}
// LoadToolsFromDir 从目录加载所有工具配置文件
func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
var tools []ToolConfig
// 检查目录是否存在
if _, err := os.Stat(dir); os.IsNotExist(err) {
return tools, nil // 目录不存在时返回空列表,不报错
}
// 读取目录中的所有 .yaml 和 .yml 文件
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("读取工具目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
continue
}
filePath := filepath.Join(dir, name)
tool, err := LoadToolFromFile(filePath)
if err != nil {
// 记录错误但继续加载其他文件
fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err)
continue
}
tools = append(tools, *tool)
}
return tools, nil
}
// LoadToolFromFile 从单个文件加载工具配置
func LoadToolFromFile(path string) (*ToolConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
var tool ToolConfig
if err := yaml.Unmarshal(data, &tool); err != nil {
return nil, fmt.Errorf("解析工具配置失败: %w", err)
}
// 验证必需字段
if tool.Name == "" {
return nil, fmt.Errorf("工具名称不能为空")
}
if tool.Command == "" {
return nil, fmt.Errorf("工具命令不能为空")
}
return &tool, nil
}
func Default() *Config {
return &Config{
Server: ServerConfig{
@@ -104,7 +201,8 @@ func Default() *Config {
Model: "gpt-4",
},
Security: SecurityConfig{
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 加载,不在此硬编码
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
ToolsDir: "tools", // 默认工具目录
},
Database: DatabaseConfig{
Path: "data/conversations.db",
+119
View File
@@ -1,6 +1,8 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"time"
@@ -132,3 +134,120 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
})
}
// StreamEvent 流式事件
type StreamEvent struct {
Type string `json:"type"` // progress, tool_call, tool_result, response, error, done
Message string `json:"message"` // 显示消息
Data interface{} `json:"data,omitempty"`
}
// AgentLoopStream 处理Agent Loop流式请求
func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
// 对于流式请求,也发送SSE格式的错误
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
event := StreamEvent{
Type: "error",
Message: "请求参数错误: " + err.Error(),
}
eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
c.Writer.Flush()
return
}
h.logger.Info("收到Agent Loop流式请求",
zap.String("message", req.Message),
zap.String("conversationId", req.ConversationID),
)
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no") // 禁用nginx缓冲
// 发送初始事件
sendEvent := func(eventType, message string, data interface{}) {
event := StreamEvent{
Type: eventType,
Message: message,
Data: data,
}
eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
c.Writer.Flush()
}
// 如果没有对话ID,创建新对话
conversationID := req.ConversationID
if conversationID == "" {
title := req.Message
if len(title) > 50 {
title = title[:50] + "..."
}
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
sendEvent("error", "创建对话失败: "+err.Error(), nil)
return
}
conversationID = conv.ID
}
// 获取历史消息
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
h.logger.Warn("获取历史消息失败", zap.Error(err))
historyMessages = []database.Message{}
}
// 将数据库消息转换为Agent消息格式
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
}
// 保存用户消息
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 创建进度回调函数
progressCallback := func(eventType, message string, data interface{}) {
sendEvent(eventType, message, data)
}
// 执行Agent Loop,传入进度回调
sendEvent("progress", "正在分析您的请求...", nil)
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), req.Message, agentHistoryMessages, progressCallback)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
sendEvent("error", "执行失败: "+err.Error(), nil)
sendEvent("done", "", nil)
return
}
// 保存助手回复
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.Error(err))
}
// 发送最终响应
sendEvent("response", result.Response, map[string]interface{}{
"mcpExecutionIds": result.MCPExecutionIDs,
"conversationId": conversationID,
})
sendEvent("done", "", map[string]interface{}{
"conversationId": conversationID,
})
}
+12
View File
@@ -361,6 +361,18 @@ func (s *Server) GetStats() map[string]*ToolStats {
return stats
}
// GetAllTools 获取所有已注册的工具(用于Agent动态获取工具列表)
func (s *Server) GetAllTools() []Tool {
s.mu.RLock()
defer s.mu.RUnlock()
tools := make([]Tool, 0, len(s.toolDefs))
for _, tool := range s.toolDefs {
tools = append(tools, tool)
}
return tools
}
// CallTool 直接调用工具(用于内部调用)
func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
s.mu.RLock()
+4 -3
View File
@@ -36,9 +36,10 @@ type Error struct {
// Tool 表示MCP工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema map[string]interface{} `json:"inputSchema"`
Name string `json:"name"`
Description string `json:"description"` // 详细描述
ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗)
InputSchema map[string]interface{} `json:"inputSchema"`
}
// ToolCall 表示工具调用
+45 -4
View File
@@ -150,10 +150,28 @@ func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
toolName := toolConfig.Name
toolConfigCopy := toolConfig
// 使用简短描述(如果存在),否则使用详细描述的前100个字符
shortDesc := toolConfigCopy.ShortDescription
if shortDesc == "" {
// 如果没有简短描述,从详细描述中提取第一行或前100个字符
desc := toolConfigCopy.Description
if len(desc) > 100 {
// 尝试找到第一个换行符
if idx := strings.Index(desc, "\n"); idx > 0 && idx < 100 {
shortDesc = strings.TrimSpace(desc[:idx])
} else {
shortDesc = desc[:100] + "..."
}
} else {
shortDesc = desc
}
}
tool := mcp.Tool{
Name: toolConfigCopy.Name,
Description: toolConfigCopy.Description,
InputSchema: e.buildInputSchema(&toolConfigCopy),
Name: toolConfigCopy.Name,
Description: toolConfigCopy.Description,
ShortDescription: shortDesc,
InputSchema: e.buildInputSchema(&toolConfigCopy),
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
@@ -543,8 +561,11 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
required := []string{}
for _, param := range toolConfig.Parameters {
// 转换类型为OpenAI/JSON Schema标准类型
openAIType := e.convertToOpenAIType(param.Type)
prop := map[string]interface{}{
"type": param.Type,
"type": openAIType,
"description": param.Description,
}
@@ -622,6 +643,26 @@ func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]in
return schema
}
// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型
func (e *Executor) convertToOpenAIType(configType string) string {
switch configType {
case "bool":
return "boolean"
case "int", "integer":
return "number"
case "float", "double":
return "number"
case "string", "array", "object":
return configType
default:
// 默认返回原类型,但记录警告
e.logger.Warn("未知的参数类型,使用原类型",
zap.String("type", configType),
)
return configType
}
}
// Vulnerability 漏洞信息
type Vulnerability struct {
ID string `json:"id"`
+121
View File
@@ -0,0 +1,121 @@
# 工具配置文件说明
## 概述
每个工具现在都有独立的配置文件,存放在 `tools/` 目录下。这种方式使得工具配置更加清晰、易于维护和管理。
## 配置文件格式
每个工具配置文件是一个 YAML 文件,包含以下字段:
### 必需字段
- `name`: 工具名称(唯一标识符)
- `command`: 要执行的命令
- `enabled`: 是否启用(true/false
### 可选字段
- `args`: 固定参数列表(数组)
- `short_description`: 简短描述(一句话说明工具用途,用于工具列表,减少token消耗)
- `description`: 工具详细描述(支持多行文本,用于工具文档和详细说明)
- `parameters`: 参数定义列表
## 工具描述
### 简短描述 (`short_description`)
- **用途**:用于工具列表,减少发送给大模型的token消耗
- **要求**:一句话(20-50字)说明工具的核心用途
- **示例**`"网络扫描工具,用于发现网络主机、开放端口和服务"`
### 详细描述 (`description`)
支持多行文本,应该包含:
1. **工具功能说明**:工具的主要功能
2. **使用场景**:什么情况下使用这个工具
3. **注意事项**:使用时的注意事项和警告
4. **示例**:使用示例(可选)
**重要说明**
- 工具列表发送给大模型时,使用 `short_description`(如果存在)
- 如果没有 `short_description`,系统会自动从 `description` 中提取第一行或前100个字符
- 详细描述可以通过 MCP 的 `resources/read` 接口获取(URI: `tool://tool_name`
这样可以大幅减少token消耗,特别是当工具数量很多时(如100个工具)。
## 参数定义
每个参数可以包含以下字段:
- `name`: 参数名称
- `type`: 参数类型(string, int, bool, array
- `description`: 参数详细描述(支持多行)
- `required`: 是否必需(true/false
- `default`: 默认值
- `flag`: 命令行标志(如 "-u", "--url"
- `position`: 位置参数的位置(整数)
- `format`: 参数格式("flag", "positional", "combined", "template"
- `template`: 模板字符串(用于 format="template"
- `options`: 可选值列表(用于枚举类型)
### 参数描述要求
参数描述应该包含:
1. **参数用途**:这个参数是做什么的
2. **格式要求**:参数值的格式要求
3. **示例值**:具体的示例值
4. **注意事项**:使用时需要注意的事项
## 示例
参考 `tools/` 目录下的现有工具配置文件:
- `nmap.yaml`: 网络扫描工具
- `sqlmap.yaml`: SQL注入检测工具
- `nikto.yaml`: Web服务器扫描工具
- `dirb.yaml`: Web目录扫描工具
- `exec.yaml`: 系统命令执行工具
## 添加新工具
要添加新工具,只需在 `tools/` 目录下创建一个新的 YAML 文件,例如 `my_tool.yaml`
```yaml
name: "my_tool"
command: "my-command"
enabled: true
short_description: "一句话说明工具用途" # 简短描述(推荐)
description: |
工具详细描述...
**功能:**
- 功能1
- 功能2
**使用场景:**
- 场景1
- 场景2
parameters:
- name: "param1"
type: "string"
description: |
参数详细描述...
**示例值:**
- "value1"
- "value2"
required: true
flag: "-p"
format: "flag"
```
保存文件后,重启服务即可自动加载新工具。
## 禁用工具
要禁用某个工具,只需将配置文件中的 `enabled` 字段设置为 `false`,或者直接删除/重命名配置文件。
+86
View File
@@ -0,0 +1,86 @@
name: "dirb"
command: "dirb"
enabled: true
# 简短描述(用于工具列表,减少token消耗)
short_description: "Web目录和文件扫描工具,通过暴力破解方式发现Web服务器上的隐藏目录和文件"
# 工具详细描述
description: |
Web目录和文件扫描工具,通过暴力破解方式发现Web服务器上的隐藏目录和文件。
**主要功能:**
- 目录和文件发现
- 支持自定义字典文件
- 检测常见的Web目录结构
- 识别备份文件、配置文件等敏感文件
- 支持多种HTTP方法
**使用场景:**
- Web应用目录枚举
- 发现隐藏的管理界面
- 查找备份文件和敏感信息
- 渗透测试中的信息收集
**注意事项:**
- 扫描可能产生大量HTTP请求
- 某些请求可能被WAF拦截
- 建议使用合适的字典文件以提高效率
- 扫描结果需要人工验证
# 参数定义
parameters:
- name: "url"
type: "string"
description: |
目标URL,要扫描的Web服务器地址。
**格式要求:**
- 必须包含协议(http:// 或 https://
- 可以包含基础路径
- 末尾不要带斜杠(除非要扫描特定目录)
**示例值:**
- 基础URL: "http://example.com"
- HTTPS: "https://example.com"
- 带端口: "http://example.com:8080"
- 特定目录: "http://example.com/admin"
- 带路径: "http://example.com/app"
**注意事项:**
- URL必须可访问
- 确保URL格式正确,包含协议前缀
- 必需参数,不能为空
required: true
position: 0
format: "positional"
- name: "wordlist"
type: "string"
description: |
字典文件路径,包含要尝试的目录和文件名列表。
**格式要求:**
- 文件路径,可以是绝对路径或相对路径
- 文件每行一个目录或文件名
- 支持常见的字典文件格式
**示例值:**
- 默认字典: "/usr/share/dirb/wordlists/common.txt"
- 自定义字典: "/path/to/custom-wordlist.txt"
- 常用字典: "/usr/share/wordlists/dirb/common.txt"
**常用字典文件:**
- common.txt: 常见目录和文件
- big.txt: 大型字典
- small.txt: 小型快速字典
- extensions_common.txt: 常见文件扩展名
**注意事项:**
- 如果不指定,将使用默认字典
- 确保字典文件存在且可读
- 大型字典会显著增加扫描时间
required: false
flag: "-w"
format: "flag"
+102
View File
@@ -0,0 +1,102 @@
name: "exec"
command: "sh"
args: ["-c"]
enabled: true
# 简短描述(用于工具列表,减少token消耗)
short_description: "系统命令执行工具,用于执行Shell命令和系统操作(谨慎使用)"
# 工具详细描述
description: |
系统命令执行工具,用于执行Shell命令和系统操作。
**主要功能:**
- 执行任意Shell命令
- 支持bash、sh等shell
- 可以指定工作目录
- 返回命令执行结果
**使用场景:**
- 系统管理和维护
- 自动化脚本执行
- 文件操作和处理
- 系统信息收集
**安全警告:**
- ⚠️ 此工具可以执行任意系统命令,存在安全风险
- ⚠️ 仅应在受控环境中使用
- ⚠️ 所有命令执行都会被记录
- ⚠️ 建议限制可执行的命令范围
- ⚠️ 不要执行不可信的命令
# 参数定义
parameters:
- name: "command"
type: "string"
description: |
要执行的系统命令。可以是任何有效的Shell命令。
**格式要求:**
- 完整的Shell命令
- 可以包含管道、重定向等Shell特性
- 支持环境变量
**示例值:**
- 简单命令: "ls -la"
- 带管道: "ps aux | grep nginx"
- 文件操作: "cat /etc/passwd"
- 网络命令: "curl http://example.com"
- 系统信息: "uname -a"
- 查找文件: "find /var/log -name '*.log'"
**注意事项:**
- 命令会在指定的shell中执行
- 确保命令语法正确
- 注意命令的安全影响
- 必需参数,不能为空
required: true
position: 0
format: "positional"
- name: "shell"
type: "string"
description: |
使用的Shell类型,默认为sh。
**可选值:**
- sh: 标准Shell(默认)
- bash: Bash Shell
- zsh: Z Shell
- 其他系统可用的shell
**示例值:**
- "sh" (默认)
- "bash"
- "zsh"
**注意事项:**
- 确保指定的shell在系统中可用
- 不同shell的命令语法可能略有差异
required: false
default: "sh"
- name: "workdir"
type: "string"
description: |
命令执行的工作目录。如果不指定,使用当前工作目录。
**格式要求:**
- 绝对路径或相对路径
- 目录必须存在
**示例值:**
- "/tmp"
- "/var/log"
- "./data"
- "/home/user/project"
**注意事项:**
- 确保目录存在且有访问权限
- 相对路径相对于程序运行目录
required: false
+58
View File
@@ -0,0 +1,58 @@
name: "nikto"
command: "nikto"
enabled: true
# 简短描述(用于工具列表,减少token消耗)
short_description: "Web服务器扫描工具,用于检测Web服务器和应用程序中的已知漏洞和配置错误"
# 工具详细描述
description: |
Web服务器扫描工具,用于检测Web服务器和应用程序中的已知漏洞、配置错误和潜在安全问题。
**主要功能:**
- 检测Web服务器版本和配置问题
- 识别已知的Web漏洞和CVE
- 检测危险文件和目录
- 检查服务器配置错误
- 识别过时的软件版本
- 检测默认文件和脚本
**使用场景:**
- Web应用安全评估
- 服务器配置审计
- 漏洞扫描和发现
- 渗透测试前期信息收集
**注意事项:**
- 扫描可能产生大量日志,注意日志管理
- 某些扫描可能触发WAF或IDS告警
- 建议在授权范围内使用
- 扫描结果需要人工验证
# 参数定义
parameters:
- name: "target"
type: "string"
description: |
目标URL或IP地址。可以是完整的URL或IP地址。
**格式要求:**
- 可以包含协议(http:// 或 https://
- 可以只提供IP地址或域名
- 如果只提供IP,默认使用http协议
**示例值:**
- 完整URL: "http://example.com"
- HTTPS: "https://example.com"
- IP地址: "192.168.1.1"
- 带端口: "http://example.com:8080"
- 带路径: "http://example.com/admin"
**注意事项:**
- 如果只提供IP,工具会使用http协议
- 建议提供完整URL以确保正确扫描
- 必需参数,不能为空
required: true
flag: "-h"
format: "flag"
+75
View File
@@ -0,0 +1,75 @@
name: "nmap"
command: "nmap"
args: ["-sT", "-sV", "-sC"] # 固定参数:TCP连接扫描、版本检测、默认脚本
enabled: true
# 简短描述(用于工具列表,减少token消耗)- 一句话说明工具用途
short_description: "网络扫描工具,用于发现网络主机、开放端口和服务"
# 工具详细描述 - 帮助大模型理解工具用途和使用场景
description: |
网络映射和端口扫描工具,用于发现网络中的主机、服务和开放端口。
**主要功能:**
- 主机发现:检测网络中的活动主机
- 端口扫描:识别目标主机上开放的端口
- 服务识别:检测运行在端口上的服务类型和版本
- 操作系统检测:识别目标主机的操作系统类型
- 漏洞检测:使用NSE脚本检测常见漏洞
**使用场景:**
- 网络资产发现和枚举
- 安全评估和渗透测试
- 网络故障排查
- 端口和服务审计
**注意事项:**
- 使用 -sT (TCP连接扫描) 而不是 -sS (SYN扫描),因为 -sS 需要root权限
- 扫描速度取决于网络延迟和目标响应
- 某些扫描可能被防火墙或IDS检测到
- 请确保有权限扫描目标网络
# 参数定义
parameters:
- name: "target"
type: "string"
description: |
目标IP地址或域名。可以是单个IP、IP范围、CIDR格式或域名。
**示例值:**
- 单个IP: "192.168.1.1"
- IP范围: "192.168.1.1-100"
- CIDR: "192.168.1.0/24"
- 域名: "example.com"
- URL: "https://example.com" (会自动提取域名部分)
**注意事项:**
- 如果提供URL,会自动提取域名部分
- 确保目标地址格式正确
- 必需参数,不能为空
required: true
position: 0 # 位置参数,放在命令最后
format: "positional"
- name: "ports"
type: "string"
description: |
要扫描的端口范围。可以是单个端口、端口范围、逗号分隔的端口列表,或特殊值。
**示例值:**
- 单个端口: "80"
- 端口范围: "1-1000"
- 多个端口: "80,443,8080,8443"
- 组合: "80,443,8000-9000"
- 常用端口: "1-1024"
- 所有端口: "1-65535"
- 快速扫描: "80,443,22,21,25,53,110,143,993,995"
**注意事项:**
- 如果不指定,将扫描默认的1000个常用端口
- 扫描所有端口(1-65535)会非常耗时
- 建议先扫描常用端口,再根据结果决定是否扫描全部端口
required: false
flag: "-p"
format: "flag"
+100
View File
@@ -0,0 +1,100 @@
name: "sqlmap"
command: "sqlmap"
enabled: true
# 简短描述(用于工具列表,减少token消耗)
short_description: "自动化SQL注入检测和利用工具,用于发现和利用SQL注入漏洞"
# 工具详细描述
description: |
自动化SQL注入检测和利用工具,用于发现和利用SQL注入漏洞。
**主要功能:**
- 自动检测SQL注入漏洞
- 支持多种数据库类型(MySQL, PostgreSQL, Oracle, MSSQL等)
- 自动提取数据库信息(表、列、数据)
- 支持多种注入技术(布尔盲注、时间盲注、联合查询等)
- 支持文件上传/下载、命令执行等高级功能
**使用场景:**
- Web应用安全测试
- SQL注入漏洞检测
- 数据库信息收集
- 渗透测试和漏洞验证
**注意事项:**
- 仅用于授权的安全测试
- 某些操作可能对目标系统造成影响
- 建议在测试环境中先验证
- 使用 --batch 参数避免交互式提示
# 参数定义
parameters:
- name: "url"
type: "string"
description: |
目标URL,包含可能存在SQL注入的参数。
**格式要求:**
- 完整的URL,包含协议(http:// 或 https://
- 必须包含查询参数,参数值用 * 标记注入点
- 或者直接提供完整的URL,sqlmap会自动检测参数
**示例值:**
- 标记注入点: "http://example.com/page?id=1*"
- 完整URL: "http://example.com/page?id=1"
- POST数据: "http://example.com/login" (需要配合data参数)
- Cookie注入: "http://example.com/page" (需要配合cookie参数)
**注意事项:**
- URL必须可访问
- 确保URL格式正确,包含协议前缀
- 如果使用POST请求,需要配合data参数
- 必需参数,不能为空
required: true
flag: "-u"
format: "flag"
- name: "batch"
type: "bool"
description: |
非交互模式,自动选择默认选项,不需要用户输入。
**使用场景:**
- 自动化测试脚本
- 批量扫描
- 避免交互式提示
**注意事项:**
- 建议始终设置为true,避免工具等待用户输入
- 默认值为true
required: false
default: true
flag: "--batch"
format: "flag"
- name: "level"
type: "int"
description: |
测试级别,范围1-5,级别越高测试越全面但耗时越长。
**级别说明:**
- Level 1: 基本测试,快速但可能遗漏漏洞
- Level 2: 默认级别,平衡速度和覆盖率
- Level 3: 扩展测试,更全面的检测
- Level 4: 深度测试,包含更多payload
- Level 5: 最全面测试,包含所有已知技术
**建议:**
- 快速扫描使用1-2
- 常规测试使用3(默认)
- 深度测试使用4-5
**注意事项:**
- 级别越高,请求数量越多,可能被WAF拦截
- 默认值为3
required: false
default: 3
flag: "--level"
format: "combined" # --level=3
+125 -1
View File
@@ -304,7 +304,131 @@ header {
word-break: break-word;
line-height: 1.6;
box-shadow: var(--shadow-sm);
white-space: pre-wrap;
}
/* Markdown 样式 */
.message-bubble p {
margin: 0.5em 0;
}
.message-bubble p:first-child {
margin-top: 0;
}
.message-bubble p:last-child {
margin-bottom: 0;
}
.message-bubble strong,
.message-bubble b {
font-weight: 600;
color: inherit;
}
.message-bubble em,
.message-bubble i {
font-style: italic;
}
.message-bubble code {
background: rgba(0, 0, 0, 0.05);
padding: 2px 6px;
border-radius: 3px;
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
font-size: 0.9em;
}
.message.user .message-bubble code {
background: rgba(255, 255, 255, 0.2);
}
.message-bubble pre {
background: rgba(0, 0, 0, 0.05);
padding: 12px;
border-radius: 6px;
overflow-x: auto;
margin: 0.5em 0;
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace;
font-size: 0.875em;
line-height: 1.5;
}
.message.user .message-bubble pre {
background: rgba(255, 255, 255, 0.15);
}
.message-bubble pre code {
background: none;
padding: 0;
}
.message-bubble ul,
.message-bubble ol {
margin: 0.5em 0;
padding-left: 1.5em;
}
.message-bubble li {
margin: 0.25em 0;
}
.message-bubble blockquote {
border-left: 3px solid var(--border-color);
padding-left: 1em;
margin: 0.5em 0;
color: var(--text-secondary);
}
.message-bubble h1,
.message-bubble h2,
.message-bubble h3,
.message-bubble h4,
.message-bubble h5,
.message-bubble h6 {
margin: 0.8em 0 0.4em 0;
font-weight: 600;
line-height: 1.3;
}
.message-bubble h1:first-child,
.message-bubble h2:first-child,
.message-bubble h3:first-child,
.message-bubble h4:first-child,
.message-bubble h5:first-child,
.message-bubble h6:first-child {
margin-top: 0;
}
.message-bubble h1 {
font-size: 1.5em;
}
.message-bubble h2 {
font-size: 1.3em;
}
.message-bubble h3 {
font-size: 1.1em;
}
.message-bubble hr {
border: none;
border-top: 1px solid var(--border-color);
margin: 1em 0;
}
.message-bubble a {
color: var(--accent-color);
text-decoration: none;
}
.message-bubble a:hover {
text-decoration: underline;
}
.message.user .message-bubble a {
color: rgba(255, 255, 255, 0.9);
text-decoration: underline;
}
.message.user .message-bubble {
+139 -21
View File
@@ -15,11 +15,15 @@ async function sendMessage() {
addMessage('user', message);
input.value = '';
// 显示加载状态
const loadingId = addMessage('system', '正在处理中...');
// 创建进度消息容器
const progressId = addMessage('system', '正在处理中...');
const progressElement = document.getElementById(progressId);
const progressBubble = progressElement.querySelector('.message-bubble');
let assistantMessageId = null;
let mcpExecutionIds = [];
try {
const response = await fetch('/api/agent-loop', {
const response = await fetch('/api/agent-loop/stream', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -30,30 +34,124 @@ async function sendMessage() {
}),
});
const data = await response.json();
if (!response.ok) {
throw new Error('请求失败: ' + response.status);
}
// 移除加载消息
removeMessage(loadingId);
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
if (response.ok) {
// 更新当前对话ID
if (data.conversationId) {
currentConversationId = data.conversationId;
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop(); // 保留最后一个不完整的行
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const eventData = JSON.parse(line.slice(6));
handleStreamEvent(eventData, progressElement, progressBubble, progressId,
() => assistantMessageId, (id) => { assistantMessageId = id; },
() => mcpExecutionIds, (ids) => { mcpExecutionIds = ids; });
} catch (e) {
console.error('解析事件数据失败:', e, line);
}
}
}
}
// 处理剩余的buffer
if (buffer.trim()) {
const lines = buffer.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const eventData = JSON.parse(line.slice(6));
handleStreamEvent(eventData, progressElement, progressBubble, progressId,
() => assistantMessageId, (id) => { assistantMessageId = id; },
() => mcpExecutionIds, (ids) => { mcpExecutionIds = ids; });
} catch (e) {
console.error('解析事件数据失败:', e, line);
}
}
}
}
} catch (error) {
removeMessage(progressId);
addMessage('system', '错误: ' + error.message);
}
}
// 处理流式事件
function handleStreamEvent(event, progressElement, progressBubble, progressId,
getAssistantId, setAssistantId, getMcpIds, setMcpIds) {
switch (event.type) {
case 'progress':
// 更新进度消息
progressBubble.textContent = event.message;
break;
case 'tool_call':
// 显示工具调用信息
const toolInfo = event.data || {};
const toolName = toolInfo.toolName || '未知工具';
const index = toolInfo.index || 0;
const total = toolInfo.total || 0;
progressBubble.innerHTML = `🔧 正在调用工具: <strong>${escapeHtml(toolName)}</strong> (${index}/${total})`;
break;
case 'tool_result':
// 显示工具执行结果
const resultInfo = event.data || {};
const resultToolName = resultInfo.toolName || '未知工具';
const success = resultInfo.success !== false;
const statusIcon = success ? '✅' : '❌';
progressBubble.innerHTML = `${statusIcon} 工具 <strong>${escapeHtml(resultToolName)}</strong> 执行${success ? '完成' : '失败'}`;
break;
case 'response':
// 移除进度消息,显示最终回复
removeMessage(progressId);
const responseData = event.data || {};
const mcpIds = responseData.mcpExecutionIds || [];
setMcpIds(mcpIds);
// 更新对话ID
if (responseData.conversationId) {
currentConversationId = responseData.conversationId;
updateActiveConversation();
}
// 如果有MCP执行ID,显示所有调用
const mcpIds = data.mcpExecutionIds || [];
addMessage('assistant', data.response, mcpIds);
// 添加助手回复
const assistantId = addMessage('assistant', event.message, mcpIds);
setAssistantId(assistantId);
// 刷新对话列表
loadConversations();
} else {
addMessage('system', '错误: ' + (data.error || '未知错误'));
}
} catch (error) {
removeMessage(loadingId);
addMessage('system', '错误: ' + error.message);
break;
case 'error':
// 显示错误
removeMessage(progressId);
addMessage('system', '错误: ' + event.message);
break;
case 'done':
// 完成,确保进度消息已移除
if (progressElement && progressElement.parentNode) {
removeMessage(progressId);
}
// 更新对话ID
if (event.data && event.data.conversationId) {
currentConversationId = event.data.conversationId;
updateActiveConversation();
}
break;
}
}
@@ -88,8 +186,28 @@ function addMessage(role, content, mcpExecutionIds = null) {
// 创建消息气泡
const bubble = document.createElement('div');
bubble.className = 'message-bubble';
// 处理换行和格式化
const formattedContent = content.replace(/\n/g, '<br>');
// 解析 Markdown 格式
let formattedContent;
if (typeof marked !== 'undefined') {
// 使用 marked.js 解析 Markdown
try {
// 配置 marked 选项
marked.setOptions({
breaks: true, // 支持换行
gfm: true, // 支持 GitHub Flavored Markdown
});
formattedContent = marked.parse(content);
} catch (e) {
console.error('Markdown 解析失败:', e);
// 降级处理:转义 HTML 并保留换行
formattedContent = escapeHtml(content).replace(/\n/g, '<br>');
}
} else {
// 如果没有 marked.js,使用简单处理
formattedContent = escapeHtml(content).replace(/\n/g, '<br>');
}
bubble.innerHTML = formattedContent;
contentWrapper.appendChild(bubble);
+2
View File
@@ -86,6 +86,8 @@
</div>
</div>
<!-- Marked.js for Markdown parsing -->
<script src="https://cdn.jsdelivr.net/npm/marked@11.1.1/marked.min.js"></script>
<script src="/static/js/app.js"></script>
</body>
</html>