mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-03 12:58:08 +02:00
Add files via upload
This commit is contained in:
+21
-21
@@ -336,10 +336,10 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
|
|||||||
|
|
||||||
// AgentLoopResult Agent Loop执行结果
|
// AgentLoopResult Agent Loop执行结果
|
||||||
type AgentLoopResult struct {
|
type AgentLoopResult struct {
|
||||||
Response string
|
Response string
|
||||||
MCPExecutionIDs []string
|
MCPExecutionIDs []string
|
||||||
LastReActInput string // 最后一轮ReAct的输入(压缩后的messages,JSON格式)
|
LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messages,JSON;与 multiagent.RunResult 字段对齐)
|
||||||
LastReActOutput string // 最终大模型的输出
|
LastAgentTraceOutput string // 最终助手输出文本
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProgressCallback 进度回调函数类型
|
// ProgressCallback 进度回调函数类型
|
||||||
@@ -471,7 +471,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
|
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
|
||||||
var currentReActInput string
|
var currentAgentTraceInput string
|
||||||
|
|
||||||
maxIterations := a.maxIterations
|
maxIterations := a.maxIterations
|
||||||
thinkingStreamSeq := 0
|
thinkingStreamSeq := 0
|
||||||
@@ -490,9 +490,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
|
a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
|
||||||
} else {
|
} else {
|
||||||
currentReActInput = string(messagesJSON)
|
currentAgentTraceInput = string(messagesJSON)
|
||||||
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
|
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
|
||||||
result.LastReActInput = currentReActInput
|
result.LastAgentTraceInput = currentAgentTraceInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查上下文是否已取消
|
// 检查上下文是否已取消
|
||||||
@@ -500,13 +500,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// 上下文被取消(可能是用户主动暂停或其他原因)
|
// 上下文被取消(可能是用户主动暂停或其他原因)
|
||||||
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
|
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
|
||||||
result.LastReActInput = currentReActInput
|
result.LastAgentTraceInput = currentAgentTraceInput
|
||||||
if ctx.Err() == context.Canceled {
|
if ctx.Err() == context.Canceled {
|
||||||
result.Response = "任务已被取消。"
|
result.Response = "任务已被取消。"
|
||||||
} else {
|
} else {
|
||||||
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
|
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
|
||||||
}
|
}
|
||||||
result.LastReActOutput = result.Response
|
result.LastAgentTraceOutput = result.Response
|
||||||
return result, ctx.Err()
|
return result, ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -600,10 +600,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// API调用失败,保存当前的ReAct输入和错误信息作为输出
|
// API调用失败,保存当前的ReAct输入和错误信息作为输出
|
||||||
result.LastReActInput = currentReActInput
|
result.LastAgentTraceInput = currentAgentTraceInput
|
||||||
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
|
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
|
||||||
result.Response = errorMsg
|
result.Response = errorMsg
|
||||||
result.LastReActOutput = errorMsg
|
result.LastAgentTraceOutput = errorMsg
|
||||||
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
|
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
|
||||||
return result, fmt.Errorf("调用OpenAI失败: %w", err)
|
return result, fmt.Errorf("调用OpenAI失败: %w", err)
|
||||||
}
|
}
|
||||||
@@ -629,19 +629,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
|
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
|
||||||
result.LastReActInput = currentReActInput
|
result.LastAgentTraceInput = currentAgentTraceInput
|
||||||
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
|
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
|
||||||
result.Response = errorMsg
|
result.Response = errorMsg
|
||||||
result.LastReActOutput = errorMsg
|
result.LastAgentTraceOutput = errorMsg
|
||||||
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
|
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(response.Choices) == 0 {
|
if len(response.Choices) == 0 {
|
||||||
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出
|
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出
|
||||||
result.LastReActInput = currentReActInput
|
result.LastAgentTraceInput = currentAgentTraceInput
|
||||||
errorMsg := "没有收到响应"
|
errorMsg := "没有收到响应"
|
||||||
result.Response = errorMsg
|
result.Response = errorMsg
|
||||||
result.LastReActOutput = errorMsg
|
result.LastAgentTraceOutput = errorMsg
|
||||||
return result, fmt.Errorf("没有收到响应")
|
return result, fmt.Errorf("没有收到响应")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -816,7 +816,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
result.Response = streamText
|
result.Response = streamText
|
||||||
result.LastReActOutput = result.Response
|
result.LastAgentTraceOutput = result.Response
|
||||||
sendProgress("progress", "总结生成完成", nil)
|
sendProgress("progress", "总结生成完成", nil)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
@@ -863,14 +863,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
result.Response = streamText
|
result.Response = streamText
|
||||||
result.LastReActOutput = result.Response
|
result.LastAgentTraceOutput = result.Response
|
||||||
sendProgress("progress", "总结生成完成", nil)
|
sendProgress("progress", "总结生成完成", nil)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
// 如果获取总结失败,使用当前回复作为结果
|
// 如果获取总结失败,使用当前回复作为结果
|
||||||
if choice.Message.Content != "" {
|
if choice.Message.Content != "" {
|
||||||
result.Response = choice.Message.Content
|
result.Response = choice.Message.Content
|
||||||
result.LastReActOutput = result.Response
|
result.LastAgentTraceOutput = result.Response
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
// 如果都没有内容,跳出循环,让后续逻辑处理
|
// 如果都没有内容,跳出循环,让后续逻辑处理
|
||||||
@@ -881,7 +881,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
if choice.FinishReason == "stop" {
|
if choice.FinishReason == "stop" {
|
||||||
sendProgress("progress", "正在生成最终回复...", nil)
|
sendProgress("progress", "正在生成最终回复...", nil)
|
||||||
result.Response = choice.Message.Content
|
result.Response = choice.Message.Content
|
||||||
result.LastReActOutput = result.Response
|
result.LastAgentTraceOutput = result.Response
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -910,14 +910,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
result.Response = streamText
|
result.Response = streamText
|
||||||
result.LastReActOutput = result.Response
|
result.LastAgentTraceOutput = result.Response
|
||||||
sendProgress("progress", "总结生成完成", nil)
|
sendProgress("progress", "总结生成完成", nil)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果无法生成总结,返回友好的提示
|
// 如果无法生成总结,返回友好的提示
|
||||||
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
|
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
|
||||||
result.LastReActOutput = result.Response
|
result.LastAgentTraceOutput = result.Response
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -283,4 +283,3 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
|||||||
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
|
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -256,11 +256,11 @@ func orchestratorConfigFromOrchestrator(o *OrchestratorMarkdown) config.MultiAge
|
|||||||
return config.MultiAgentSubConfig{}
|
return config.MultiAgentSubConfig{}
|
||||||
}
|
}
|
||||||
return config.MultiAgentSubConfig{
|
return config.MultiAgentSubConfig{
|
||||||
ID: o.EinoName,
|
ID: o.EinoName,
|
||||||
Name: o.DisplayName,
|
Name: o.DisplayName,
|
||||||
Description: o.Description,
|
Description: o.Description,
|
||||||
Instruction: o.Instruction,
|
Instruction: o.Instruction,
|
||||||
Kind: "orchestrator",
|
Kind: "orchestrator",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+84
-92
@@ -497,10 +497,10 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
|||||||
defer h.hitlManager.DeactivateConversation(conversationID)
|
defer h.hitlManager.DeactivateConversation(conversationID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
// 优先尝试从保存的代理轨迹恢复历史上下文
|
||||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
|
||||||
// 回退到使用数据库消息表
|
// 回退到使用数据库消息表
|
||||||
historyMessages, err := h.db.GetMessages(conversationID)
|
historyMessages, err := h.db.GetMessages(conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -518,7 +518,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
|||||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 校验附件数量(非流式)
|
// 校验附件数量(非流式)
|
||||||
@@ -613,12 +613,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||||
|
|
||||||
// 即使执行失败,也尝试保存ReAct数据(如果result中有)
|
// 即使执行失败,也尝试保存代理轨迹(如果 result 中有)
|
||||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||||
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil {
|
if saveErr := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); saveErr != nil {
|
||||||
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr))
|
h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(saveErr))
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
|
h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -634,12 +634,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
|||||||
// 因为AI已经生成了回复,用户应该能看到
|
// 因为AI已经生成了回复,用户应该能看到
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存最后一轮ReAct的输入和输出
|
// 保存最后一轮代理轨迹与助手输出
|
||||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
|
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
|
h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -666,7 +666,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
historyMessages, getErr := h.db.GetMessages(conversationID)
|
historyMessages, getErr := h.db.GetMessages(conversationID)
|
||||||
if getErr != nil {
|
if getErr != nil {
|
||||||
@@ -722,6 +722,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
|||||||
"deep",
|
"deep",
|
||||||
)
|
)
|
||||||
if errMA != nil {
|
if errMA != nil {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||||
errMsg := "执行失败: " + errMA.Error()
|
errMsg := "执行失败: " + errMA.Error()
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
|
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
|
||||||
@@ -747,8 +748,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
|||||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" {
|
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||||
_ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput)
|
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||||
}
|
}
|
||||||
return resultMA.Response, conversationID, nil
|
return resultMA.Response, conversationID, nil
|
||||||
}
|
}
|
||||||
@@ -782,8 +783,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
|||||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
_ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput)
|
_ = h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||||
}
|
}
|
||||||
return result.Response, conversationID, nil
|
return result.Response, conversationID, nil
|
||||||
}
|
}
|
||||||
@@ -1359,10 +1360,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
ssePublishConversationID = conversationID
|
ssePublishConversationID = conversationID
|
||||||
|
|
||||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
// 优先尝试从保存的代理轨迹恢复历史上下文
|
||||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
|
||||||
// 回退到使用数据库消息表
|
// 回退到使用数据库消息表
|
||||||
historyMessages, err := h.db.GetMessages(conversationID)
|
historyMessages, err := h.db.GetMessages(conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1380,7 +1381,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 校验附件数量
|
// 校验附件数量
|
||||||
@@ -1579,12 +1580,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 即使任务被取消,也尝试保存ReAct数据(如果result中有)
|
// 即使任务被取消,也尝试保存代理轨迹(如果 result 中有)
|
||||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err))
|
h.logger.Warn("保存取消任务的代理轨迹失败", zap.Error(err))
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID))
|
h.logger.Info("已保存取消任务的代理轨迹", zap.String("conversationId", conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1614,12 +1615,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 即使任务超时,也尝试保存ReAct数据(如果result中有)
|
// 即使任务超时,也尝试保存代理轨迹(如果 result 中有)
|
||||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err))
|
h.logger.Warn("保存超时任务的代理轨迹失败", zap.Error(err))
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID))
|
h.logger.Info("已保存超时任务的代理轨迹", zap.String("conversationId", conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1649,12 +1650,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
|
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 即使任务失败,也尝试保存ReAct数据(如果result中有)
|
// 即使任务失败,也尝试保存代理轨迹(如果 result 中有)
|
||||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err))
|
h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(err))
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
|
h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1694,12 +1695,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存最后一轮ReAct的输入和输出
|
// 保存最后一轮代理轨迹与助手输出
|
||||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
|
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
|
h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2499,6 +2500,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
|
if useRunResult {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||||
|
}
|
||||||
// 检查是否是取消错误
|
// 检查是否是取消错误
|
||||||
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
|
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
|
||||||
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
|
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
|
||||||
@@ -2542,14 +2546,14 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 保存ReAct数据(如果存在)
|
// 保存代理轨迹(如果存在)
|
||||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
}
|
}
|
||||||
} else if useRunResult && resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") {
|
} else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") {
|
||||||
if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
||||||
@@ -2581,13 +2585,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
if useRunResult {
|
if useRunResult {
|
||||||
resText = resultMA.Response
|
resText = resultMA.Response
|
||||||
mcpIDs = resultMA.MCPExecutionIDs
|
mcpIDs = resultMA.MCPExecutionIDs
|
||||||
lastIn = resultMA.LastReActInput
|
lastIn = resultMA.LastAgentTraceInput
|
||||||
lastOut = resultMA.LastReActOutput
|
lastOut = resultMA.LastAgentTraceOutput
|
||||||
} else {
|
} else {
|
||||||
resText = result.Response
|
resText = result.Response
|
||||||
mcpIDs = result.MCPExecutionIDs
|
mcpIDs = result.MCPExecutionIDs
|
||||||
lastIn = result.LastReActInput
|
lastIn = result.LastAgentTraceInput
|
||||||
lastOut = result.LastReActOutput
|
lastOut = result.LastAgentTraceOutput
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新助手消息内容
|
// 更新助手消息内容
|
||||||
@@ -2618,12 +2622,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存ReAct数据
|
// 保存代理轨迹
|
||||||
if lastIn != "" || lastOut != "" {
|
if lastIn != "" || lastOut != "" {
|
||||||
if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
||||||
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
} else {
|
} else {
|
||||||
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2642,36 +2646,33 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
|
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
|
||||||
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表
|
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
|
||||||
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
|
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
|
||||||
// 获取保存的ReAct输入和输出
|
traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID)
|
||||||
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
|
return nil, fmt.Errorf("获取代理轨迹失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致)
|
if traceInputJSON == "" {
|
||||||
if reactInputJSON == "" {
|
return nil, fmt.Errorf("代理轨迹为空,将使用消息表")
|
||||||
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dataSource := "database_last_react_input"
|
dataSource := "database_last_agent_trace"
|
||||||
|
|
||||||
// 解析JSON格式的messages数组
|
|
||||||
var messagesArray []map[string]interface{}
|
var messagesArray []map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
|
if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil {
|
||||||
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
|
return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
messageCount := len(messagesArray)
|
messageCount := len(messagesArray)
|
||||||
|
|
||||||
h.logger.Info("使用保存的ReAct数据恢复历史上下文",
|
h.logger.Info("使用保存的代理轨迹恢复历史上下文",
|
||||||
zap.String("conversationId", conversationID),
|
zap.String("conversationId", conversationID),
|
||||||
zap.String("dataSource", dataSource),
|
zap.String("dataSource", dataSource),
|
||||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
zap.Int("traceInputSize", len(traceInputJSON)),
|
||||||
zap.Int("messageCount", messageCount),
|
zap.Int("messageCount", messageCount),
|
||||||
zap.Int("reactOutputSize", len(reactOutput)),
|
zap.Int("assistantOutSize", len(assistantOut)),
|
||||||
)
|
)
|
||||||
// fmt.Println("messagesArray:", messagesArray)//debug
|
// fmt.Println("messagesArray:", messagesArray)//debug
|
||||||
|
|
||||||
@@ -2755,53 +2756,44 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
|
|||||||
agentMessages = append(agentMessages, msg)
|
agentMessages = append(agentMessages, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果存在last_react_output,需要将其作为最后一条assistant消息
|
// 若存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致)
|
||||||
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出
|
if assistantOut != "" {
|
||||||
if reactOutput != "" {
|
|
||||||
// 检查最后一条消息是否是assistant消息且没有tool_calls
|
|
||||||
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
|
|
||||||
if len(agentMessages) > 0 {
|
if len(agentMessages) > 0 {
|
||||||
lastMsg := &agentMessages[len(agentMessages)-1]
|
lastMsg := &agentMessages[len(agentMessages)-1]
|
||||||
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
|
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
|
||||||
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content
|
lastMsg.Content = assistantOut
|
||||||
lastMsg.Content = reactOutput
|
|
||||||
} else {
|
} else {
|
||||||
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
|
|
||||||
agentMessages = append(agentMessages, agent.ChatMessage{
|
agentMessages = append(agentMessages, agent.ChatMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: reactOutput,
|
Content: assistantOut,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 如果没有消息,直接添加最终输出
|
|
||||||
agentMessages = append(agentMessages, agent.ChatMessage{
|
agentMessages = append(agentMessages, agent.ChatMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: reactOutput,
|
Content: assistantOut,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(agentMessages) == 0 {
|
if len(agentMessages) == 0 {
|
||||||
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
|
return nil, fmt.Errorf("从代理轨迹解析的消息为空")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 修复可能存在的失配tool消息,避免OpenAI报错
|
|
||||||
// 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误
|
|
||||||
if h.agent != nil {
|
if h.agent != nil {
|
||||||
if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed {
|
if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed {
|
||||||
h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息",
|
h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息",
|
||||||
zap.String("conversationId", conversationID),
|
zap.String("conversationId", conversationID),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("从ReAct数据恢复历史消息完成",
|
h.logger.Info("从代理轨迹恢复历史消息完成",
|
||||||
zap.String("conversationId", conversationID),
|
zap.String("conversationId", conversationID),
|
||||||
zap.String("dataSource", dataSource),
|
zap.String("dataSource", dataSource),
|
||||||
zap.Int("originalMessageCount", messageCount),
|
zap.Int("originalMessageCount", messageCount),
|
||||||
zap.Int("finalMessageCount", len(agentMessages)),
|
zap.Int("finalMessageCount", len(agentMessages)),
|
||||||
zap.Bool("hasReactOutput", reactOutput != ""),
|
zap.Bool("hasAssistantOut", assistantOut != ""),
|
||||||
)
|
)
|
||||||
fmt.Println("agentMessages:", agentMessages) //debug
|
|
||||||
return agentMessages, nil
|
return agentMessages, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
|
|||||||
|
|
||||||
c.JSON(http.StatusOK, chain)
|
c.JSON(http.StatusOK, chain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -230,4 +230,3 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
|
|||||||
"message": "ok",
|
"message": "ok",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, ErrTaskCancelled) {
|
if errors.Is(cause, ErrTaskCancelled) {
|
||||||
taskStatus = "cancelled"
|
taskStatus = "cancelled"
|
||||||
@@ -239,9 +240,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
|
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,6 +307,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
progressCallback,
|
progressCallback,
|
||||||
)
|
)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
|
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -323,8 +325,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
prep.AssistantMessageID,
|
prep.AssistantMessageID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
_ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput)
|
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
|
|||||||
|
|
||||||
// 先添加一个配置
|
// 先添加一个配置
|
||||||
configObj := config.ExternalMCPServerConfig{
|
configObj := config.ExternalMCPServerConfig{
|
||||||
Command: "python3",
|
Command: "python3",
|
||||||
ExternalMCPEnable: true,
|
ExternalMCPEnable: true,
|
||||||
}
|
}
|
||||||
handler.manager.AddOrUpdateConfig("test-delete", configObj)
|
handler.manager.AddOrUpdateConfig("test-delete", configObj)
|
||||||
@@ -276,11 +276,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
|||||||
|
|
||||||
// 添加多个配置
|
// 添加多个配置
|
||||||
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
|
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
|
||||||
Command: "python3",
|
Command: "python3",
|
||||||
ExternalMCPEnable: true,
|
ExternalMCPEnable: true,
|
||||||
})
|
})
|
||||||
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
|
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
|
||||||
URL: "http://127.0.0.1:8081/mcp",
|
URL: "http://127.0.0.1:8081/mcp",
|
||||||
ExternalMCPEnable: false,
|
ExternalMCPEnable: false,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -319,15 +319,15 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
|
|||||||
|
|
||||||
// 添加配置
|
// 添加配置
|
||||||
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||||
Command: "python3",
|
Command: "python3",
|
||||||
ExternalMCPEnable: true,
|
ExternalMCPEnable: true,
|
||||||
})
|
})
|
||||||
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||||
URL: "http://127.0.0.1:8081/mcp",
|
URL: "http://127.0.0.1:8081/mcp",
|
||||||
ExternalMCPEnable: true,
|
ExternalMCPEnable: true,
|
||||||
})
|
})
|
||||||
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||||
Command: "python3",
|
Command: "python3",
|
||||||
})
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
|
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
|
||||||
@@ -360,7 +360,7 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
|
|||||||
|
|
||||||
// 添加一个禁用的配置
|
// 添加一个禁用的配置
|
||||||
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
|
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
|
||||||
Command: "python3",
|
Command: "python3",
|
||||||
})
|
})
|
||||||
|
|
||||||
// 测试启动(可能会失败,因为没有真实的服务器)
|
// 测试启动(可能会失败,因为没有真实的服务器)
|
||||||
@@ -416,7 +416,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
|
|||||||
router, _, _ := setupTestRouter()
|
router, _, _ := setupTestRouter()
|
||||||
|
|
||||||
configObj := config.ExternalMCPServerConfig{
|
configObj := config.ExternalMCPServerConfig{
|
||||||
Command: "python3",
|
Command: "python3",
|
||||||
ExternalMCPEnable: true,
|
ExternalMCPEnable: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -459,14 +459,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
|
|||||||
|
|
||||||
// 先添加配置
|
// 先添加配置
|
||||||
config1 := config.ExternalMCPServerConfig{
|
config1 := config.ExternalMCPServerConfig{
|
||||||
Command: "python3",
|
Command: "python3",
|
||||||
ExternalMCPEnable: true,
|
ExternalMCPEnable: true,
|
||||||
}
|
}
|
||||||
handler.manager.AddOrUpdateConfig("test-update", config1)
|
handler.manager.AddOrUpdateConfig("test-update", config1)
|
||||||
|
|
||||||
// 更新配置
|
// 更新配置
|
||||||
config2 := config.ExternalMCPServerConfig{
|
config2 := config.ExternalMCPServerConfig{
|
||||||
URL: "http://127.0.0.1:8081/mcp",
|
URL: "http://127.0.0.1:8081/mcp",
|
||||||
ExternalMCPEnable: true,
|
ExternalMCPEnable: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -131,16 +131,16 @@ func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type markdownAgentBody struct {
|
type markdownAgentBody struct {
|
||||||
Filename string `json:"filename"`
|
Filename string `json:"filename"`
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Tools []string `json:"tools"`
|
Tools []string `json:"tools"`
|
||||||
Instruction string `json:"instruction"`
|
Instruction string `json:"instruction"`
|
||||||
BindRole string `json:"bind_role"`
|
BindRole string `json:"bind_role"`
|
||||||
MaxIterations int `json:"max_iterations"`
|
MaxIterations int `json:"max_iterations"`
|
||||||
Kind string `json:"kind"`
|
Kind string `json:"kind"`
|
||||||
Raw string `json:"raw"`
|
Raw string `json:"raw"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents
|
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents
|
||||||
|
|||||||
@@ -42,11 +42,11 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
|||||||
type MonitorResponse struct {
|
type MonitorResponse struct {
|
||||||
Executions []*mcp.ToolExecution `json:"executions"`
|
Executions []*mcp.ToolExecution `json:"executions"`
|
||||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
Total int `json:"total,omitempty"`
|
Total int `json:"total,omitempty"`
|
||||||
Page int `json:"page,omitempty"`
|
Page int `json:"page,omitempty"`
|
||||||
PageSize int `json:"page_size,omitempty"`
|
PageSize int `json:"page_size,omitempty"`
|
||||||
TotalPages int `json:"total_pages,omitempty"`
|
TotalPages int `json:"total_pages,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Monitor 获取监控信息
|
// Monitor 获取监控信息
|
||||||
@@ -213,7 +213,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
|
|||||||
return stats
|
return stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// GetExecution 获取特定执行记录
|
// GetExecution 获取特定执行记录
|
||||||
func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
@@ -416,5 +415,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
|||||||
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
|
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, ErrTaskCancelled) {
|
if errors.Is(cause, ErrTaskCancelled) {
|
||||||
taskStatus = "cancelled"
|
taskStatus = "cancelled"
|
||||||
@@ -249,9 +250,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
|
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,6 +319,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
strings.TrimSpace(req.Orchestration),
|
strings.TrimSpace(req.Orchestration),
|
||||||
)
|
)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
|
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||||
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
|
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
|
||||||
errMsg := "执行失败: " + runErr.Error()
|
errMsg := "执行失败: " + runErr.Error()
|
||||||
if prep.AssistantMessageID != "" {
|
if prep.AssistantMessageID != "" {
|
||||||
@@ -341,9 +343,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
|
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,6 +357,19 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
|
||||||
|
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
|
||||||
|
if h == nil || result == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||||
|
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
historyMessages, getErr := h.db.GetMessages(conversationID)
|
historyMessages, getErr := h.db.GetMessages(conversationID)
|
||||||
if getErr != nil {
|
if getErr != nil {
|
||||||
|
|||||||
+26
-26
@@ -4445,7 +4445,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"messageId"},
|
"required": []string{"messageId"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"messageId": map[string]interface{}{
|
"messageId": map[string]interface{}{
|
||||||
@@ -4689,7 +4689,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"scheduleEnabled"},
|
"required": []string{"scheduleEnabled"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"},
|
"scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"},
|
||||||
@@ -4761,7 +4761,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"query"},
|
"required": []string{"query"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""},
|
"query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""},
|
||||||
@@ -4810,7 +4810,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"text"},
|
"required": []string{"text"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"},
|
"text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"},
|
||||||
@@ -4853,7 +4853,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"api_key", "model"},
|
"required": []string{"api_key", "model"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude)", "example": "openai"},
|
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude)", "example": "openai"},
|
||||||
@@ -4900,7 +4900,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"command"},
|
"required": []string{"command"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
|
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
|
||||||
@@ -4943,7 +4943,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"command"},
|
"required": []string{"command"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
|
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
|
||||||
@@ -5027,7 +5027,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"url"},
|
"required": []string{"url"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
||||||
@@ -5231,7 +5231,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"url", "command"},
|
"required": []string{"url", "command"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
||||||
@@ -5277,7 +5277,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"url", "action", "path"},
|
"required": []string{"url", "action", "path"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
||||||
@@ -5339,14 +5339,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"items": map[string]interface{}{
|
"items": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"relativePath": map[string]interface{}{"type": "string"},
|
"relativePath": map[string]interface{}{"type": "string"},
|
||||||
"absolutePath": map[string]interface{}{"type": "string"},
|
"absolutePath": map[string]interface{}{"type": "string"},
|
||||||
"name": map[string]interface{}{"type": "string"},
|
"name": map[string]interface{}{"type": "string"},
|
||||||
"size": map[string]interface{}{"type": "integer"},
|
"size": map[string]interface{}{"type": "integer"},
|
||||||
"modifiedUnix": map[string]interface{}{"type": "integer"},
|
"modifiedUnix": map[string]interface{}{"type": "integer"},
|
||||||
"date": map[string]interface{}{"type": "string"},
|
"date": map[string]interface{}{"type": "string"},
|
||||||
"conversationId": map[string]interface{}{"type": "string"},
|
"conversationId": map[string]interface{}{"type": "string"},
|
||||||
"subPath": map[string]interface{}{"type": "string"},
|
"subPath": map[string]interface{}{"type": "string"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -5369,7 +5369,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"multipart/form-data": map[string]interface{}{
|
"multipart/form-data": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"file"},
|
"required": []string{"file"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"},
|
"file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"},
|
||||||
@@ -5410,7 +5410,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"path"},
|
"required": []string{"path"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
||||||
@@ -5485,7 +5485,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"path", "content"},
|
"required": []string{"path", "content"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
||||||
@@ -5512,7 +5512,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"name"},
|
"required": []string{"name"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"},
|
"parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"},
|
||||||
@@ -5552,7 +5552,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"path", "newName"},
|
"required": []string{"path", "newName"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"},
|
"path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"},
|
||||||
@@ -5646,7 +5646,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"platform", "text"},
|
"required": []string{"platform", "text"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}},
|
"platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}},
|
||||||
@@ -5712,7 +5712,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"name"},
|
"required": []string{"name"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"},
|
"filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"},
|
||||||
@@ -5932,7 +5932,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"path"},
|
"required": []string{"path"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
||||||
@@ -5974,7 +5974,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"content": map[string]interface{}{
|
"content": map[string]interface{}{
|
||||||
"application/json": map[string]interface{}{
|
"application/json": map[string]interface{}{
|
||||||
"schema": map[string]interface{}{
|
"schema": map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": []string{"ids"},
|
"required": []string{"ids"},
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"ids": map[string]interface{}{
|
"ids": map[string]interface{}{
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ var apiDocI18nSummaryToKey = map[string]string{
|
|||||||
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
|
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
|
||||||
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
|
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
|
||||||
"从分组移除对话": "removeConversationFromGroup",
|
"从分组移除对话": "removeConversationFromGroup",
|
||||||
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
|
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
|
||||||
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
|
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
|
||||||
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
|
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
|
||||||
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
|
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
|
||||||
@@ -52,9 +52,9 @@ var apiDocI18nSummaryToKey = map[string]string{
|
|||||||
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
|
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
|
||||||
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
|
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
|
||||||
"获取所有分组映射": "getAllGroupMappings",
|
"获取所有分组映射": "getAllGroupMappings",
|
||||||
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
|
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
|
||||||
"测试OpenAI API连接": "testOpenAI",
|
"测试OpenAI API连接": "testOpenAI",
|
||||||
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
|
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
|
||||||
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
|
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
|
||||||
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
|
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
|
||||||
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
|
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
|
||||||
@@ -69,7 +69,7 @@ var apiDocI18nSummaryToKey = map[string]string{
|
|||||||
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
|
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
|
||||||
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
|
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
|
||||||
"批量获取工具名称": "batchGetToolNames",
|
"批量获取工具名称": "batchGetToolNames",
|
||||||
"获取知识库统计": "getKnowledgeStats",
|
"获取知识库统计": "getKnowledgeStats",
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiDocI18nResponseDescToKey = map[string]string{
|
var apiDocI18nResponseDescToKey = map[string]string{
|
||||||
@@ -78,7 +78,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
|
|||||||
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
|
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
|
||||||
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
|
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
|
||||||
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
|
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
|
||||||
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
|
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
|
||||||
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
|
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
|
||||||
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
|
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
|
||||||
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
|
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
|
||||||
@@ -89,7 +89,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
|
|||||||
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events)": "streamResponse",
|
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events)": "streamResponse",
|
||||||
// 新增缺失端点响应
|
// 新增缺失端点响应
|
||||||
"参数错误或删除失败": "badRequestOrDeleteFailed",
|
"参数错误或删除失败": "badRequestOrDeleteFailed",
|
||||||
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
|
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
|
||||||
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
|
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
|
||||||
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
|
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
|
||||||
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
|
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
|
||||||
|
|||||||
+14
-14
@@ -28,20 +28,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
robotCmdHelp = "帮助"
|
robotCmdHelp = "帮助"
|
||||||
robotCmdList = "列表"
|
robotCmdList = "列表"
|
||||||
robotCmdListAlt = "对话列表"
|
robotCmdListAlt = "对话列表"
|
||||||
robotCmdSwitch = "切换"
|
robotCmdSwitch = "切换"
|
||||||
robotCmdContinue = "继续"
|
robotCmdContinue = "继续"
|
||||||
robotCmdNew = "新对话"
|
robotCmdNew = "新对话"
|
||||||
robotCmdClear = "清空"
|
robotCmdClear = "清空"
|
||||||
robotCmdCurrent = "当前"
|
robotCmdCurrent = "当前"
|
||||||
robotCmdStop = "停止"
|
robotCmdStop = "停止"
|
||||||
robotCmdRoles = "角色"
|
robotCmdRoles = "角色"
|
||||||
robotCmdRolesList = "角色列表"
|
robotCmdRolesList = "角色列表"
|
||||||
robotCmdSwitchRole = "切换角色"
|
robotCmdSwitchRole = "切换角色"
|
||||||
robotCmdDelete = "删除"
|
robotCmdDelete = "删除"
|
||||||
robotCmdVersion = "版本"
|
robotCmdVersion = "版本"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
|
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
|
||||||
|
|||||||
+13
-13
@@ -65,19 +65,19 @@ func (h *SkillsHandler) GetSkills(c *gin.Context) {
|
|||||||
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
|
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
|
||||||
for _, s := range allSummaries {
|
for _, s := range allSummaries {
|
||||||
skillInfo := map[string]interface{}{
|
skillInfo := map[string]interface{}{
|
||||||
"id": s.ID,
|
"id": s.ID,
|
||||||
"name": s.Name,
|
"name": s.Name,
|
||||||
"dir_name": s.DirName,
|
"dir_name": s.DirName,
|
||||||
"description": s.Description,
|
"description": s.Description,
|
||||||
"version": s.Version,
|
"version": s.Version,
|
||||||
"path": s.Path,
|
"path": s.Path,
|
||||||
"tags": s.Tags,
|
"tags": s.Tags,
|
||||||
"triggers": s.Triggers,
|
"triggers": s.Triggers,
|
||||||
"script_count": s.ScriptCount,
|
"script_count": s.ScriptCount,
|
||||||
"file_count": s.FileCount,
|
"file_count": s.FileCount,
|
||||||
"progressive": s.Progressive,
|
"progressive": s.Progressive,
|
||||||
"file_size": s.FileSize,
|
"file_size": s.FileSize,
|
||||||
"mod_time": s.ModTime,
|
"mod_time": s.ModTime,
|
||||||
}
|
}
|
||||||
allSkillsInfo = append(allSkillsInfo, skillInfo)
|
allSkillsInfo = append(allSkillsInfo, skillInfo)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,4 +109,3 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
|
|||||||
|
|
||||||
<-doneChan
|
<-doneChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
|
|||||||
|
|
||||||
// CreateVulnerabilityRequest 创建漏洞请求
|
// CreateVulnerabilityRequest 创建漏洞请求
|
||||||
type CreateVulnerabilityRequest struct {
|
type CreateVulnerabilityRequest struct {
|
||||||
ConversationID string `json:"conversation_id" binding:"required"`
|
ConversationID string `json:"conversation_id" binding:"required"`
|
||||||
ConversationTag string `json:"conversation_tag"`
|
ConversationTag string `json:"conversation_tag"`
|
||||||
TaskTag string `json:"task_tag"`
|
TaskTag string `json:"task_tag"`
|
||||||
Title string `json:"title" binding:"required"`
|
Title string `json:"title" binding:"required"`
|
||||||
@@ -51,18 +51,18 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vuln := &database.Vulnerability{
|
vuln := &database.Vulnerability{
|
||||||
ConversationID: req.ConversationID,
|
ConversationID: req.ConversationID,
|
||||||
ConversationTag: req.ConversationTag,
|
ConversationTag: req.ConversationTag,
|
||||||
TaskTag: req.TaskTag,
|
TaskTag: req.TaskTag,
|
||||||
Title: req.Title,
|
Title: req.Title,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Severity: req.Severity,
|
Severity: req.Severity,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
Type: req.Type,
|
Type: req.Type,
|
||||||
Target: req.Target,
|
Target: req.Target,
|
||||||
Proof: req.Proof,
|
Proof: req.Proof,
|
||||||
Impact: req.Impact,
|
Impact: req.Impact,
|
||||||
Recommendation: req.Recommendation,
|
Recommendation: req.Recommendation,
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := h.db.CreateVulnerability(vuln)
|
created, err := h.db.CreateVulnerability(vuln)
|
||||||
@@ -172,15 +172,15 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
|||||||
type UpdateVulnerabilityRequest struct {
|
type UpdateVulnerabilityRequest struct {
|
||||||
ConversationTag string `json:"conversation_tag"`
|
ConversationTag string `json:"conversation_tag"`
|
||||||
TaskTag string `json:"task_tag"`
|
TaskTag string `json:"task_tag"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Severity string `json:"severity"`
|
Severity string `json:"severity"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Target string `json:"target"`
|
Target string `json:"target"`
|
||||||
Proof string `json:"proof"`
|
Proof string `json:"proof"`
|
||||||
Impact string `json:"impact"`
|
Impact string `json:"impact"`
|
||||||
Recommendation string `json:"recommendation"`
|
Recommendation string `json:"recommendation"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateVulnerability 更新漏洞
|
// UpdateVulnerability 更新漏洞
|
||||||
@@ -460,4 +460,3 @@ func sanitizeExportName(raw string) string {
|
|||||||
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
|
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
|
||||||
return replacer.Replace(name)
|
return replacer.Replace(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ const (
|
|||||||
|
|
||||||
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
|
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
|
||||||
const (
|
const (
|
||||||
DSLRiskType = "risk_type"
|
DSLRiskType = "risk_type"
|
||||||
DSLSimilarityThreshold = "similarity_threshold"
|
DSLSimilarityThreshold = "similarity_threshold"
|
||||||
DSLSubIndexFilter = "sub_index_filter"
|
DSLSubIndexFilter = "sub_index_filter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
|
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/compose"
|
|
||||||
"github.com/cloudwego/eino/components/document"
|
"github.com/cloudwego/eino/components/document"
|
||||||
|
"github.com/cloudwego/eino/compose"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
|
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
|
||||||
"github.com/cloudwego/eino/compose"
|
|
||||||
"github.com/cloudwego/eino/components/document"
|
"github.com/cloudwego/eino/components/document"
|
||||||
"github.com/cloudwego/eino/components/indexer"
|
"github.com/cloudwego/eino/components/indexer"
|
||||||
|
"github.com/cloudwego/eino/compose"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -35,14 +35,14 @@ type Indexer struct {
|
|||||||
lastErrorTime time.Time
|
lastErrorTime time.Time
|
||||||
errorCount int
|
errorCount int
|
||||||
|
|
||||||
rebuildMu sync.RWMutex
|
rebuildMu sync.RWMutex
|
||||||
isRebuilding bool
|
isRebuilding bool
|
||||||
rebuildTotalItems int
|
rebuildTotalItems int
|
||||||
rebuildCurrent int
|
rebuildCurrent int
|
||||||
rebuildFailed int
|
rebuildFailed int
|
||||||
rebuildStartTime time.Time
|
rebuildStartTime time.Time
|
||||||
rebuildLastItemID string
|
rebuildLastItemID string
|
||||||
rebuildLastChunks int
|
rebuildLastChunks int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
|
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
|
||||||
|
|||||||
@@ -108,9 +108,9 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
|
|||||||
|
|
||||||
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
|
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
|
||||||
type CategoryWithItems struct {
|
type CategoryWithItems struct {
|
||||||
Category string `json:"category"` // 分类名称
|
Category string `json:"category"` // 分类名称
|
||||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchRequest 搜索请求
|
// SearchRequest 搜索请求
|
||||||
|
|||||||
@@ -192,13 +192,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
|||||||
fnName, _ := fn["name"].(string)
|
fnName, _ := fn["name"].(string)
|
||||||
fnArgs, _ := fn["arguments"]
|
fnArgs, _ := fn["arguments"]
|
||||||
|
|
||||||
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
|
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
|
||||||
if strings.TrimSpace(fnName) == "" {
|
if strings.TrimSpace(fnName) == "" {
|
||||||
fnName = "unknown_function"
|
fnName = "unknown_function"
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(tcID) == "" {
|
if strings.TrimSpace(tcID) == "" {
|
||||||
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
|
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
var inputRaw json.RawMessage
|
var inputRaw json.RawMessage
|
||||||
switch v := fnArgs.(type) {
|
switch v := fnArgs.(type) {
|
||||||
|
|||||||
+14
-14
@@ -281,9 +281,9 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
|
|||||||
|
|
||||||
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
|
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
|
||||||
type StreamToolCall struct {
|
type StreamToolCall struct {
|
||||||
Index int
|
Index int
|
||||||
ID string
|
ID string
|
||||||
Type string
|
Type string
|
||||||
FunctionName string
|
FunctionName string
|
||||||
FunctionArgsStr string
|
FunctionArgsStr string
|
||||||
}
|
}
|
||||||
@@ -348,10 +348,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
|||||||
Arguments string `json:"arguments,omitempty"`
|
Arguments string `json:"arguments,omitempty"`
|
||||||
}
|
}
|
||||||
type toolCallDelta struct {
|
type toolCallDelta struct {
|
||||||
Index int `json:"index,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
Function toolCallFunctionDelta `json:"function,omitempty"`
|
Function toolCallFunctionDelta `json:"function,omitempty"`
|
||||||
}
|
}
|
||||||
type streamDelta2 struct {
|
type streamDelta2 struct {
|
||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content,omitempty"`
|
||||||
@@ -371,10 +371,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type toolCallAccum struct {
|
type toolCallAccum struct {
|
||||||
id string
|
id string
|
||||||
typ string
|
typ string
|
||||||
name string
|
name string
|
||||||
args strings.Builder
|
args strings.Builder
|
||||||
}
|
}
|
||||||
toolCallAccums := make(map[int]*toolCallAccum)
|
toolCallAccums := make(map[int]*toolCallAccum)
|
||||||
|
|
||||||
@@ -475,9 +475,9 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
|||||||
for _, idx := range indices {
|
for _, idx := range indices {
|
||||||
acc := toolCallAccums[idx]
|
acc := toolCallAccums[idx]
|
||||||
tc := StreamToolCall{
|
tc := StreamToolCall{
|
||||||
Index: idx,
|
Index: idx,
|
||||||
ID: acc.id,
|
ID: acc.id,
|
||||||
Type: acc.typ,
|
Type: acc.typ,
|
||||||
FunctionName: acc.name,
|
FunctionName: acc.name,
|
||||||
FunctionArgsStr: acc.args.String(),
|
FunctionArgsStr: acc.args.String(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -265,4 +265,3 @@ func TestPaginateLines(t *testing.T) {
|
|||||||
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
|
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ type rateLimitEntry struct {
|
|||||||
|
|
||||||
// RateLimiter 基于 IP 的滑动窗口速率限制器
|
// RateLimiter 基于 IP 的滑动窗口速率限制器
|
||||||
type RateLimiter struct {
|
type RateLimiter struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
entries map[string]*rateLimitEntry
|
entries map[string]*rateLimitEntry
|
||||||
limit int // 窗口内允许的最大请求数
|
limit int // 窗口内允许的最大请求数
|
||||||
window time.Duration // 窗口时长
|
window time.Duration // 窗口时长
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRateLimiter 创建速率限制器
|
// NewRateLimiter 创建速率限制器
|
||||||
|
|||||||
Reference in New Issue
Block a user