mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-19 22:38:56 +02:00
Add files via upload
This commit is contained in:
+295
-33
@@ -15,6 +15,7 @@ import (
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -196,6 +197,7 @@ type OpenAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIResponse OpenAI API响应
|
||||
@@ -529,6 +531,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
var currentReActInput string
|
||||
|
||||
maxIterations := a.maxIterations
|
||||
thinkingStreamSeq := 0
|
||||
for i := 0; i < maxIterations; i++ {
|
||||
// 先获取本轮可用工具并统计 tools token,再压缩,以便压缩时预留 tools 占用的空间
|
||||
tools := a.getAvailableTools(roleTools)
|
||||
@@ -630,7 +633,28 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 调用OpenAI
|
||||
sendProgress("progress", "正在调用AI模型...", nil)
|
||||
response, err := a.callOpenAI(ctx, messages, tools)
|
||||
thinkingStreamSeq++
|
||||
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
|
||||
thinkingStreamStarted := false
|
||||
|
||||
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
if !thinkingStreamStarted {
|
||||
thinkingStreamStarted = true
|
||||
sendProgress("thinking_stream_start", " ", map[string]interface{}{
|
||||
"streamId": thinkingStreamId,
|
||||
"iteration": i + 1,
|
||||
"toolStream": false,
|
||||
})
|
||||
}
|
||||
sendProgress("thinking_stream_delta", delta, map[string]interface{}{
|
||||
"streamId": thinkingStreamId,
|
||||
"iteration": i + 1,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// API调用失败,保存当前的ReAct输入和错误信息作为输出
|
||||
result.LastReActInput = currentReActInput
|
||||
@@ -682,10 +706,12 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 检查是否有工具调用
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// 如果有思考内容,先发送思考事件
|
||||
// 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重;
|
||||
// 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。
|
||||
if choice.Message.Content != "" {
|
||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||
"iteration": i + 1,
|
||||
"streamId": thinkingStreamId,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -717,7 +743,21 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
})
|
||||
|
||||
// 执行工具
|
||||
execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
toolCtx := context.WithValue(ctx, security.ToolOutputCallbackCtxKey, security.ToolOutputCallback(func(chunk string) {
|
||||
if strings.TrimSpace(chunk) == "" {
|
||||
return
|
||||
}
|
||||
sendProgress("tool_result_delta", chunk, map[string]interface{}{
|
||||
"toolName": toolCall.Function.Name,
|
||||
"toolCallId": toolCall.ID,
|
||||
"index": idx + 1,
|
||||
"total": len(choice.Message.ToolCalls),
|
||||
"iteration": i + 1,
|
||||
// success 在最终 tool_result 事件里会以 success/isError 标记为准
|
||||
})
|
||||
}))
|
||||
|
||||
execResult, err := a.executeToolViaMCP(toolCtx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
// 构建详细的错误信息,帮助AI理解问题并做出决策
|
||||
errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err)
|
||||
@@ -792,16 +832,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
|
||||
})
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
// 立即调用OpenAI获取总结
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 如果获取总结失败,跳出循环,让后续逻辑处理
|
||||
break
|
||||
@@ -817,7 +864,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
})
|
||||
|
||||
// 发送AI思考内容(如果没有工具调用)
|
||||
if choice.Message.Content != "" {
|
||||
if choice.Message.Content != "" && !thinkingStreamStarted {
|
||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||
"iteration": i + 1,
|
||||
})
|
||||
@@ -832,16 +879,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。",
|
||||
})
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
// 立即调用OpenAI获取总结
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 如果获取总结失败,使用当前回复作为结果
|
||||
if choice.Message.Content != "" {
|
||||
@@ -872,15 +926,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
messages = append(messages, finalSummaryPrompt)
|
||||
messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留
|
||||
|
||||
summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复
|
||||
if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 {
|
||||
summaryChoice := summaryResponse.Choices[0]
|
||||
if summaryChoice.Message.Content != "" {
|
||||
result.Response = summaryChoice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 流式调用OpenAI获取总结(不提供工具,强制AI直接回复)
|
||||
sendProgress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "max_iter_summary",
|
||||
})
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 如果无法生成总结,返回友好的提示
|
||||
@@ -1200,6 +1262,206 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// callOpenAISingleStreamText 单次调用OpenAI的流式模式,只用于“不会调用工具”的纯文本输出(tools 为空时最佳)。
|
||||
// onDelta 每收到一段 content delta,就回调一次;如果 callback 返回错误,会终止读取并返回错误。
|
||||
func (a *Agent) callOpenAISingleStreamText(ctx context.Context, messages []ChatMessage, tools []Tool, onDelta func(delta string) error) (string, error) {
|
||||
reqBody := OpenAIRequest{
|
||||
Model: a.config.Model,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
reqBody.Tools = tools
|
||||
}
|
||||
|
||||
if a.openAIClient == nil {
|
||||
return "", fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
|
||||
return a.openAIClient.ChatCompletionStream(ctx, reqBody, onDelta)
|
||||
}
|
||||
|
||||
// callOpenAIStreamText 调用OpenAI流式模式(带重试),仅在“未输出任何 delta”时才允许重试,避免重复发送已下发的内容。
|
||||
func (a *Agent) callOpenAIStreamText(ctx context.Context, messages []ChatMessage, tools []Tool, onDelta func(delta string) error) (string, error) {
|
||||
maxRetries := 3
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
var deltasSent bool
|
||||
full, err := a.callOpenAISingleStreamText(ctx, messages, tools, func(delta string) error {
|
||||
deltasSent = true
|
||||
return onDelta(delta)
|
||||
})
|
||||
if err == nil {
|
||||
if attempt > 0 {
|
||||
a.logger.Info("OpenAI stream 调用重试成功",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
)
|
||||
}
|
||||
return full, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
// 已经开始输出了 delta,避免重复内容:直接失败让上层处理。
|
||||
if deltasSent {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !a.isRetryableError(err) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
a.logger.Warn("OpenAI stream 调用失败,准备重试",
|
||||
zap.Error(err),
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
zap.Duration("backoff", backoff),
|
||||
)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("上下文已取消: %w", ctx.Err())
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// callOpenAISingleStreamWithToolCalls 单次调用OpenAI流式模式(带工具调用解析),不包含重试逻辑。
|
||||
func (a *Agent) callOpenAISingleStreamWithToolCalls(
|
||||
ctx context.Context,
|
||||
messages []ChatMessage,
|
||||
tools []Tool,
|
||||
onContentDelta func(delta string) error,
|
||||
) (*OpenAIResponse, error) {
|
||||
reqBody := OpenAIRequest{
|
||||
Model: a.config.Model,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
reqBody.Tools = tools
|
||||
}
|
||||
if a.openAIClient == nil {
|
||||
return nil, fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
|
||||
content, streamToolCalls, finishReason, err := a.openAIClient.ChatCompletionStreamWithToolCalls(ctx, reqBody, onContentDelta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolCalls := make([]ToolCall, 0, len(streamToolCalls))
|
||||
for _, stc := range streamToolCalls {
|
||||
fnArgsStr := stc.FunctionArgsStr
|
||||
args := make(map[string]interface{})
|
||||
if strings.TrimSpace(fnArgsStr) != "" {
|
||||
if err := json.Unmarshal([]byte(fnArgsStr), &args); err != nil {
|
||||
// 兼容:arguments 不一定是严格 JSON
|
||||
args = map[string]interface{}{"raw": fnArgsStr}
|
||||
}
|
||||
}
|
||||
|
||||
typ := stc.Type
|
||||
if strings.TrimSpace(typ) == "" {
|
||||
typ = "function"
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: stc.ID,
|
||||
Type: typ,
|
||||
Function: FunctionCall{
|
||||
Name: stc.FunctionName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
response := &OpenAIResponse{
|
||||
ID: "",
|
||||
Choices: []Choice{
|
||||
{
|
||||
Message: MessageWithTools{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// callOpenAIStreamWithToolCalls 调用OpenAI流式模式(带重试),仅当还没有输出任何 content delta 时才允许重试。
|
||||
func (a *Agent) callOpenAIStreamWithToolCalls(
|
||||
ctx context.Context,
|
||||
messages []ChatMessage,
|
||||
tools []Tool,
|
||||
onContentDelta func(delta string) error,
|
||||
) (*OpenAIResponse, error) {
|
||||
maxRetries := 3
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
deltasSent := false
|
||||
resp, err := a.callOpenAISingleStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
||||
deltasSent = true
|
||||
if onContentDelta != nil {
|
||||
return onContentDelta(delta)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
if attempt > 0 {
|
||||
a.logger.Info("OpenAI stream 调用重试成功",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
if deltasSent {
|
||||
// 已经开始输出了 delta:避免重复发送
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !a.isRetryableError(err) {
|
||||
return nil, err
|
||||
}
|
||||
if attempt < maxRetries-1 {
|
||||
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
a.logger.Warn("OpenAI stream 调用失败,准备重试",
|
||||
zap.Error(err),
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("maxRetries", maxRetries),
|
||||
zap.Duration("backoff", backoff),
|
||||
)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("上下文已取消: %w", ctx.Err())
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// ToolExecutionResult 工具执行结果
|
||||
type ToolExecutionResult struct {
|
||||
Result string
|
||||
|
||||
@@ -662,8 +662,16 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
|
||||
}
|
||||
}
|
||||
|
||||
// 保存过程详情到数据库(排除response和done事件,它们会在后面单独处理)
|
||||
if assistantMessageID != "" && eventType != "response" && eventType != "done" {
|
||||
// 保存过程详情到数据库(排除response/done事件,它们会在后面单独处理)
|
||||
// 另外:response_start/response_delta 是模型流式增量,保存会导致过程详情膨胀,因此不落库。
|
||||
if assistantMessageID != "" &&
|
||||
eventType != "response" &&
|
||||
eventType != "done" &&
|
||||
eventType != "response_start" &&
|
||||
eventType != "response_delta" &&
|
||||
eventType != "tool_result_delta" &&
|
||||
eventType != "thinking_stream_start" &&
|
||||
eventType != "thinking_stream_delta" {
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", eventType))
|
||||
}
|
||||
@@ -703,8 +711,53 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
// 发送初始事件
|
||||
// 用于跟踪客户端是否已断开连接
|
||||
clientDisconnected := false
|
||||
// 用于快速确认模型是否真的产生了流式 delta
|
||||
var responseDeltaCount int
|
||||
var responseStartLogged bool
|
||||
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if eventType == "response_start" {
|
||||
responseDeltaCount = 0
|
||||
responseStartLogged = true
|
||||
h.logger.Info("SSE: response_start",
|
||||
zap.Int("conversationIdPresent", func() int {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if v, ok2 := m["conversationId"]; ok2 && v != nil && fmt.Sprint(v) != "" {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}()),
|
||||
zap.String("messageGeneratedBy", func() string {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if v, ok2 := m["messageGeneratedBy"]; ok2 {
|
||||
if s, ok3 := v.(string); ok3 {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprint(v)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}()),
|
||||
)
|
||||
} else if eventType == "response_delta" {
|
||||
responseDeltaCount++
|
||||
// 只打前几条,避免刷屏
|
||||
if responseStartLogged && responseDeltaCount <= 3 {
|
||||
h.logger.Info("SSE: response_delta",
|
||||
zap.Int("index", responseDeltaCount),
|
||||
zap.Int("deltaLen", len(message)),
|
||||
zap.String("deltaPreview", func() string {
|
||||
p := strings.ReplaceAll(message, "\n", "\\n")
|
||||
if len(p) > 80 {
|
||||
return p[:80] + "..."
|
||||
}
|
||||
return p
|
||||
}()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果客户端已断开,不再发送事件
|
||||
if clientDisconnected {
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
@@ -142,3 +143,342 @@ func (c *Client) ChatCompletion(ctx context.Context, payload interface{}, out in
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChatCompletionStream 调用 /chat/completions 的流式模式(stream=true),并在每个 delta 到达时回调 onDelta。
|
||||
// 返回最终拼接的 content(只拼 content delta;工具调用 delta 未做处理)。
|
||||
func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{}, onDelta func(delta string) error) (string, error) {
|
||||
if c == nil {
|
||||
return "", fmt.Errorf("openai client is not initialized")
|
||||
}
|
||||
if c.config == nil {
|
||||
return "", fmt.Errorf("openai config is nil")
|
||||
}
|
||||
if strings.TrimSpace(c.config.APIKey) == "" {
|
||||
return "", fmt.Errorf("openai api key is empty")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal openai payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("build openai request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
||||
|
||||
requestStart := time.Now()
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("call openai api: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 非200:读完 body 返回
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return "", &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
}
|
||||
}
|
||||
|
||||
type streamDelta struct {
|
||||
// OpenAI 兼容流式通常使用 content;但部分兼容实现可能用 text。
|
||||
Content string `json:"content,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
type streamChoice struct {
|
||||
Delta streamDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
type streamResponse struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Choices []streamChoice `json:"choices"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var full strings.Builder
|
||||
|
||||
// 典型 SSE 结构:
|
||||
// data: {...}\n\n
|
||||
// data: [DONE]\n\n
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
if readErr == io.EOF {
|
||||
break
|
||||
}
|
||||
return full.String(), fmt.Errorf("read openai stream: %w", readErr)
|
||||
}
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "data:") {
|
||||
continue
|
||||
}
|
||||
dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||
if dataStr == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chunk streamResponse
|
||||
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
|
||||
// 解析失败跳过(兼容各种兼容层的差异)
|
||||
continue
|
||||
}
|
||||
if chunk.Error != nil && strings.TrimSpace(chunk.Error.Message) != "" {
|
||||
return full.String(), fmt.Errorf("openai stream error: %s", chunk.Error.Message)
|
||||
}
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
delta := chunk.Choices[0].Delta.Content
|
||||
if delta == "" {
|
||||
delta = chunk.Choices[0].Delta.Text
|
||||
}
|
||||
if delta == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
full.WriteString(delta)
|
||||
if onDelta != nil {
|
||||
if err := onDelta(delta); err != nil {
|
||||
return full.String(), err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.logger.Debug("received OpenAI stream completion",
|
||||
zap.Duration("duration", time.Since(requestStart)),
|
||||
zap.Int("contentLen", full.Len()),
|
||||
)
|
||||
|
||||
return full.String(), nil
|
||||
}
|
||||
|
||||
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
|
||||
type StreamToolCall struct {
|
||||
Index int
|
||||
ID string
|
||||
Type string
|
||||
FunctionName string
|
||||
FunctionArgsStr string
|
||||
}
|
||||
|
||||
// ChatCompletionStreamWithToolCalls 流式模式:同时把 content delta 实时回调,并在结束后返回 tool_calls 和 finish_reason。
|
||||
func (c *Client) ChatCompletionStreamWithToolCalls(
|
||||
ctx context.Context,
|
||||
payload interface{},
|
||||
onContentDelta func(delta string) error,
|
||||
) (string, []StreamToolCall, string, error) {
|
||||
if c == nil {
|
||||
return "", nil, "", fmt.Errorf("openai client is not initialized")
|
||||
}
|
||||
if c.config == nil {
|
||||
return "", nil, "", fmt.Errorf("openai config is nil")
|
||||
}
|
||||
if strings.TrimSpace(c.config.APIKey) == "" {
|
||||
return "", nil, "", fmt.Errorf("openai api key is empty")
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", nil, "", fmt.Errorf("marshal openai payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", nil, "", fmt.Errorf("build openai request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
||||
|
||||
requestStart := time.Now()
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", nil, "", fmt.Errorf("call openai api: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return "", nil, "", &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
}
|
||||
}
|
||||
|
||||
// delta tool_calls 的增量结构
|
||||
type toolCallFunctionDelta struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
type toolCallDelta struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function toolCallFunctionDelta `json:"function,omitempty"`
|
||||
}
|
||||
type streamDelta2 struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ToolCalls []toolCallDelta `json:"tool_calls,omitempty"`
|
||||
}
|
||||
type streamChoice2 struct {
|
||||
Delta streamDelta2 `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
type streamResponse2 struct {
|
||||
Choices []streamChoice2 `json:"choices"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type toolCallAccum struct {
|
||||
id string
|
||||
typ string
|
||||
name string
|
||||
args strings.Builder
|
||||
}
|
||||
toolCallAccums := make(map[int]*toolCallAccum)
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var full strings.Builder
|
||||
finishReason := ""
|
||||
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
if readErr == io.EOF {
|
||||
break
|
||||
}
|
||||
return full.String(), nil, finishReason, fmt.Errorf("read openai stream: %w", readErr)
|
||||
}
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "data:") {
|
||||
continue
|
||||
}
|
||||
dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
|
||||
if dataStr == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chunk streamResponse2
|
||||
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
|
||||
// 兼容:解析失败跳过
|
||||
continue
|
||||
}
|
||||
if chunk.Error != nil && strings.TrimSpace(chunk.Error.Message) != "" {
|
||||
return full.String(), nil, finishReason, fmt.Errorf("openai stream error: %s", chunk.Error.Message)
|
||||
}
|
||||
if len(chunk.Choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
choice := chunk.Choices[0]
|
||||
if choice.FinishReason != nil && strings.TrimSpace(*choice.FinishReason) != "" {
|
||||
finishReason = strings.TrimSpace(*choice.FinishReason)
|
||||
}
|
||||
|
||||
delta := choice.Delta
|
||||
|
||||
content := delta.Content
|
||||
if content == "" {
|
||||
content = delta.Text
|
||||
}
|
||||
if content != "" {
|
||||
full.WriteString(content)
|
||||
if onContentDelta != nil {
|
||||
if err := onContentDelta(content); err != nil {
|
||||
return full.String(), nil, finishReason, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
for _, tc := range delta.ToolCalls {
|
||||
acc, ok := toolCallAccums[tc.Index]
|
||||
if !ok {
|
||||
acc = &toolCallAccum{}
|
||||
toolCallAccums[tc.Index] = acc
|
||||
}
|
||||
if tc.ID != "" {
|
||||
acc.id = tc.ID
|
||||
}
|
||||
if tc.Type != "" {
|
||||
acc.typ = tc.Type
|
||||
}
|
||||
if tc.Function.Name != "" {
|
||||
acc.name = tc.Function.Name
|
||||
}
|
||||
if tc.Function.Arguments != "" {
|
||||
acc.args.WriteString(tc.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 组装 tool calls
|
||||
indices := make([]int, 0, len(toolCallAccums))
|
||||
for idx := range toolCallAccums {
|
||||
indices = append(indices, idx)
|
||||
}
|
||||
// 手写简单排序(避免额外 import)
|
||||
for i := 0; i < len(indices); i++ {
|
||||
for j := i + 1; j < len(indices); j++ {
|
||||
if indices[j] < indices[i] {
|
||||
indices[i], indices[j] = indices[j], indices[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := make([]StreamToolCall, 0, len(indices))
|
||||
for _, idx := range indices {
|
||||
acc := toolCallAccums[idx]
|
||||
tc := StreamToolCall{
|
||||
Index: idx,
|
||||
ID: acc.id,
|
||||
Type: acc.typ,
|
||||
FunctionName: acc.name,
|
||||
FunctionArgsStr: acc.args.String(),
|
||||
}
|
||||
toolCalls = append(toolCalls, tc)
|
||||
}
|
||||
|
||||
c.logger.Debug("received OpenAI stream completion (tool_calls)",
|
||||
zap.Duration("duration", time.Since(requestStart)),
|
||||
zap.Int("contentLen", full.Len()),
|
||||
zap.Int("toolCalls", len(toolCalls)),
|
||||
zap.String("finishReason", finishReason),
|
||||
)
|
||||
|
||||
if strings.TrimSpace(finishReason) == "" {
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
return full.String(), toolCalls, finishReason, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
@@ -17,6 +19,15 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ToolOutputCallback 用于在工具执行过程中把 stdout/stderr 增量推给上层(SSE)。
|
||||
// 通过 context 传递,避免修改 MCP ToolHandler 签名导致的“写死工具”问题。
|
||||
type ToolOutputCallback func(chunk string)
|
||||
|
||||
type toolOutputCallbackCtxKey struct{}
|
||||
|
||||
// ToolOutputCallbackCtxKey 是 context 中的 key,供 Agent 写入回调,Executor 读取并流式回调。
|
||||
var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{}
|
||||
|
||||
// Executor 安全工具执行器
|
||||
type Executor struct {
|
||||
config *config.SecurityConfig
|
||||
@@ -144,7 +155,16 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
zap.Strings("args", cmdArgs),
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
var output string
|
||||
var err error
|
||||
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(cmd, cb)
|
||||
} else {
|
||||
outputBytes, err2 := cmd.CombinedOutput()
|
||||
output = string(outputBytes)
|
||||
err = err2
|
||||
}
|
||||
if err != nil {
|
||||
// 检查退出码是否在允许列表中
|
||||
exitCode := getExitCode(err)
|
||||
@@ -931,7 +951,16 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
}
|
||||
|
||||
// 非后台命令:等待输出
|
||||
output, err := cmd.CombinedOutput()
|
||||
var output string
|
||||
var err error
|
||||
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(cmd, cb)
|
||||
} else {
|
||||
outputBytes, err2 := cmd.CombinedOutput()
|
||||
output = string(outputBytes)
|
||||
err = err2
|
||||
}
|
||||
if err != nil {
|
||||
e.logger.Error("系统命令执行失败",
|
||||
zap.String("command", command),
|
||||
@@ -965,6 +994,78 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
}, nil
|
||||
}
|
||||
|
||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||
// 保持输出内容完整拼接返回,并用 cb(chunk) 向上层持续推送。
|
||||
func streamCommandOutput(cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
_ = stdoutPipe.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
_ = stdoutPipe.Close()
|
||||
_ = stderrPipe.Close()
|
||||
return "", err
|
||||
}
|
||||
|
||||
chunks := make(chan string, 64)
|
||||
var wg sync.WaitGroup
|
||||
readFn := func(r io.Reader) {
|
||||
defer wg.Done()
|
||||
br := bufio.NewReader(r)
|
||||
for {
|
||||
s, readErr := br.ReadString('\n')
|
||||
if s != "" {
|
||||
chunks <- s
|
||||
}
|
||||
if readErr != nil {
|
||||
// EOF 正常结束
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(2)
|
||||
go readFn(stdoutPipe)
|
||||
go readFn(stderrPipe)
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(chunks)
|
||||
}()
|
||||
|
||||
var outBuilder strings.Builder
|
||||
var deltaBuilder strings.Builder
|
||||
lastFlush := time.Now()
|
||||
|
||||
flush := func() {
|
||||
if deltaBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
cb(deltaBuilder.String())
|
||||
deltaBuilder.Reset()
|
||||
lastFlush = time.Now()
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
outBuilder.WriteString(chunk)
|
||||
deltaBuilder.WriteString(chunk)
|
||||
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次
|
||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||
flush()
|
||||
}
|
||||
}
|
||||
flush()
|
||||
|
||||
// 等待命令结束,返回最终退出状态
|
||||
waitErr := cmd.Wait()
|
||||
return outBuilder.String(), waitErr
|
||||
}
|
||||
|
||||
// executeInternalTool 执行内部工具(不执行外部命令)
|
||||
func (e *Executor) executeInternalTool(ctx context.Context, toolName string, command string, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
// 提取内部工具类型(去掉 "internal:" 前缀)
|
||||
|
||||
Reference in New Issue
Block a user