mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-24 00:14:08 +02:00
Add files via upload
This commit is contained in:
+236
-44
@@ -87,41 +87,30 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
conversationID = conv.ID
|
||||
}
|
||||
|
||||
// 获取历史消息(排除当前消息,因为还没保存)
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
historyMessages = []database.Message{}
|
||||
}
|
||||
|
||||
h.logger.Info("获取历史消息",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("count", len(historyMessages)),
|
||||
)
|
||||
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for i, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
contentPreview := msg.Content
|
||||
if len(contentPreview) > 50 {
|
||||
contentPreview = contentPreview[:50] + "..."
|
||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
||||
// 回退到使用数据库消息表
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
h.logger.Info("添加历史消息",
|
||||
zap.Int("index", i),
|
||||
zap.String("role", msg.Role),
|
||||
zap.String("content", contentPreview),
|
||||
)
|
||||
} else {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
h.logger.Info("历史消息转换完成",
|
||||
zap.Int("originalCount", len(historyMessages)),
|
||||
zap.Int("convertedCount", len(agentHistoryMessages)),
|
||||
)
|
||||
|
||||
// 保存用户消息
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
@@ -132,6 +121,16 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
// 即使执行失败,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil {
|
||||
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr))
|
||||
} else {
|
||||
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -142,6 +141,15 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
h.logger.Error("保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 保存最后一轮ReAct的输入和输出
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ChatResponse{
|
||||
Response: result.Response,
|
||||
MCPExecutionIDs: result.MCPExecutionIDs,
|
||||
@@ -246,20 +254,28 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
|
||||
// 获取历史消息
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
historyMessages = []database.Message{}
|
||||
}
|
||||
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
||||
// 回退到使用数据库消息表
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
@@ -499,6 +515,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务被取消,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
@@ -524,6 +550,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务超时,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
@@ -549,6 +585,16 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务失败,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("error", errorMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
@@ -585,6 +631,15 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 保存最后一轮ReAct的输入和输出
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
// 发送最终响应
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
@@ -632,3 +687,140 @@ func (h *AgentHandler) ListAgentTasks(c *gin.Context) {
|
||||
"tasks": h.tasks.GetActiveTasks(),
|
||||
})
|
||||
}
|
||||
|
||||
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
|
||||
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
|
||||
// 获取保存的ReAct输入和输出
|
||||
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
|
||||
}
|
||||
|
||||
if reactInputJSON == "" {
|
||||
return nil, fmt.Errorf("ReAct数据为空")
|
||||
}
|
||||
|
||||
// 解析JSON格式的messages数组
|
||||
var messagesArray []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
|
||||
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为Agent消息格式
|
||||
agentMessages := make([]agent.ChatMessage, 0, len(messagesArray))
|
||||
for _, msgMap := range messagesArray {
|
||||
msg := agent.ChatMessage{}
|
||||
|
||||
// 解析role
|
||||
if role, ok := msgMap["role"].(string); ok {
|
||||
msg.Role = role
|
||||
} else {
|
||||
continue // 跳过无效消息
|
||||
}
|
||||
|
||||
// 跳过system消息(AgentLoop会重新添加)
|
||||
if msg.Role == "system" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析content
|
||||
if content, ok := msgMap["content"].(string); ok {
|
||||
msg.Content = content
|
||||
}
|
||||
|
||||
// 解析tool_calls(如果存在)
|
||||
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
|
||||
msg.ToolCalls = make([]agent.ToolCall, 0, len(toolCallsArray))
|
||||
for _, tcRaw := range toolCallsArray {
|
||||
if tcMap, ok := tcRaw.(map[string]interface{}); ok {
|
||||
toolCall := agent.ToolCall{}
|
||||
|
||||
// 解析ID
|
||||
if id, ok := tcMap["id"].(string); ok {
|
||||
toolCall.ID = id
|
||||
}
|
||||
|
||||
// 解析Type
|
||||
if toolType, ok := tcMap["type"].(string); ok {
|
||||
toolCall.Type = toolType
|
||||
}
|
||||
|
||||
// 解析Function
|
||||
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||
toolCall.Function = agent.FunctionCall{}
|
||||
|
||||
// 解析函数名
|
||||
if name, ok := funcMap["name"].(string); ok {
|
||||
toolCall.Function.Name = name
|
||||
}
|
||||
|
||||
// 解析arguments(可能是字符串或对象)
|
||||
if argsRaw, ok := funcMap["arguments"]; ok {
|
||||
if argsStr, ok := argsRaw.(string); ok {
|
||||
// 如果是字符串,解析为JSON
|
||||
var argsMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
|
||||
toolCall.Function.Arguments = argsMap
|
||||
}
|
||||
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
|
||||
// 如果已经是对象,直接使用
|
||||
toolCall.Function.Arguments = argsMap
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolCall.ID != "" {
|
||||
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析tool_call_id(tool角色消息)
|
||||
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||
msg.ToolCallID = toolCallID
|
||||
}
|
||||
|
||||
agentMessages = append(agentMessages, msg)
|
||||
}
|
||||
|
||||
// 如果存在last_react_output,需要将其作为最后一条assistant消息
|
||||
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出
|
||||
if reactOutput != "" {
|
||||
// 检查最后一条消息是否是assistant消息且没有tool_calls
|
||||
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
|
||||
if len(agentMessages) > 0 {
|
||||
lastMsg := &agentMessages[len(agentMessages)-1]
|
||||
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
|
||||
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content
|
||||
lastMsg.Content = reactOutput
|
||||
} else {
|
||||
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
|
||||
agentMessages = append(agentMessages, agent.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: reactOutput,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// 如果没有消息,直接添加最终输出
|
||||
agentMessages = append(agentMessages, agent.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: reactOutput,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(agentMessages) == 0 {
|
||||
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
|
||||
}
|
||||
|
||||
h.logger.Info("从ReAct数据恢复历史消息",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("messageCount", len(agentMessages)),
|
||||
zap.Bool("hasReactOutput", reactOutput != ""),
|
||||
)
|
||||
|
||||
return agentMessages, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user