Add files via upload

This commit is contained in:
公明
2026-04-28 11:37:52 +08:00
committed by GitHub
parent 3b3d094dc4
commit b53cae3a02
26 changed files with 374 additions and 374 deletions
+21 -21
View File
@@ -336,10 +336,10 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
// AgentLoopResult Agent Loop执行结果 // AgentLoopResult Agent Loop执行结果
type AgentLoopResult struct { type AgentLoopResult struct {
Response string Response string
MCPExecutionIDs []string MCPExecutionIDs []string
LastReActInput string // 最后一轮ReAct的输入(压缩后的messagesJSON格式 LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messagesJSON;与 multiagent.RunResult 字段对齐
LastReActOutput string // 最终大模型的输出 LastAgentTraceOutput string // 最终助手输出文本
} }
// ProgressCallback 进度回调函数类型 // ProgressCallback 进度回调函数类型
@@ -471,7 +471,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
} }
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入 // 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
var currentReActInput string var currentAgentTraceInput string
maxIterations := a.maxIterations maxIterations := a.maxIterations
thinkingStreamSeq := 0 thinkingStreamSeq := 0
@@ -490,9 +490,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if err != nil { if err != nil {
a.logger.Warn("序列化ReAct输入失败", zap.Error(err)) a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
} else { } else {
currentReActInput = string(messagesJSON) currentAgentTraceInput = string(messagesJSON)
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的) // 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
} }
// 检查上下文是否已取消 // 检查上下文是否已取消
@@ -500,13 +500,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
case <-ctx.Done(): case <-ctx.Done():
// 上下文被取消(可能是用户主动暂停或其他原因) // 上下文被取消(可能是用户主动暂停或其他原因)
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err())) a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
if ctx.Err() == context.Canceled { if ctx.Err() == context.Canceled {
result.Response = "任务已被取消。" result.Response = "任务已被取消。"
} else { } else {
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err()) result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
} }
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, ctx.Err() return result, ctx.Err()
default: default:
} }
@@ -600,10 +600,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if err != nil { if err != nil {
// API调用失败,保存当前的ReAct输入和错误信息作为输出 // API调用失败,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err) errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err)) a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
return result, fmt.Errorf("调用OpenAI失败: %w", err) return result, fmt.Errorf("调用OpenAI失败: %w", err)
} }
@@ -629,19 +629,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
continue continue
} }
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出 // OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message) errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
} }
if len(response.Choices) == 0 { if len(response.Choices) == 0 {
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出 // 没有收到响应,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := "没有收到响应" errorMsg := "没有收到响应"
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("没有收到响应") return result, fmt.Errorf("没有收到响应")
} }
@@ -816,7 +816,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
@@ -863,14 +863,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
// 如果获取总结失败,使用当前回复作为结果 // 如果获取总结失败,使用当前回复作为结果
if choice.Message.Content != "" { if choice.Message.Content != "" {
result.Response = choice.Message.Content result.Response = choice.Message.Content
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
// 如果都没有内容,跳出循环,让后续逻辑处理 // 如果都没有内容,跳出循环,让后续逻辑处理
@@ -881,7 +881,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if choice.FinishReason == "stop" { if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil) sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content result.Response = choice.Message.Content
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
} }
@@ -910,14 +910,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
// 如果无法生成总结,返回友好的提示 // 如果无法生成总结,返回友好的提示
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations) result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
-1
View File
@@ -283,4 +283,3 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold) t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
} }
} }
+5 -5
View File
@@ -256,11 +256,11 @@ func orchestratorConfigFromOrchestrator(o *OrchestratorMarkdown) config.MultiAge
return config.MultiAgentSubConfig{} return config.MultiAgentSubConfig{}
} }
return config.MultiAgentSubConfig{ return config.MultiAgentSubConfig{
ID: o.EinoName, ID: o.EinoName,
Name: o.DisplayName, Name: o.DisplayName,
Description: o.Description, Description: o.Description,
Instruction: o.Instruction, Instruction: o.Instruction,
Kind: "orchestrator", Kind: "orchestrator",
} }
} }
+84 -92
View File
@@ -497,10 +497,10 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
defer h.hitlManager.DeactivateConversation(conversationID) defer h.hitlManager.DeactivateConversation(conversationID)
} }
// 优先尝试从保存的ReAct数据恢复历史上下文 // 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表 // 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID) historyMessages, err := h.db.GetMessages(conversationID)
if err != nil { if err != nil {
@@ -518,7 +518,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
} }
} else { } else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
} }
// 校验附件数量(非流式) // 校验附件数量(非流式)
@@ -613,12 +613,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
if err != nil { if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err)) h.logger.Error("Agent Loop执行失败", zap.Error(err))
// 即使执行失败,也尝试保存ReAct数据(如果result中有) // 即使执行失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil { if saveErr := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); saveErr != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr)) h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(saveErr))
} else { } else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -634,12 +634,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
// 因为AI已经生成了回复,用户应该能看到 // 因为AI已经生成了回复,用户应该能看到
} }
// 保存最后一轮ReAct的输入和输出 // 保存最后一轮代理轨迹与助手输出
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -666,7 +666,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
} }
} }
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID) historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil { if getErr != nil {
@@ -722,6 +722,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
"deep", "deep",
) )
if errMA != nil { if errMA != nil {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
errMsg := "执行失败: " + errMA.Error() errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
@@ -747,8 +748,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
} }
} }
if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" { if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput) _ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
} }
return resultMA.Response, conversationID, nil return resultMA.Response, conversationID, nil
} }
@@ -782,8 +783,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
} }
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput) _ = h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
} }
return result.Response, conversationID, nil return result.Response, conversationID, nil
} }
@@ -1359,10 +1360,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
} }
ssePublishConversationID = conversationID ssePublishConversationID = conversationID
// 优先尝试从保存的ReAct数据恢复历史上下文 // 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表 // 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID) historyMessages, err := h.db.GetMessages(conversationID)
if err != nil { if err != nil {
@@ -1380,7 +1381,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
} }
} else { } else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
} }
// 校验附件数量 // 校验附件数量
@@ -1579,12 +1580,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
} }
// 即使任务被取消,也尝试保存ReAct数据(如果result中有) // 即使任务被取消,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存取消任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1614,12 +1615,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
} }
// 即使任务超时,也尝试保存ReAct数据(如果result中有) // 即使任务超时,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存超时任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存超时任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1649,12 +1650,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
} }
// 即使任务失败,也尝试保存ReAct数据(如果result中有) // 即使任务失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1694,12 +1695,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
} }
} }
// 保存最后一轮ReAct的输入和输出 // 保存最后一轮代理轨迹与助手输出
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -2499,6 +2500,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
cancel() cancel()
if runErr != nil { if runErr != nil {
if useRunResult {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
// 检查是否是取消错误 // 检查是否是取消错误
// 1. 直接检查是否是 context.Canceled(包括包装后的错误) // 1. 直接检查是否是 context.Canceled(包括包装后的错误)
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字 // 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
@@ -2542,14 +2546,14 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg)) h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
} }
} }
// 保存ReAct数据(如果存在) // 保存代理轨迹(如果存在)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} }
} else if useRunResult && resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") { } else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} }
} }
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID) h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
@@ -2581,13 +2585,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
if useRunResult { if useRunResult {
resText = resultMA.Response resText = resultMA.Response
mcpIDs = resultMA.MCPExecutionIDs mcpIDs = resultMA.MCPExecutionIDs
lastIn = resultMA.LastReActInput lastIn = resultMA.LastAgentTraceInput
lastOut = resultMA.LastReActOutput lastOut = resultMA.LastAgentTraceOutput
} else { } else {
resText = result.Response resText = result.Response
mcpIDs = result.MCPExecutionIDs mcpIDs = result.MCPExecutionIDs
lastIn = result.LastReActInput lastIn = result.LastAgentTraceInput
lastOut = result.LastReActOutput lastOut = result.LastAgentTraceOutput
} }
// 更新助手消息内容 // 更新助手消息内容
@@ -2618,12 +2622,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
} }
} }
// 保存ReAct数据 // 保存代理轨迹
if lastIn != "" || lastOut != "" { if lastIn != "" || lastOut != "" {
if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil { if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
} }
} }
@@ -2642,36 +2646,33 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
} }
} }
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文 // loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退消息表 // 逻辑与攻击链一致:优先用保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) { func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
// 获取保存的ReAct输入和输出 traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID)
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
if err != nil { if err != nil {
return nil, fmt.Errorf("获取ReAct数据失败: %w", err) return nil, fmt.Errorf("获取代理轨迹失败: %w", err)
} }
// 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致) if traceInputJSON == "" {
if reactInputJSON == "" { return nil, fmt.Errorf("代理轨迹为空,将使用消息表")
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
} }
dataSource := "database_last_react_input" dataSource := "database_last_agent_trace"
// 解析JSON格式的messages数组
var messagesArray []map[string]interface{} var messagesArray []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil { if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err) return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err)
} }
messageCount := len(messagesArray) messageCount := len(messagesArray)
h.logger.Info("使用保存的ReAct数据恢复历史上下文", h.logger.Info("使用保存的代理轨迹恢复历史上下文",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)), zap.Int("traceInputSize", len(traceInputJSON)),
zap.Int("messageCount", messageCount), zap.Int("messageCount", messageCount),
zap.Int("reactOutputSize", len(reactOutput)), zap.Int("assistantOutSize", len(assistantOut)),
) )
// fmt.Println("messagesArray:", messagesArray)//debug // fmt.Println("messagesArray:", messagesArray)//debug
@@ -2755,53 +2756,44 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
agentMessages = append(agentMessages, msg) agentMessages = append(agentMessages, msg)
} }
// 如果存在last_react_output,需要将其作为最后一条assistant消息 // 存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致)
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出 if assistantOut != "" {
if reactOutput != "" {
// 检查最后一条消息是否是assistant消息且没有tool_calls
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
if len(agentMessages) > 0 { if len(agentMessages) > 0 {
lastMsg := &agentMessages[len(agentMessages)-1] lastMsg := &agentMessages[len(agentMessages)-1]
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 { if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content lastMsg.Content = assistantOut
lastMsg.Content = reactOutput
} else { } else {
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
agentMessages = append(agentMessages, agent.ChatMessage{ agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant", Role: "assistant",
Content: reactOutput, Content: assistantOut,
}) })
} }
} else { } else {
// 如果没有消息,直接添加最终输出
agentMessages = append(agentMessages, agent.ChatMessage{ agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant", Role: "assistant",
Content: reactOutput, Content: assistantOut,
}) })
} }
} }
if len(agentMessages) == 0 { if len(agentMessages) == 0 {
return nil, fmt.Errorf("从ReAct数据解析的消息为空") return nil, fmt.Errorf("从代理轨迹解析的消息为空")
} }
// 修复可能存在的失配tool消息,避免OpenAI报错
// 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误
if h.agent != nil { if h.agent != nil {
if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed { if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed {
h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息", h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
) )
} }
} }
h.logger.Info("从ReAct数据恢复历史消息完成", h.logger.Info("从代理轨迹恢复历史消息完成",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
zap.Int("originalMessageCount", messageCount), zap.Int("originalMessageCount", messageCount),
zap.Int("finalMessageCount", len(agentMessages)), zap.Int("finalMessageCount", len(agentMessages)),
zap.Bool("hasReactOutput", reactOutput != ""), zap.Bool("hasAssistantOut", assistantOut != ""),
) )
fmt.Println("agentMessages:", agentMessages) //debug
return agentMessages, nil return agentMessages, nil
} }
-1
View File
@@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
c.JSON(http.StatusOK, chain) c.JSON(http.StatusOK, chain)
} }
-1
View File
@@ -230,4 +230,3 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
"message": "ok", "message": "ok",
}) })
} }
+7 -5
View File
@@ -175,6 +175,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) { if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled" taskStatus = "cancelled"
@@ -239,9 +240,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -306,6 +307,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
progressCallback, progressCallback,
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
return return
} }
@@ -323,8 +325,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.AssistantMessageID, prep.AssistantMessageID,
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput) _ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
+10 -10
View File
@@ -247,7 +247,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
// 先添加一个配置 // 先添加一个配置
configObj := config.ExternalMCPServerConfig{ configObj := config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
} }
handler.manager.AddOrUpdateConfig("test-delete", configObj) handler.manager.AddOrUpdateConfig("test-delete", configObj)
@@ -276,11 +276,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
// 添加多个配置 // 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
}) })
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp", URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: false, ExternalMCPEnable: false,
}) })
@@ -319,15 +319,15 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
// 添加配置 // 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
}) })
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp", URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true, ExternalMCPEnable: true,
}) })
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
}) })
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil) req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
@@ -360,7 +360,7 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
// 添加一个禁用的配置 // 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
}) })
// 测试启动(可能会失败,因为没有真实的服务器) // 测试启动(可能会失败,因为没有真实的服务器)
@@ -416,7 +416,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter() router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{ configObj := config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
} }
@@ -459,14 +459,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
// 先添加配置 // 先添加配置
config1 := config.ExternalMCPServerConfig{ config1 := config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
} }
handler.manager.AddOrUpdateConfig("test-update", config1) handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置 // 更新配置
config2 := config.ExternalMCPServerConfig{ config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp", URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true, ExternalMCPEnable: true,
} }
+10 -10
View File
@@ -131,16 +131,16 @@ func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
} }
type markdownAgentBody struct { type markdownAgentBody struct {
Filename string `json:"filename"` Filename string `json:"filename"`
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Tools []string `json:"tools"` Tools []string `json:"tools"`
Instruction string `json:"instruction"` Instruction string `json:"instruction"`
BindRole string `json:"bind_role"` BindRole string `json:"bind_role"`
MaxIterations int `json:"max_iterations"` MaxIterations int `json:"max_iterations"`
Kind string `json:"kind"` Kind string `json:"kind"`
Raw string `json:"raw"` Raw string `json:"raw"`
} }
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents // CreateMarkdownAgent POST /api/multi-agent/markdown-agents
+5 -8
View File
@@ -42,11 +42,11 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
type MonitorResponse struct { type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"` Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"` Stats map[string]*mcp.ToolStats `json:"stats"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"` Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"` Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"` PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"` TotalPages int `json:"total_pages,omitempty"`
} }
// Monitor 获取监控信息 // Monitor 获取监控信息
@@ -213,7 +213,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
return stats return stats
} }
// GetExecution 获取特定执行记录 // GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) { func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
@@ -416,5 +415,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
} }
+21 -6
View File
@@ -185,6 +185,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) { if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled" taskStatus = "cancelled"
@@ -249,9 +250,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -318,6 +319,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
strings.TrimSpace(req.Orchestration), strings.TrimSpace(req.Orchestration),
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
errMsg := "执行失败: " + runErr.Error() errMsg := "执行失败: " + runErr.Error()
if prep.AssistantMessageID != "" { if prep.AssistantMessageID != "" {
@@ -341,9 +343,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -355,6 +357,19 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
}) })
} }
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
if h == nil || result == nil {
return
}
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
return
}
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
func multiAgentHTTPErrorStatus(err error) (int, string) { func multiAgentHTTPErrorStatus(err error) (int, string) {
msg := err.Error() msg := err.Error()
switch { switch {
+1 -1
View File
@@ -49,7 +49,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
} }
} }
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID) historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil { if getErr != nil {
+26 -26
View File
@@ -4445,7 +4445,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"messageId"}, "required": []string{"messageId"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"messageId": map[string]interface{}{ "messageId": map[string]interface{}{
@@ -4689,7 +4689,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"scheduleEnabled"}, "required": []string{"scheduleEnabled"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"}, "scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"},
@@ -4761,7 +4761,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"query"}, "required": []string{"query"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""}, "query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""},
@@ -4810,7 +4810,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"text"}, "required": []string{"text"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"}, "text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"},
@@ -4853,7 +4853,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"api_key", "model"}, "required": []string{"api_key", "model"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude", "example": "openai"}, "provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude", "example": "openai"},
@@ -4900,7 +4900,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"command"}, "required": []string{"command"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, "command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
@@ -4943,7 +4943,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"command"}, "required": []string{"command"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, "command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
@@ -5027,7 +5027,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"url"}, "required": []string{"url"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, "url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5231,7 +5231,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"url", "command"}, "required": []string{"url", "command"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, "url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5277,7 +5277,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"url", "action", "path"}, "required": []string{"url", "action", "path"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, "url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5339,14 +5339,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"items": map[string]interface{}{ "items": map[string]interface{}{
"type": "object", "type": "object",
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"relativePath": map[string]interface{}{"type": "string"}, "relativePath": map[string]interface{}{"type": "string"},
"absolutePath": map[string]interface{}{"type": "string"}, "absolutePath": map[string]interface{}{"type": "string"},
"name": map[string]interface{}{"type": "string"}, "name": map[string]interface{}{"type": "string"},
"size": map[string]interface{}{"type": "integer"}, "size": map[string]interface{}{"type": "integer"},
"modifiedUnix": map[string]interface{}{"type": "integer"}, "modifiedUnix": map[string]interface{}{"type": "integer"},
"date": map[string]interface{}{"type": "string"}, "date": map[string]interface{}{"type": "string"},
"conversationId": map[string]interface{}{"type": "string"}, "conversationId": map[string]interface{}{"type": "string"},
"subPath": map[string]interface{}{"type": "string"}, "subPath": map[string]interface{}{"type": "string"},
}, },
}, },
}, },
@@ -5369,7 +5369,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"multipart/form-data": map[string]interface{}{ "multipart/form-data": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"file"}, "required": []string{"file"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"}, "file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"},
@@ -5410,7 +5410,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"path"}, "required": []string{"path"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, "path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5485,7 +5485,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"path", "content"}, "required": []string{"path", "content"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, "path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5512,7 +5512,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"name"}, "required": []string{"name"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"}, "parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"},
@@ -5552,7 +5552,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"path", "newName"}, "required": []string{"path", "newName"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"}, "path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"},
@@ -5646,7 +5646,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"platform", "text"}, "required": []string{"platform", "text"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}}, "platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}},
@@ -5712,7 +5712,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"name"}, "required": []string{"name"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"}, "filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"},
@@ -5932,7 +5932,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"path"}, "required": []string{"path"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, "path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5974,7 +5974,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"ids"}, "required": []string{"ids"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"ids": map[string]interface{}{ "ids": map[string]interface{}{
+6 -6
View File
@@ -26,7 +26,7 @@ var apiDocI18nSummaryToKey = map[string]string{
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup", "创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup", "删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
"从分组移除对话": "removeConversationFromGroup", "从分组移除对话": "removeConversationFromGroup",
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats", "列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability", "获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole", "列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill", "获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
@@ -52,9 +52,9 @@ var apiDocI18nSummaryToKey = map[string]string{
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata", "重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled", "修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
"获取所有分组映射": "getAllGroupMappings", "获取所有分组映射": "getAllGroupMappings",
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse", "FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
"测试OpenAI API连接": "testOpenAI", "测试OpenAI API连接": "testOpenAI",
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS", "执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection", "列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection", "更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState", "获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
@@ -69,7 +69,7 @@ var apiDocI18nSummaryToKey = map[string]string{
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent", "获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile", "列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
"批量获取工具名称": "batchGetToolNames", "批量获取工具名称": "batchGetToolNames",
"获取知识库统计": "getKnowledgeStats", "获取知识库统计": "getKnowledgeStats",
} }
var apiDocI18nResponseDescToKey = map[string]string{ var apiDocI18nResponseDescToKey = map[string]string{
@@ -78,7 +78,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty", "对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound", "请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig", "请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed", "请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess", "登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid", "密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess", "对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
@@ -89,7 +89,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events": "streamResponse", "消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events": "streamResponse",
// 新增缺失端点响应 // 新增缺失端点响应
"参数错误或删除失败": "badRequestOrDeleteFailed", "参数错误或删除失败": "badRequestOrDeleteFailed",
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun", "参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess", "参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult", "搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished", "执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
+14 -14
View File
@@ -28,20 +28,20 @@ import (
) )
const ( const (
robotCmdHelp = "帮助" robotCmdHelp = "帮助"
robotCmdList = "列表" robotCmdList = "列表"
robotCmdListAlt = "对话列表" robotCmdListAlt = "对话列表"
robotCmdSwitch = "切换" robotCmdSwitch = "切换"
robotCmdContinue = "继续" robotCmdContinue = "继续"
robotCmdNew = "新对话" robotCmdNew = "新对话"
robotCmdClear = "清空" robotCmdClear = "清空"
robotCmdCurrent = "当前" robotCmdCurrent = "当前"
robotCmdStop = "停止" robotCmdStop = "停止"
robotCmdRoles = "角色" robotCmdRoles = "角色"
robotCmdRolesList = "角色列表" robotCmdRolesList = "角色列表"
robotCmdSwitchRole = "切换角色" robotCmdSwitchRole = "切换角色"
robotCmdDelete = "删除" robotCmdDelete = "删除"
robotCmdVersion = "版本" robotCmdVersion = "版本"
) )
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理 // RobotHandler 企业微信/钉钉/飞书等机器人回调处理
+13 -13
View File
@@ -65,19 +65,19 @@ func (h *SkillsHandler) GetSkills(c *gin.Context) {
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries)) allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
for _, s := range allSummaries { for _, s := range allSummaries {
skillInfo := map[string]interface{}{ skillInfo := map[string]interface{}{
"id": s.ID, "id": s.ID,
"name": s.Name, "name": s.Name,
"dir_name": s.DirName, "dir_name": s.DirName,
"description": s.Description, "description": s.Description,
"version": s.Version, "version": s.Version,
"path": s.Path, "path": s.Path,
"tags": s.Tags, "tags": s.Tags,
"triggers": s.Triggers, "triggers": s.Triggers,
"script_count": s.ScriptCount, "script_count": s.ScriptCount,
"file_count": s.FileCount, "file_count": s.FileCount,
"progressive": s.Progressive, "progressive": s.Progressive,
"file_size": s.FileSize, "file_size": s.FileSize,
"mod_time": s.ModTime, "mod_time": s.ModTime,
} }
allSkillsInfo = append(allSkillsInfo, skillInfo) allSkillsInfo = append(allSkillsInfo, skillInfo)
} }
-1
View File
@@ -109,4 +109,3 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
<-doneChan <-doneChan
} }
+20 -21
View File
@@ -28,7 +28,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
// CreateVulnerabilityRequest 创建漏洞请求 // CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct { type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"` ConversationID string `json:"conversation_id" binding:"required"`
ConversationTag string `json:"conversation_tag"` ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"` TaskTag string `json:"task_tag"`
Title string `json:"title" binding:"required"` Title string `json:"title" binding:"required"`
@@ -51,18 +51,18 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
} }
vuln := &database.Vulnerability{ vuln := &database.Vulnerability{
ConversationID: req.ConversationID, ConversationID: req.ConversationID,
ConversationTag: req.ConversationTag, ConversationTag: req.ConversationTag,
TaskTag: req.TaskTag, TaskTag: req.TaskTag,
Title: req.Title, Title: req.Title,
Description: req.Description, Description: req.Description,
Severity: req.Severity, Severity: req.Severity,
Status: req.Status, Status: req.Status,
Type: req.Type, Type: req.Type,
Target: req.Target, Target: req.Target,
Proof: req.Proof, Proof: req.Proof,
Impact: req.Impact, Impact: req.Impact,
Recommendation: req.Recommendation, Recommendation: req.Recommendation,
} }
created, err := h.db.CreateVulnerability(vuln) created, err := h.db.CreateVulnerability(vuln)
@@ -172,15 +172,15 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
type UpdateVulnerabilityRequest struct { type UpdateVulnerabilityRequest struct {
ConversationTag string `json:"conversation_tag"` ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"` TaskTag string `json:"task_tag"`
Title string `json:"title"` Title string `json:"title"`
Description string `json:"description"` Description string `json:"description"`
Severity string `json:"severity"` Severity string `json:"severity"`
Status string `json:"status"` Status string `json:"status"`
Type string `json:"type"` Type string `json:"type"`
Target string `json:"target"` Target string `json:"target"`
Proof string `json:"proof"` Proof string `json:"proof"`
Impact string `json:"impact"` Impact string `json:"impact"`
Recommendation string `json:"recommendation"` Recommendation string `json:"recommendation"`
} }
// UpdateVulnerability 更新漏洞 // UpdateVulnerability 更新漏洞
@@ -460,4 +460,3 @@ func sanitizeExportName(raw string) string {
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-") replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
return replacer.Replace(name) return replacer.Replace(name)
} }
+3 -3
View File
@@ -16,9 +16,9 @@ const (
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo]. // DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
const ( const (
DSLRiskType = "risk_type" DSLRiskType = "risk_type"
DSLSimilarityThreshold = "similarity_threshold" DSLSimilarityThreshold = "similarity_threshold"
DSLSubIndexFilter = "sub_index_filter" DSLSubIndexFilter = "sub_index_filter"
) )
// FormatEmbeddingInput matches the historical indexing format so existing embeddings // FormatEmbeddingInput matches the historical indexing format so existing embeddings
+1 -1
View File
@@ -8,8 +8,8 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
) )
+9 -9
View File
@@ -11,9 +11,9 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file" fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -35,14 +35,14 @@ type Indexer struct {
lastErrorTime time.Time lastErrorTime time.Time
errorCount int errorCount int
rebuildMu sync.RWMutex rebuildMu sync.RWMutex
isRebuilding bool isRebuilding bool
rebuildTotalItems int rebuildTotalItems int
rebuildCurrent int rebuildCurrent int
rebuildFailed int rebuildFailed int
rebuildStartTime time.Time rebuildStartTime time.Time
rebuildLastItemID string rebuildLastItemID string
rebuildLastChunks int rebuildLastChunks int
} }
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。 // NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
+3 -3
View File
@@ -108,9 +108,9 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
// CategoryWithItems 分类及其下的知识项(用于按分类分页) // CategoryWithItems 分类及其下的知识项(用于按分类分页)
type CategoryWithItems struct { type CategoryWithItems struct {
Category string `json:"category"` // 分类名称 Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数 ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表 Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
} }
// SearchRequest 搜索请求 // SearchRequest 搜索请求
+7 -7
View File
@@ -192,13 +192,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
fnName, _ := fn["name"].(string) fnName, _ := fn["name"].(string)
fnArgs, _ := fn["arguments"] fnArgs, _ := fn["arguments"]
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝 // 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
if strings.TrimSpace(fnName) == "" { if strings.TrimSpace(fnName) == "" {
fnName = "unknown_function" fnName = "unknown_function"
} }
if strings.TrimSpace(tcID) == "" { if strings.TrimSpace(tcID) == "" {
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano()) tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
} }
var inputRaw json.RawMessage var inputRaw json.RawMessage
switch v := fnArgs.(type) { switch v := fnArgs.(type) {
+14 -14
View File
@@ -281,9 +281,9 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。 // StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
type StreamToolCall struct { type StreamToolCall struct {
Index int Index int
ID string ID string
Type string Type string
FunctionName string FunctionName string
FunctionArgsStr string FunctionArgsStr string
} }
@@ -348,10 +348,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
Arguments string `json:"arguments,omitempty"` Arguments string `json:"arguments,omitempty"`
} }
type toolCallDelta struct { type toolCallDelta struct {
Index int `json:"index,omitempty"` Index int `json:"index,omitempty"`
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
Function toolCallFunctionDelta `json:"function,omitempty"` Function toolCallFunctionDelta `json:"function,omitempty"`
} }
type streamDelta2 struct { type streamDelta2 struct {
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
@@ -371,10 +371,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
} }
type toolCallAccum struct { type toolCallAccum struct {
id string id string
typ string typ string
name string name string
args strings.Builder args strings.Builder
} }
toolCallAccums := make(map[int]*toolCallAccum) toolCallAccums := make(map[int]*toolCallAccum)
@@ -475,9 +475,9 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
for _, idx := range indices { for _, idx := range indices {
acc := toolCallAccums[idx] acc := toolCallAccums[idx]
tc := StreamToolCall{ tc := StreamToolCall{
Index: idx, Index: idx,
ID: acc.id, ID: acc.id,
Type: acc.typ, Type: acc.typ,
FunctionName: acc.name, FunctionName: acc.name,
FunctionArgsStr: acc.args.String(), FunctionArgsStr: acc.args.String(),
} }
-1
View File
@@ -265,4 +265,3 @@ func TestPaginateLines(t *testing.T) {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines)) t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
} }
} }
+4 -4
View File
@@ -16,10 +16,10 @@ type rateLimitEntry struct {
// RateLimiter 基于 IP 的滑动窗口速率限制器 // RateLimiter 基于 IP 的滑动窗口速率限制器
type RateLimiter struct { type RateLimiter struct {
mu sync.Mutex mu sync.Mutex
entries map[string]*rateLimitEntry entries map[string]*rateLimitEntry
limit int // 窗口内允许的最大请求数 limit int // 窗口内允许的最大请求数
window time.Duration // 窗口时长 window time.Duration // 窗口时长
} }
// NewRateLimiter 创建速率限制器 // NewRateLimiter 创建速率限制器