mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-22 23:49:48 +02:00
Add files via upload
This commit is contained in:
+295
-33
@@ -15,6 +15,7 @@ import (
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -196,6 +197,7 @@ type OpenAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIResponse OpenAI API响应
|
||||
@@ -529,6 +531,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
var currentReActInput string
|
||||
|
||||
maxIterations := a.maxIterations
|
||||
thinkingStreamSeq := 0
|
||||
for i := 0; i < maxIterations; i++ {
|
||||
// 先获取本轮可用工具并统计 tools token,再压缩,以便压缩时预留 tools 占用的空间
|
||||
tools := a.getAvailableTools(roleTools)
|
||||
@@ -630,7 +633,28 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 调用OpenAI
|
||||
sendProgress("progress", "正在调用AI模型...", nil)
|
||||
response, err := a.callOpenAI(ctx, messages, tools)
|
||||
thinkingStreamSeq++
|
||||
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
|
||||
thinkingStreamStarted := false
|
||||
|
||||
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
if !thinkingStreamStarted {
|
||||
thinkingStreamStarted = true
|
||||
sendProgress("thinking_stream_start", " ", map[string]interface{}{
|
||||
"streamId": thinkingStreamId,
|
||||
"iteration": i + 1,
|
||||
"toolStream": false,
|
||||
})
|
||||
}
|
||||
sendProgress("thinking_stream_delta", delta, map[string]interface{}{
|
||||
"streamId": thinkingStreamId,
|
||||
"iteration": i + 1,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// API调用失败,保存当前的ReAct输入和错误信息作为输出
|
||||
result.LastReActInput = currentReActInput
|
||||
@@ -682,10 +706,12 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 检查是否有工具调用
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// 如果有思考内容,先发送思考事件
|
||||
// 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重;
|
||||
// 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。
|
||||
if choice.Message.Content != "" {
|
||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||
"iteration": i + 1,
|
||||
"streamId": thinkingStreamId,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -717,7 +743,21 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
})
|
||||
|
||||
// 执行工具
|
||||
execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
toolCtx := context.WithValue(ctx, security.ToolOutputCallbackCtxKey, security.ToolOutputCallback(func(chunk string) {
|
||||
if strings.TrimSpace(chunk) == "" {
|
||||
return
|
||||
}
|
||||
sendProgress("tool_result_delta", chunk, map[string]interface{}{
|
||||
"toolName": toolCall.Function.Name,
|
||||
"toolCallId": toolCall.ID,
|
||||
"index": idx + 1,
|
||||
"total": len(choice.Message.ToolCalls),
|
||||
"iteration": i + 1,
|
||||
// success 在最终 tool_result 事件里会以 success/isError 标记为准
|
||||
})
|
||||
}))
|
||||
|
||||
execResult, err := a.executeToolViaMCP(toolCtx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
// 构建详细的错误信息,帮助AI理解问题并做出决策
|
||||
errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err)
|
||||
@@ -792,16 +832,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
|
||||
})
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
// 立即调用OpenAI获取总结
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 如果获取总结失败,跳出循环,让后续逻辑处理
|
||||
break
|
||||
@@ -817,7 +864,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
})
|
||||
|
||||
// 发送AI思考内容(如果没有工具调用)
|
||||
if choice.Message.Content != "" {
|
||||
if choice.Message.Content != "" && !thinkingStreamStarted {
|
||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||
"iteration": i + 1,
|
||||
})
|
||||
@@ -832,16 +879,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
|
||||
})
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
// 立即调用OpenAI获取总结
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 如果获取总结失败,使用当前回复作为结果
|
||||
if choice.Message.Content != "" {
|
||||
@@ -872,15 +926,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
messages = append(messages, finalSummaryPrompt)
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "max_iter_summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 如果无法生成总结,返回友好的提示
|
||||
@@ -1200,6 +1262,206 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// callOpenAISingleStreamText 单次调用OpenAI的流式模式,只用于“不会调用工具”的纯文本输出(tools 为空时最佳)。
|
||||
// onDelta 每收到一段 content delta,就回调一次;如果 callback 返回错误,会终止读取并返回错误。
|
||||
func (a *Agent) callOpenAISingleStreamText(ctx context.Context, messages []ChatMessage, tools []Tool, onDelta func(delta string) error) (string, error) {
|
||||
reqBody := OpenAIRequest{
|
||||
Model: a.config.Model,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
reqBody.Tools = tools
|
||||
}
|
||||
|
||||
if a.openAIClient == nil {
|
||||
return "", fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
|
||||
return a.openAIClient.ChatCompletionStream(ctx, reqBody, onDelta)
|
||||
}
|
||||
|
||||
// callOpenAIStreamText 调用OpenAI流式模式(带重试),仅在“未输出任何 delta”时才允许重试,避免重复发送已下发的内容。
|
||||
func (a *Agent) callOpenAIStreamText(ctx context.Context, messages []ChatMessage, tools []Tool, onDelta func(delta string) error) (string, error) {
|
||||
maxRetries := 3
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
var deltasSent bool
|
||||
full, err := a.callOpenAISingleStreamText(ctx, messages, tools, func(delta string) error {
|
||||
deltasSent = true
|
||||
return onDelta(delta)
|
||||
})
|
||||
if err == nil {
|
||||
if attempt > 0 {
|
||||
a.logger.Info("OpenAI stream 调用重试成功",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
)
|
||||
}
|
||||
return full, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
// 已经开始输出了 delta,避免重复内容:直接失败让上层处理。
|
||||
if deltasSent {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !a.isRetryableError(err) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
a.logger.Warn("OpenAI stream 调用失败,准备重试",
|
||||
zap.Error(err),
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
zap.Duration("backoff", backoff),
|
||||
)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("上下文已取消: %w", ctx.Err())
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// callOpenAISingleStreamWithToolCalls 单次调用OpenAI流式模式(带工具调用解析),不包含重试逻辑。
|
||||
func (a *Agent) callOpenAISingleStreamWithToolCalls(
|
||||
ctx context.Context,
|
||||
messages []ChatMessage,
|
||||
tools []Tool,
|
||||
onContentDelta func(delta string) error,
|
||||
) (*OpenAIResponse, error) {
|
||||
reqBody := OpenAIRequest{
|
||||
Model: a.config.Model,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
reqBody.Tools = tools
|
||||
}
|
||||
if a.openAIClient == nil {
|
||||
return nil, fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
|
||||
content, streamToolCalls, finishReason, err := a.openAIClient.ChatCompletionStreamWithToolCalls(ctx, reqBody, onContentDelta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolCalls := make([]ToolCall, 0, len(streamToolCalls))
|
||||
for _, stc := range streamToolCalls {
|
||||
fnArgsStr := stc.FunctionArgsStr
|
||||
args := make(map[string]interface{})
|
||||
if strings.TrimSpace(fnArgsStr) != "" {
|
||||
if err := json.Unmarshal([]byte(fnArgsStr), &args); err != nil {
|
||||
// 兼容:arguments 不一定是严格 JSON
|
||||
args = map[string]interface{}{"raw": fnArgsStr}
|
||||
}
|
||||
}
|
||||
|
||||
typ := stc.Type
|
||||
if strings.TrimSpace(typ) == "" {
|
||||
typ = "function"
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: stc.ID,
|
||||
Type: typ,
|
||||
Function: FunctionCall{
|
||||
Name: stc.FunctionName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
response := &OpenAIResponse{
|
||||
ID: "",
|
||||
Choices: []Choice{
|
||||
{
|
||||
Message: MessageWithTools{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// callOpenAIStreamWithToolCalls 调用OpenAI流式模式(带重试),仅当还没有输出任何 content delta 时才允许重试。
|
||||
func (a *Agent) callOpenAIStreamWithToolCalls(
|
||||
ctx context.Context,
|
||||
messages []ChatMessage,
|
||||
tools []Tool,
|
||||
onContentDelta func(delta string) error,
|
||||
) (*OpenAIResponse, error) {
|
||||
maxRetries := 3
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
deltasSent := false
|
||||
resp, err := a.callOpenAISingleStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
||||
deltasSent = true
|
||||
if onContentDelta != nil {
|
||||
return onContentDelta(delta)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
if attempt > 0 {
|
||||
a.logger.Info("OpenAI stream 调用重试成功",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
if deltasSent {
|
||||
// 已经开始输出了 delta:避免重复发送
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !a.isRetryableError(err) {
|
||||
return nil, err
|
||||
}
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
a.logger.Warn("OpenAI stream 调用失败,准备重试",
|
||||
zap.Error(err),
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
zap.Duration("backoff", backoff),
|
||||
)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("上下文已取消: %w", ctx.Err())
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// ToolExecutionResult 工具执行结果
|
||||
type ToolExecutionResult struct {
|
||||
Result string
|
||||
|
||||
Reference in New Issue
Block a user