Add files via upload

This commit is contained in:
公明
2025-12-20 17:36:40 +08:00
committed by GitHub
parent b659fb7445
commit abc4085c8a
21 changed files with 5234 additions and 46 deletions
+142 -4
View File
@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/agent"
@@ -17,10 +18,13 @@ import (
// AgentHandler Agent处理器
type AgentHandler struct {
agent *agent.Agent
db *database.DB
logger *zap.Logger
tasks *AgentTaskManager
agent *agent.Agent
db *database.DB
logger *zap.Logger
tasks *AgentTaskManager
knowledgeManager interface { // 知识库管理器接口
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
}
}
// NewAgentHandler 创建新的Agent处理器
@@ -33,6 +37,13 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A
}
}
// SetKnowledgeManager 设置知识库管理器(用于记录检索日志)
func (h *AgentHandler) SetKnowledgeManager(manager interface {
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
}) {
h.knowledgeManager = manager
}
// ChatRequest 聊天请求
type ChatRequest struct {
Message string `json:"message" binding:"required"`
@@ -271,9 +282,136 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
assistantMessageID = assistantMsg.ID
}
// 用于保存tool_call事件中的参数,以便在tool_result时使用
toolCallCache := make(map[string]map[string]interface{}) // toolCallId -> arguments
progressCallback := func(eventType, message string, data interface{}) {
sendEvent(eventType, message, data)
// 保存tool_call事件中的参数
if eventType == "tool_call" {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" {
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
toolCallCache[toolCallId] = argumentsObj
}
}
}
}
}
// 处理知识检索日志记录
if eventType == "tool_result" && h.knowledgeManager != nil {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" {
// 提取检索信息
query := ""
riskType := ""
var retrievedItems []string
// 首先尝试从tool_call缓存中获取参数
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
if cachedArgs, exists := toolCallCache[toolCallId]; exists {
if q, ok := cachedArgs["query"].(string); ok && q != "" {
query = q
}
if rt, ok := cachedArgs["risk_type"].(string); ok && rt != "" {
riskType = rt
}
// 使用后清理缓存
delete(toolCallCache, toolCallId)
}
}
// 如果缓存中没有,尝试从argumentsObj中提取
if query == "" {
if arguments, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
if q, ok := arguments["query"].(string); ok && q != "" {
query = q
}
if rt, ok := arguments["risk_type"].(string); ok && rt != "" {
riskType = rt
}
}
}
// 如果query仍然为空,尝试从result中提取(从结果文本的第一行)
if query == "" {
if result, ok := dataMap["result"].(string); ok && result != "" {
// 尝试从结果中提取查询内容(如果结果包含"未找到与查询 'xxx' 相关的知识"
if strings.Contains(result, "未找到与查询 '") {
start := strings.Index(result, "未找到与查询 '") + len("未找到与查询 '")
end := strings.Index(result[start:], "'")
if end > 0 {
query = result[start : start+end]
}
}
}
// 如果还是为空,使用默认值
if query == "" {
query = "未知查询"
}
}
// 从工具结果中提取检索到的知识项ID
// 结果格式:"找到 X 条相关知识:\n\n--- 结果 1 (相似度: XX.XX%) ---\n来源: [分类] 标题\n...\n<!-- METADATA: {...} -->"
if result, ok := dataMap["result"].(string); ok && result != "" {
// 尝试从元数据中提取知识项ID
metadataMatch := strings.Index(result, "<!-- METADATA:")
if metadataMatch > 0 {
// 提取元数据JSON
metadataStart := metadataMatch + len("<!-- METADATA: ")
metadataEnd := strings.Index(result[metadataStart:], " -->")
if metadataEnd > 0 {
metadataJSON := result[metadataStart : metadataStart+metadataEnd]
var metadata map[string]interface{}
if err := json.Unmarshal([]byte(metadataJSON), &metadata); err == nil {
if meta, ok := metadata["_metadata"].(map[string]interface{}); ok {
if ids, ok := meta["retrievedItemIDs"].([]interface{}); ok {
retrievedItems = make([]string, 0, len(ids))
for _, id := range ids {
if idStr, ok := id.(string); ok {
retrievedItems = append(retrievedItems, idStr)
}
}
}
}
}
}
}
// 如果没有从元数据中提取到,但结果包含"找到 X 条",至少标记为有结果
if len(retrievedItems) == 0 && strings.Contains(result, "找到") && !strings.Contains(result, "未找到") {
// 有结果,但无法准确提取ID,使用特殊标记
retrievedItems = []string{"_has_results"}
}
}
// 记录检索日志(异步,不阻塞)
go func() {
if err := h.knowledgeManager.LogRetrieval(conversationID, assistantMessageID, query, riskType, retrievedItems); err != nil {
h.logger.Warn("记录知识检索日志失败", zap.Error(err))
}
}()
// 添加知识检索事件到processDetails
if assistantMessageID != "" {
retrievalData := map[string]interface{}{
"query": query,
"riskType": riskType,
"toolName": toolName,
}
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "knowledge_retrieval", fmt.Sprintf("检索知识: %s", query), retrievalData); err != nil {
h.logger.Warn("保存知识检索详情失败", zap.Error(err))
}
}
}
}
}
// 保存过程详情到数据库(排除response和done事件,它们会在后面单独处理)
if assistantMessageID != "" && eventType != "response" && eventType != "done" {
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil {