Files
CyberStrikeAI/internal/agent/agent.go
2025-11-08 22:58:51 +08:00

800 lines
24 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package agent
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// Agent AI代理
type Agent struct {
openAIClient *http.Client
config *config.OpenAIConfig
mcpServer *mcp.Server
logger *zap.Logger
}
// NewAgent 创建新的Agent
func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, logger *zap.Logger) *Agent {
return &Agent{
openAIClient: &http.Client{Timeout: 5 * time.Minute},
config: cfg,
mcpServer: mcpServer,
logger: logger,
}
}
// ChatMessage 聊天消息
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// MarshalJSON 自定义JSON序列化将tool_calls中的arguments转换为JSON字符串
func (cm ChatMessage) MarshalJSON() ([]byte, error) {
// 构建序列化结构
aux := map[string]interface{}{
"role": cm.Role,
}
// 添加content如果存在
if cm.Content != "" {
aux["content"] = cm.Content
}
// 添加tool_call_id如果存在
if cm.ToolCallID != "" {
aux["tool_call_id"] = cm.ToolCallID
}
// 转换tool_calls将arguments转换为JSON字符串
if len(cm.ToolCalls) > 0 {
toolCallsJSON := make([]map[string]interface{}, len(cm.ToolCalls))
for i, tc := range cm.ToolCalls {
// 将arguments转换为JSON字符串
argsJSON := ""
if tc.Function.Arguments != nil {
argsBytes, err := json.Marshal(tc.Function.Arguments)
if err != nil {
return nil, err
}
argsJSON = string(argsBytes)
}
toolCallsJSON[i] = map[string]interface{}{
"id": tc.ID,
"type": tc.Type,
"function": map[string]interface{}{
"name": tc.Function.Name,
"arguments": argsJSON,
},
}
}
aux["tool_calls"] = toolCallsJSON
}
return json.Marshal(aux)
}
// OpenAIRequest OpenAI API请求
type OpenAIRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Tools []Tool `json:"tools,omitempty"`
}
// OpenAIResponse OpenAI API响应
type OpenAIResponse struct {
ID string `json:"id"`
Choices []Choice `json:"choices"`
Error *Error `json:"error,omitempty"`
}
// Choice 选择
type Choice struct {
Message MessageWithTools `json:"message"`
FinishReason string `json:"finish_reason"`
}
// MessageWithTools 带工具调用的消息
type MessageWithTools struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
// Tool OpenAI工具定义
type Tool struct {
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
}
// FunctionDefinition 函数定义
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// Error OpenAI错误
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
}
// ToolCall 工具调用
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function FunctionCall `json:"function"`
}
// FunctionCall 函数调用
type FunctionCall struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// UnmarshalJSON 自定义JSON解析处理arguments可能是字符串或对象的情况
func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
type Alias FunctionCall
aux := &struct {
Name string `json:"name"`
Arguments interface{} `json:"arguments"`
*Alias
}{
Alias: (*Alias)(fc),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
fc.Name = aux.Name
// 处理arguments可能是字符串或对象的情况
switch v := aux.Arguments.(type) {
case map[string]interface{}:
fc.Arguments = v
case string:
// 如果是字符串尝试解析为JSON
if err := json.Unmarshal([]byte(v), &fc.Arguments); err != nil {
// 如果解析失败创建一个包含原始字符串的map
fc.Arguments = map[string]interface{}{
"raw": v,
}
}
case nil:
fc.Arguments = make(map[string]interface{})
default:
// 其他类型尝试转换为map
fc.Arguments = map[string]interface{}{
"value": v,
}
}
return nil
}
// AgentLoopResult Agent Loop执行结果
type AgentLoopResult struct {
Response string
MCPExecutionIDs []string
}
// ProgressCallback 进度回调函数类型
type ProgressCallback func(eventType, message string, data interface{})
// AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, nil)
}
// AgentLoopWithProgress 执行Agent循环带进度回调
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, callback ProgressCallback) (*AgentLoopResult, error) {
// 发送进度更新
sendProgress := func(eventType, message string, data interface{}) {
if callback != nil {
callback(eventType, message, data)
}
}
// 系统提示词指导AI如何处理工具错误
systemPrompt := `你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。
重要:当工具调用失败时,请遵循以下原则:
1. 仔细分析错误信息,理解失败的具体原因
2. 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标
3. 如果参数错误,根据错误提示修正参数后重试
4. 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析
5. 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作
6. 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。`
messages := []ChatMessage{
{
Role: "system",
Content: systemPrompt,
},
}
// 添加历史消息数据库只保存user和assistant消息
a.logger.Info("处理历史消息",
zap.Int("count", len(historyMessages)),
)
addedCount := 0
for i, msg := range historyMessages {
// 只添加有内容的消息
if msg.Content != "" {
messages = append(messages, ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
addedCount++
contentPreview := msg.Content
if len(contentPreview) > 50 {
contentPreview = contentPreview[:50] + "..."
}
a.logger.Info("添加历史消息到上下文",
zap.Int("index", i),
zap.String("role", msg.Role),
zap.String("content", contentPreview),
)
}
}
a.logger.Info("构建消息数组",
zap.Int("historyMessages", len(historyMessages)),
zap.Int("addedMessages", addedCount),
zap.Int("totalMessages", len(messages)),
)
// 添加当前用户消息
messages = append(messages, ChatMessage{
Role: "user",
Content: userInput,
})
result := &AgentLoopResult{
MCPExecutionIDs: make([]string, 0),
}
maxIterations := 30
for i := 0; i < maxIterations; i++ {
// 检查是否是最后一次迭代
isLastIteration := (i == maxIterations-1)
// 获取可用工具
tools := a.getAvailableTools()
// 发送迭代开始事件
if i == 0 {
sendProgress("iteration", "开始分析请求并制定测试策略", map[string]interface{}{
"iteration": i + 1,
"total": maxIterations,
})
} else if isLastIteration {
sendProgress("iteration", fmt.Sprintf("第 %d 轮迭代(最后一次)", i+1), map[string]interface{}{
"iteration": i + 1,
"total": maxIterations,
"isLast": true,
})
} else {
sendProgress("iteration", fmt.Sprintf("第 %d 轮迭代", i+1), map[string]interface{}{
"iteration": i + 1,
"total": maxIterations,
})
}
// 记录每次调用OpenAI
if i == 0 {
a.logger.Info("调用OpenAI",
zap.Int("iteration", i+1),
zap.Int("messagesCount", len(messages)),
)
// 记录前几条消息的内容(用于调试)
for j, msg := range messages {
if j >= 5 { // 只记录前5条
break
}
contentPreview := msg.Content
if len(contentPreview) > 100 {
contentPreview = contentPreview[:100] + "..."
}
a.logger.Debug("消息内容",
zap.Int("index", j),
zap.String("role", msg.Role),
zap.String("content", contentPreview),
)
}
} else {
a.logger.Info("调用OpenAI",
zap.Int("iteration", i+1),
zap.Int("messagesCount", len(messages)),
)
}
// 调用OpenAI
sendProgress("progress", "正在调用AI模型...", nil)
response, err := a.callOpenAI(ctx, messages, tools)
if err != nil {
result.Response = ""
return result, fmt.Errorf("调用OpenAI失败: %w", err)
}
if response.Error != nil {
result.Response = ""
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
}
if len(response.Choices) == 0 {
result.Response = ""
return result, fmt.Errorf("没有收到响应")
}
choice := response.Choices[0]
// 检查是否有工具调用
if len(choice.Message.ToolCalls) > 0 {
// 如果有思考内容,先发送思考事件
if choice.Message.Content != "" {
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
"iteration": i + 1,
})
}
// 添加assistant消息包含工具调用
messages = append(messages, ChatMessage{
Role: "assistant",
Content: choice.Message.Content,
ToolCalls: choice.Message.ToolCalls,
})
// 发送工具调用进度
sendProgress("tool_calls_detected", fmt.Sprintf("检测到 %d 个工具调用", len(choice.Message.ToolCalls)), map[string]interface{}{
"count": len(choice.Message.ToolCalls),
"iteration": i + 1,
})
// 执行所有工具调用
for idx, toolCall := range choice.Message.ToolCalls {
// 发送工具调用开始事件
toolArgsJSON, _ := json.Marshal(toolCall.Function.Arguments)
sendProgress("tool_call", fmt.Sprintf("正在调用工具: %s", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"arguments": string(toolArgsJSON),
"argumentsObj": toolCall.Function.Arguments,
"toolCallId": toolCall.ID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
})
// 执行工具
execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil {
// 构建详细的错误信息帮助AI理解问题并做出决策
errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err)
messages = append(messages, ChatMessage{
Role: "tool",
ToolCallID: toolCall.ID,
Content: errorMsg,
})
// 发送工具执行失败事件
sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"success": false,
"isError": true,
"error": err.Error(),
"toolCallId": toolCall.ID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
})
a.logger.Warn("工具执行失败,已返回详细错误信息",
zap.String("tool", toolCall.Function.Name),
zap.Error(err),
)
} else {
// 即使工具返回了错误结果IsError=true也继续处理让AI决定下一步
messages = append(messages, ChatMessage{
Role: "tool",
ToolCallID: toolCall.ID,
Content: execResult.Result,
})
// 收集执行ID
if execResult.ExecutionID != "" {
result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID)
}
// 发送工具执行成功事件
resultPreview := execResult.Result
if len(resultPreview) > 200 {
resultPreview = resultPreview[:200] + "..."
}
sendProgress("tool_result", fmt.Sprintf("工具 %s 执行完成", toolCall.Function.Name), map[string]interface{}{
"toolName": toolCall.Function.Name,
"success": !execResult.IsError,
"isError": execResult.IsError,
"result": execResult.Result, // 完整结果
"resultPreview": resultPreview, // 预览结果
"executionId": execResult.ExecutionID,
"toolCallId": toolCall.ID,
"index": idx + 1,
"total": len(choice.Message.ToolCalls),
"iteration": i + 1,
})
// 如果工具返回了错误,记录日志但不中断流程
if execResult.IsError {
a.logger.Warn("工具返回错误结果,但继续处理",
zap.String("tool", toolCall.Function.Name),
zap.String("result", execResult.Result),
)
}
}
}
// 如果是最后一次迭代执行完工具后要求AI进行总结
if isLastIteration {
sendProgress("progress", "最后一次迭代:正在生成总结和下一步计划...", nil)
// 添加用户消息要求AI进行总结
messages = append(messages, ChatMessage{
Role: "user",
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
})
// 立即调用OpenAI获取总结
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
summaryChoice := summaryResponse.Choices[0]
if summaryChoice.Message.Content != "" {
result.Response = summaryChoice.Message.Content
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
}
// 如果获取总结失败,跳出循环,让后续逻辑处理
break
}
continue
}
// 添加assistant响应
messages = append(messages, ChatMessage{
Role: "assistant",
Content: choice.Message.Content,
})
// 发送AI思考内容如果没有工具调用
if choice.Message.Content != "" {
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
"iteration": i + 1,
})
}
// 如果是最后一次迭代无论finish_reason是什么都要求AI进行总结
if isLastIteration {
sendProgress("progress", "最后一次迭代:正在生成总结和下一步计划...", nil)
// 添加用户消息要求AI进行总结
messages = append(messages, ChatMessage{
Role: "user",
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
})
// 立即调用OpenAI获取总结
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
summaryChoice := summaryResponse.Choices[0]
if summaryChoice.Message.Content != "" {
result.Response = summaryChoice.Message.Content
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
}
// 如果获取总结失败,使用当前回复作为结果
if choice.Message.Content != "" {
result.Response = choice.Message.Content
return result, nil
}
// 如果都没有内容,跳出循环,让后续逻辑处理
break
}
// 如果完成,返回结果
if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content
return result, nil
}
}
// 如果循环结束仍未返回,说明达到了最大迭代次数
// 尝试最后一次调用AI获取总结
sendProgress("progress", "达到最大迭代次数,正在生成总结...", nil)
finalSummaryPrompt := ChatMessage{
Role: "user",
Content: "已达到最大迭代次数30轮。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试请提供详细的下一步执行计划。请直接回复不要调用工具。",
}
messages = append(messages, finalSummaryPrompt)
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具强制AI直接回复
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
summaryChoice := summaryResponse.Choices[0]
if summaryChoice.Message.Content != "" {
result.Response = summaryChoice.Message.Content
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
}
// 如果无法生成总结,返回友好的提示
result.Response = "已达到最大迭代次数30轮。系统已执行了多轮测试但由于达到迭代上限无法继续自动执行。建议您查看已执行的工具结果或提出新的测试请求以继续测试。"
return result, nil
}
// getAvailableTools 获取可用工具
// 从MCP服务器动态获取工具列表使用简短描述以减少token消耗
func (a *Agent) getAvailableTools() []Tool {
// 从MCP服务器获取所有已注册的工具
mcpTools := a.mcpServer.GetAllTools()
// 转换为OpenAI格式的工具定义
tools := make([]Tool, 0, len(mcpTools))
for _, mcpTool := range mcpTools {
// 使用简短描述(如果存在),否则使用详细描述
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
}
// 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema)
tools = append(tools, Tool{
Type: "function",
Function: FunctionDefinition{
Name: mcpTool.Name,
Description: description, // 使用简短描述减少token消耗
Parameters: convertedSchema,
},
})
}
a.logger.Debug("获取可用工具列表",
zap.Int("count", len(tools)),
)
return tools
}
// convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型
func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} {
if schema == nil {
return schema
}
// 创建新的schema副本
converted := make(map[string]interface{})
for k, v := range schema {
converted[k] = v
}
// 转换properties中的类型
if properties, ok := converted["properties"].(map[string]interface{}); ok {
convertedProperties := make(map[string]interface{})
for propName, propValue := range properties {
if prop, ok := propValue.(map[string]interface{}); ok {
convertedProp := make(map[string]interface{})
for pk, pv := range prop {
if pk == "type" {
// 转换类型
if typeStr, ok := pv.(string); ok {
convertedProp[pk] = a.convertToOpenAIType(typeStr)
} else {
convertedProp[pk] = pv
}
} else {
convertedProp[pk] = pv
}
}
convertedProperties[propName] = convertedProp
} else {
convertedProperties[propName] = propValue
}
}
converted["properties"] = convertedProperties
}
return converted
}
// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型
func (a *Agent) convertToOpenAIType(configType string) string {
switch configType {
case "bool":
return "boolean"
case "int", "integer":
return "number"
case "float", "double":
return "number"
case "string", "array", "object":
return configType
default:
// 默认返回原类型
return configType
}
}
// callOpenAI 调用OpenAI API
func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) {
reqBody := OpenAIRequest{
Model: a.config.Model,
Messages: messages,
}
if len(tools) > 0 {
reqBody.Tools = tools
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+a.config.APIKey)
resp, err := a.openAIClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 记录响应内容(用于调试)
if resp.StatusCode != http.StatusOK {
a.logger.Warn("OpenAI API返回非200状态码",
zap.Int("status", resp.StatusCode),
zap.String("body", string(body)),
)
}
var response OpenAIResponse
if err := json.Unmarshal(body, &response); err != nil {
a.logger.Error("解析OpenAI响应失败",
zap.Error(err),
zap.String("body", string(body)),
)
return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body))
}
return &response, nil
}
// parseToolCall 解析工具调用
func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) {
// 简单解析,实际应该更复杂
// 格式: [TOOL_CALL]tool_name:arg1=value1,arg2=value2
if !strings.HasPrefix(content, "[TOOL_CALL]") {
return nil, fmt.Errorf("不是有效的工具调用格式")
}
parts := strings.Split(content[len("[TOOL_CALL]"):], ":")
if len(parts) < 2 {
return nil, fmt.Errorf("工具调用格式错误")
}
toolName := strings.TrimSpace(parts[0])
argsStr := strings.TrimSpace(parts[1])
args := make(map[string]interface{})
argPairs := strings.Split(argsStr, ",")
for _, pair := range argPairs {
kv := strings.Split(pair, "=")
if len(kv) == 2 {
args[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
}
}
args["_tool_name"] = toolName
return args, nil
}
// ToolExecutionResult 工具执行结果
type ToolExecutionResult struct {
Result string
ExecutionID string
IsError bool // 标记是否为错误结果
}
// executeToolViaMCP 通过MCP执行工具
// 即使工具执行失败也返回结果而不是错误让AI能够处理错误情况
func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map[string]interface{}) (*ToolExecutionResult, error) {
a.logger.Info("通过MCP执行工具",
zap.String("tool", toolName),
zap.Any("args", args),
)
// 通过MCP服务器调用工具
result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args)
// 如果调用失败(如工具不存在),返回友好的错误信息而不是抛出异常
if err != nil {
errorMsg := fmt.Sprintf(`工具调用失败
工具名称: %s
错误类型: 系统错误
错误详情: %v
可能的原因:
- 工具 "%s" 不存在或未启用
- 系统配置问题
- 网络或权限问题
建议:
- 检查工具名称是否正确
- 尝试使用其他替代工具
- 如果这是必需的工具,请向用户说明情况`, toolName, err, toolName)
return &ToolExecutionResult{
Result: errorMsg,
ExecutionID: executionID,
IsError: true,
}, nil // 返回 nil 错误,让调用者处理结果
}
// 格式化结果
var resultText strings.Builder
for _, content := range result.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
return &ToolExecutionResult{
Result: resultText.String(),
ExecutionID: executionID,
IsError: result != nil && result.IsError,
}, nil
}
// formatToolError 格式化工具错误信息,提供更友好的错误描述
func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string {
errorMsg := fmt.Sprintf(`工具执行失败
工具名称: %s
调用参数: %v
错误信息: %v
请分析错误原因并采取以下行动之一:
1. 如果参数错误,请修正参数后重试
2. 如果工具不可用,请尝试使用替代工具
3. 如果这是系统问题,请向用户说明情况并提供建议
4. 如果错误信息中包含有用信息,可以基于这些信息继续分析`, toolName, args, err)
return errorMsg
}