mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-03 03:05:57 +02:00
Add files via upload
This commit is contained in:
+142
-4
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
@@ -17,10 +18,13 @@ import (
|
||||
|
||||
// AgentHandler Agent处理器
|
||||
type AgentHandler struct {
|
||||
agent *agent.Agent
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
tasks *AgentTaskManager
|
||||
agent *agent.Agent
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
tasks *AgentTaskManager
|
||||
knowledgeManager interface { // 知识库管理器接口
|
||||
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
|
||||
}
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
@@ -33,6 +37,13 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A
|
||||
}
|
||||
}
|
||||
|
||||
// SetKnowledgeManager 设置知识库管理器(用于记录检索日志)
|
||||
func (h *AgentHandler) SetKnowledgeManager(manager interface {
|
||||
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
|
||||
}) {
|
||||
h.knowledgeManager = manager
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
@@ -271,9 +282,136 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
|
||||
// 用于保存tool_call事件中的参数,以便在tool_result时使用
|
||||
toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments
|
||||
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
sendEvent(eventType, message, data)
|
||||
|
||||
// 保存tool_call事件中的参数
|
||||
if eventType == "tool_call" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
if toolName == "search_knowledge_base" {
|
||||
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
|
||||
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
|
||||
toolCallCache[toolCallId] = argumentsObj
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理知识检索日志记录
|
||||
if eventType == "tool_result" && h.knowledgeManager != nil {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
if toolName == "search_knowledge_base" {
|
||||
// 提取检索信息
|
||||
query := ""
|
||||
riskType := ""
|
||||
var retrievedItems []string
|
||||
|
||||
// 首先尝试从tool_call缓存中获取参数
|
||||
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
|
||||
if cachedArgs, exists := toolCallCache[toolCallId]; exists {
|
||||
if q, ok := cachedArgs["query"].(string); ok && q != "" {
|
||||
query = q
|
||||
}
|
||||
if rt, ok := cachedArgs["risk_type"].(string); ok && rt != "" {
|
||||
riskType = rt
|
||||
}
|
||||
// 使用后清理缓存
|
||||
delete(toolCallCache, toolCallId)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果缓存中没有,尝试从argumentsObj中提取
|
||||
if query == "" {
|
||||
if arguments, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
|
||||
if q, ok := arguments["query"].(string); ok && q != "" {
|
||||
query = q
|
||||
}
|
||||
if rt, ok := arguments["risk_type"].(string); ok && rt != "" {
|
||||
riskType = rt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果query仍然为空,尝试从result中提取(从结果文本的第一行)
|
||||
if query == "" {
|
||||
if result, ok := dataMap["result"].(string); ok && result != "" {
|
||||
// 尝试从结果中提取查询内容(如果结果包含"未找到与查询 'xxx' 相关的知识")
|
||||
if strings.Contains(result, "未找到与查询 '") {
|
||||
start := strings.Index(result, "未找到与查询 '") + len("未找到与查询 '")
|
||||
end := strings.Index(result[start:], "'")
|
||||
if end > 0 {
|
||||
query = result[start : start+end]
|
||||
}
|
||||
}
|
||||
}
|
||||
// 如果还是为空,使用默认值
|
||||
if query == "" {
|
||||
query = "未知查询"
|
||||
}
|
||||
}
|
||||
|
||||
// 从工具结果中提取检索到的知识项ID
|
||||
// 结果格式:"找到 X 条相关知识:\n\n--- 结果 1 (相似度: XX.XX%) ---\n来源: [分类] 标题\n...\n<!-- METADATA: {...} -->"
|
||||
if result, ok := dataMap["result"].(string); ok && result != "" {
|
||||
// 尝试从元数据中提取知识项ID
|
||||
metadataMatch := strings.Index(result, "<!-- METADATA:")
|
||||
if metadataMatch > 0 {
|
||||
// 提取元数据JSON
|
||||
metadataStart := metadataMatch + len("<!-- METADATA: ")
|
||||
metadataEnd := strings.Index(result[metadataStart:], " -->")
|
||||
if metadataEnd > 0 {
|
||||
metadataJSON := result[metadataStart : metadataStart+metadataEnd]
|
||||
var metadata map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(metadataJSON), &metadata); err == nil {
|
||||
if meta, ok := metadata["_metadata"].(map[string]interface{}); ok {
|
||||
if ids, ok := meta["retrievedItemIDs"].([]interface{}); ok {
|
||||
retrievedItems = make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if idStr, ok := id.(string); ok {
|
||||
retrievedItems = append(retrievedItems, idStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有从元数据中提取到,但结果包含"找到 X 条",至少标记为有结果
|
||||
if len(retrievedItems) == 0 && strings.Contains(result, "找到") && !strings.Contains(result, "未找到") {
|
||||
// 有结果,但无法准确提取ID,使用特殊标记
|
||||
retrievedItems = []string{"_has_results"}
|
||||
}
|
||||
}
|
||||
|
||||
// 记录检索日志(异步,不阻塞)
|
||||
go func() {
|
||||
if err := h.knowledgeManager.LogRetrieval(conversationID, assistantMessageID, query, riskType, retrievedItems); err != nil {
|
||||
h.logger.Warn("记录知识检索日志失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// 添加知识检索事件到processDetails
|
||||
if assistantMessageID != "" {
|
||||
retrievalData := map[string]interface{}{
|
||||
"query": query,
|
||||
"riskType": riskType,
|
||||
"toolName": toolName,
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "knowledge_retrieval", fmt.Sprintf("检索知识: %s", query), retrievalData); err != nil {
|
||||
h.logger.Warn("保存知识检索详情失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存过程详情到数据库(排除response和done事件,它们会在后面单独处理)
|
||||
if assistantMessageID != "" && eventType != "response" && eventType != "done" {
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user