From dd3b1ae219dc0c94b2fee8e981c61f0dba25a9ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Tue, 26 May 2026 14:34:21 +0800 Subject: [PATCH] Add files via upload --- internal/handler/agent.go | 37 ++-- internal/handler/batch_task_manager.go | 11 +- internal/handler/batch_task_mcp.go | 7 +- internal/handler/conversation.go | 32 ++- internal/handler/eino_single_agent.go | 2 + internal/handler/multi_agent.go | 2 + internal/handler/multi_agent_prepare.go | 8 + internal/handler/openapi.go | 138 +++++++++++++ internal/handler/project.go | 262 ++++++++++++++++++++++++ internal/handler/project_context.go | 32 +++ internal/handler/project_resolve.go | 18 ++ internal/handler/robot.go | 8 +- internal/handler/vulnerability.go | 31 +-- internal/handler/webshell_context.go | 2 +- internal/project/blackboard.go | 77 +++++++ 15 files changed, 636 insertions(+), 31 deletions(-) create mode 100644 internal/handler/project.go create mode 100644 internal/handler/project_context.go create mode 100644 internal/handler/project_resolve.go create mode 100644 internal/project/blackboard.go diff --git a/internal/handler/agent.go b/internal/handler/agent.go index be616ae9..651b73e5 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -222,6 +222,7 @@ type ChatReasoningRequest struct { type ChatRequest struct { Message string `json:"message" binding:"required"` ConversationID string `json:"conversationId,omitempty"` + ProjectID string `json:"projectId,omitempty"` // 新对话绑定的项目(可选;未指定时可用 config.project.default_project_id) Role string `json:"role,omitempty"` // 角色名称 Attachments []ChatAttachment `json:"attachments,omitempty"` WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具 @@ -560,7 +561,9 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { conversationID := req.ConversationID if conversationID == "" { title := safeTruncateString(req.Message, 50) - conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "agent_loop")) + meta := audit.ConversationCreateMetaFromGin(c, "agent_loop") + meta.ProjectID = effectiveProjectID(h.config, req.ProjectID) + conv, err := h.db.CreateConversation(title, meta) if err != nil { h.logger.Error("创建对话失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -635,6 +638,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { builtin.ToolWebshellFileRead, builtin.ToolWebshellFileWrite, builtin.ToolRecordVulnerability, + builtin.ToolListVulnerabilities, + builtin.ToolGetVulnerability, builtin.ToolListKnowledgeRiskTypes, builtin.ToolSearchKnowledgeBase, } @@ -682,7 +687,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, "", nil) // 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表) - result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools) + result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID)) if err != nil { h.logger.Error("Agent Loop执行失败", zap.Error(err)) @@ -760,7 +765,9 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con if strings.TrimSpace(platform) != "" { src = "robot:" + strings.TrimSpace(platform) } - conv, createErr := h.db.CreateConversation(title, audit.ConversationCreateMeta(src)) + meta := audit.ConversationCreateMeta(src) + meta.ProjectID = effectiveProjectID(h.config, "") + conv, createErr := h.db.CreateConversation(title, meta) if createErr != nil { return "", "", fmt.Errorf("创建对话失败: %w", createErr) } @@ -839,7 +846,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con for { resultMA, errMA = multiagent.RunEinoSingleChatModelAgent( taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, - conversationID, curMsg, curHist, roleTools, progressCallback, nil, + conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID), ) if errMA == nil { // 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。 @@ -872,7 +879,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con resultMA, errMA = multiagent.RunDeepAgent( taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, curMsg, curHist, roleTools, progressCallback, - h.agentsMarkdownDir, robotMode, nil, + h.agentsMarkdownDir, robotMode, nil, h.projectBlackboardBlock(conversationID), ) if errMA == nil { // 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。 @@ -891,7 +898,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA) } - result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools) + result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID)) if err != nil { taskStatus = "failed" errMsg := "执行失败: " + err.Error() @@ -1518,6 +1525,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { var conv *database.Conversation var err error meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream") + meta.ProjectID = effectiveProjectID(h.config, req.ProjectID) if req.WebShellConnectionID != "" { meta.Source = "webshell_chat" conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta) @@ -1595,6 +1603,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { builtin.ToolWebshellFileRead, builtin.ToolWebshellFileWrite, builtin.ToolRecordVulnerability, + builtin.ToolListVulnerabilities, + builtin.ToolGetVulnerability, builtin.ToolListKnowledgeRiskTypes, builtin.ToolSearchKnowledgeBase, } @@ -1725,7 +1735,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { go sseKeepalive(c, stopKeepalive, &sseWriteMu) defer close(stopKeepalive) - result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools) + result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID)) if err != nil { h.logger.Error("Agent Loop执行失败", zap.Error(err)) cause := context.Cause(baseCtx) @@ -2037,6 +2047,7 @@ type BatchTaskRequest struct { ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false) + ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选) } func normalizeBatchQueueAgentMode(mode string) string { @@ -2117,7 +2128,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) { nextRunAt = &next } - queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks) + queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks) if createErr != nil { c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()}) return @@ -2651,7 +2662,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { // 创建新对话 title := safeTruncateString(task.Message, 50) - conv, err := h.db.CreateConversation(title, audit.ConversationCreateMeta("batch_task")) + batchMeta := audit.ConversationCreateMeta("batch_task") + batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID) + conv, err := h.db.CreateConversation(title, batchMeta) var conversationID string if err != nil { h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) @@ -2801,15 +2814,15 @@ func (h *AgentHandler) executeBatchQueue(queueID string) { var runErr error switch { case useBatchMulti: - resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil) + resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID)) case useEinoSingle: if h.config == nil { runErr = fmt.Errorf("服务器配置未加载") } else { - resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil) + resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID)) } default: - result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools) + result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID)) } if runErr != nil { diff --git a/internal/handler/batch_task_manager.go b/internal/handler/batch_task_manager.go index 572588b1..b33f9dd9 100644 --- a/internal/handler/batch_task_manager.go +++ b/internal/handler/batch_task_manager.go @@ -65,6 +65,7 @@ type BatchTaskQueue struct { LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"` LastScheduleError string `json:"lastScheduleError,omitempty"` LastRunError string `json:"lastRunError,omitempty"` + ProjectID string `json:"projectId,omitempty"` Tasks []*BatchTask `json:"tasks"` Status string `json:"status"` // pending, running, paused, completed, cancelled CreatedAt time.Time `json:"createdAt"` @@ -103,7 +104,7 @@ func (m *BatchTaskManager) SetDB(db *database.DB) { // CreateBatchQueue 创建批量任务队列 func (m *BatchTaskManager) CreateBatchQueue( - title, role, agentMode, scheduleMode, cronExpr string, + title, role, agentMode, scheduleMode, cronExpr, projectID string, nextRunAt *time.Time, tasks []string, ) (*BatchTaskQueue, error) { @@ -126,6 +127,7 @@ func (m *BatchTaskManager) CreateBatchQueue( ID: queueID, Title: title, Role: role, + ProjectID: strings.TrimSpace(projectID), AgentMode: normalizeBatchQueueAgentMode(agentMode), ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode), CronExpr: strings.TrimSpace(cronExpr), @@ -171,6 +173,7 @@ func (m *BatchTaskManager) CreateBatchQueue( queue.ScheduleMode, queue.CronExpr, queue.NextRunAt, + queue.ProjectID, dbTasks, ); err != nil { m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err)) @@ -263,6 +266,9 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue { if queueRow.LastRunError.Valid { queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) } + if queueRow.ProjectID.Valid { + queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) + } if queueRow.StartedAt.Valid { queue.StartedAt = &queueRow.StartedAt.Time } @@ -499,6 +505,9 @@ func (m *BatchTaskManager) LoadFromDB() error { if queueRow.LastRunError.Valid { queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String) } + if queueRow.ProjectID.Valid { + queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) + } if queueRow.StartedAt.Valid { queue.StartedAt = &queueRow.StartedAt.Time } diff --git a/internal/handler/batch_task_mcp.go b/internal/handler/batch_task_mcp.go index 5512e1f2..27886b6c 100644 --- a/internal/handler/batch_task_mcp.go +++ b/internal/handler/batch_task_mcp.go @@ -176,6 +176,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z "type": "boolean", "description": "创建后是否立即开始执行队列,默认 false(pending,需 batch_task_start)", }, + "project_id": map[string]interface{}{ + "type": "string", + "description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)", + }, }, }, }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { @@ -204,7 +208,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z if !ok { executeNow = false } - queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks) + projectID := strings.TrimSpace(mcpArgString(args, "project_id")) + queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks) if createErr != nil { return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil } diff --git a/internal/handler/conversation.go b/internal/handler/conversation.go index e3e62c98..4d5849e3 100644 --- a/internal/handler/conversation.go +++ b/internal/handler/conversation.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "strconv" + "strings" "cyberstrike-ai/internal/audit" "cyberstrike-ai/internal/database" @@ -33,7 +34,13 @@ func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHa // CreateConversationRequest 创建对话请求 type CreateConversationRequest struct { - Title string `json:"title"` + Title string `json:"title"` + ProjectID string `json:"projectId,omitempty"` +} + +// SetConversationProjectRequest 设置对话所属项目 +type SetConversationProjectRequest struct { + ProjectID string `json:"projectId"` // 空字符串表示解除绑定 } // CreateConversation 创建新对话 @@ -49,7 +56,9 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) { title = "新对话" } - conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api")) + meta := audit.ConversationCreateMetaFromGin(c, "api") + meta.ProjectID = strings.TrimSpace(req.ProjectID) + conv, err := h.db.CreateConversation(title, meta) if err != nil { h.logger.Error("创建对话失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -59,6 +68,25 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) { c.JSON(http.StatusOK, conv) } +// SetConversationProject 设置或清除对话绑定的项目 +func (h *ConversationHandler) SetConversationProject(c *gin.Context) { + id := c.Param("id") + var req SetConversationProjectRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if _, err := h.db.GetConversation(id); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)}) +} + // ListConversations 列出对话 func (h *ConversationHandler) ListConversations(c *gin.Context) { limitStr := c.DefaultQuery("limit", "50") diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go index 2a3e644c..7fd02f03 100644 --- a/internal/handler/eino_single_agent.go +++ b/internal/handler/eino_single_agent.go @@ -230,6 +230,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { roleTools, progressCallback, chatReasoningToClientIntent(req.Reasoning), + h.projectBlackboardBlock(conversationID), ) if result != nil && len(result.MCPExecutionIDs) > 0 { @@ -429,6 +430,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) { prep.RoleTools, progressCallback, chatReasoningToClientIntent(req.Reasoning), + h.projectBlackboardBlock(prep.ConversationID), ) if runErr != nil { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index b83a7a6b..9b97fd21 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -242,6 +242,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { h.agentsMarkdownDir, orch, chatReasoningToClientIntent(req.Reasoning), + h.projectBlackboardBlock(conversationID), ) if result != nil && len(result.MCPExecutionIDs) > 0 { @@ -443,6 +444,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) { h.agentsMarkdownDir, strings.TrimSpace(req.Orchestration), chatReasoningToClientIntent(req.Reasoning), + h.projectBlackboardBlock(prep.ConversationID), ) if runErr != nil { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { diff --git a/internal/handler/multi_agent_prepare.go b/internal/handler/multi_agent_prepare.go index 3ce2e042..5111319a 100644 --- a/internal/handler/multi_agent_prepare.go +++ b/internal/handler/multi_agent_prepare.go @@ -36,6 +36,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context var conv *database.Conversation var err error meta := audit.ConversationCreateMetaFromGin(c, source) + meta.ProjectID = effectiveProjectID(h.config, req.ProjectID) if strings.TrimSpace(req.WebShellConnectionID) != "" { meta.Source = source + "_webshell" meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID) @@ -90,6 +91,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context builtin.ToolWebshellFileRead, builtin.ToolWebshellFileWrite, builtin.ToolRecordVulnerability, + builtin.ToolListVulnerabilities, + builtin.ToolGetVulnerability, + builtin.ToolUpsertProjectFact, + builtin.ToolGetProjectFact, + builtin.ToolListProjectFacts, + builtin.ToolSearchProjectFacts, + builtin.ToolDeprecateProjectFact, builtin.ToolListKnowledgeRiskTypes, builtin.ToolSearchKnowledgeBase, } diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index 15de9ab1..6b7855ae 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -73,8 +73,22 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "description": "对话标题", "example": "Web应用安全测试", }, + "projectId": map[string]interface{}{ + "type": "string", + "description": "绑定的项目 ID(可选,共享事实黑板)", + }, }, }, + "SetConversationProjectRequest": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "projectId": map[string]interface{}{ + "type": "string", + "description": "项目 ID;空字符串表示解除绑定", + }, + }, + "required": []string{"projectId"}, + }, "Conversation": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ @@ -98,6 +112,10 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "format": "date-time", "description": "更新时间", }, + "projectId": map[string]interface{}{ + "type": "string", + "description": "绑定的项目 ID(可选)", + }, }, }, "ConversationDetail": map[string]interface{}{ @@ -1326,6 +1344,37 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { }, }, }, + "/api/conversations/{id}/project": map[string]interface{}{ + "put": map[string]interface{}{ + "tags": []string{"对话管理"}, + "summary": "设置对话所属项目", + "description": "绑定或解除对话与项目的关联,用于共享事实黑板", + "operationId": "setConversationProject", + "parameters": []map[string]interface{}{ + { + "name": "id", "in": "path", "required": true, + "description": "对话ID", + "schema": map[string]interface{}{"type": "string"}, + }, + }, + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "$ref": "#/components/schemas/SetConversationProjectRequest", + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "设置成功"}, + "400": map[string]interface{}{"description": "项目不存在或参数错误"}, + "404": map[string]interface{}{"description": "对话不存在"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, "/api/conversations/{id}/results": map[string]interface{}{ "get": map[string]interface{}{ "tags": []string{"对话管理"}, @@ -2444,6 +2493,86 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { }, }, }, + "/api/projects": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"项目管理"}, + "summary": "列出项目", + "operationId": "listProjects", + "parameters": []map[string]interface{}{ + {"name": "status", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"active", "archived"}}}, + {"name": "limit", "in": "query", "schema": map[string]interface{}{"type": "integer", "default": 200}}, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "项目列表"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + "post": map[string]interface{}{ + "tags": []string{"项目管理"}, + "summary": "创建项目", + "operationId": "createProject", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{"type": "string"}, + "description": map[string]interface{}{"type": "string"}, + "scope_json": map[string]interface{}{"type": "string"}, + }, + "required": []string{"name"}, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{"description": "创建成功"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, + "/api/projects/{id}": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "获取项目", "operationId": "getProject", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "项目详情"}}, + }, + "put": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "更新项目", "operationId": "updateProject", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "更新成功"}}, + }, + "delete": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "删除项目", "operationId": "deleteProject", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}}, + }, + }, + "/api/projects/{id}/facts": map[string]interface{}{ + "get": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "列出或按 key 获取事实", "operationId": "listProjectFacts", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + {"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}}, + }, + "post": map[string]interface{}{ + "tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST", + "parameters": []map[string]interface{}{ + {"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}}, + }, + "responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}}, + }, + }, "/api/vulnerabilities": map[string]interface{}{ "get": map[string]interface{}{ "tags": []string{"漏洞管理"}, @@ -2502,6 +2631,15 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { "type": "string", }, }, + { + "name": "project_id", + "in": "query", + "required": false, + "description": "项目ID", + "schema": map[string]interface{}{ + "type": "string", + }, + }, { "name": "severity", "in": "query", diff --git a/internal/handler/project.go b/internal/handler/project.go new file mode 100644 index 00000000..ddf9bc15 --- /dev/null +++ b/internal/handler/project.go @@ -0,0 +1,262 @@ +package handler + +import ( + "net/http" + "strconv" + "strings" + + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ProjectHandler 项目管理处理器。 +type ProjectHandler struct { + db *database.DB + logger *zap.Logger +} + +// NewProjectHandler 创建项目管理处理器。 +func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler { + return &ProjectHandler{db: db, logger: logger} +} + +type createProjectRequest struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + ScopeJSON string `json:"scope_json"` + Status string `json:"status"` +} + +type updateProjectRequest struct { + Name string `json:"name"` + Description string `json:"description"` + ScopeJSON string `json:"scope_json"` + Status string `json:"status"` + Pinned *bool `json:"pinned"` +} + +// CreateProject POST /api/projects +func (h *ProjectHandler) CreateProject(c *gin.Context) { + var req createProjectRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + p := &database.Project{ + Name: strings.TrimSpace(req.Name), + Description: req.Description, + ScopeJSON: req.ScopeJSON, + Status: strings.TrimSpace(req.Status), + } + created, err := h.db.CreateProject(p) + if err != nil { + h.logger.Error("创建项目失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, created) +} + +// ListProjects GET /api/projects +func (h *ProjectHandler) ListProjects(c *gin.Context) { + status := c.Query("status") + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200")) + offset, _ := strconv.Atoi(c.Query("offset")) + list, err := h.db.ListProjects(status, limit, offset) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if list == nil { + list = []*database.Project{} + } + c.JSON(http.StatusOK, list) +} + +// GetProject GET /api/projects/:id +func (h *ProjectHandler) GetProject(c *gin.Context) { + p, err := h.db.GetProject(c.Param("id")) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"}) + return + } + c.JSON(http.StatusOK, p) +} + +// UpdateProject PUT /api/projects/:id +func (h *ProjectHandler) UpdateProject(c *gin.Context) { + id := c.Param("id") + p, err := h.db.GetProject(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"}) + return + } + var req updateProjectRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if s := strings.TrimSpace(req.Name); s != "" { + p.Name = s + } + if req.Description != "" || c.Request.ContentLength > 0 { + p.Description = req.Description + } + if req.ScopeJSON != "" || c.GetHeader("Content-Type") != "" { + p.ScopeJSON = req.ScopeJSON + } + if s := strings.TrimSpace(req.Status); s != "" { + p.Status = s + } + if req.Pinned != nil { + p.Pinned = *req.Pinned + } + if err := h.db.UpdateProject(p); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, p) +} + +// DeleteProject DELETE /api/projects/:id +func (h *ProjectHandler) DeleteProject(c *gin.Context) { + if err := h.db.DeleteProject(c.Param("id")); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +type upsertFactRequest struct { + FactKey string `json:"fact_key" binding:"required"` + Category string `json:"category"` + Summary string `json:"summary" binding:"required"` + Body string `json:"body"` + Confidence string `json:"confidence"` + Pinned bool `json:"pinned"` + RelatedVulnerabilityID string `json:"related_vulnerability_id"` +} + +// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情) +func (h *ProjectHandler) ListFacts(c *gin.Context) { + projectID := c.Param("id") + if key := strings.TrimSpace(c.Query("fact_key")); key != "" { + f, err := h.db.GetProjectFactByKey(projectID, key) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, f) + return + } + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) + offset, _ := strconv.Atoi(c.Query("offset")) + filter := database.ProjectFactListFilter{ + Category: c.Query("category"), + Confidence: c.Query("confidence"), + Search: c.Query("search"), + } + list, err := h.db.ListProjectFacts(projectID, filter, limit, offset) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if list == nil { + list = []*database.ProjectFact{} + } + c.JSON(http.StatusOK, list) +} + +// CreateFact POST /api/projects/:id/facts +func (h *ProjectHandler) CreateFact(c *gin.Context) { + var req upsertFactRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + f := &database.ProjectFact{ + ProjectID: c.Param("id"), + FactKey: req.FactKey, + Category: req.Category, + Summary: req.Summary, + Body: req.Body, + Confidence: req.Confidence, + Pinned: req.Pinned, + RelatedVulnerabilityID: req.RelatedVulnerabilityID, + } + created, err := h.db.UpsertProjectFact(f) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, created) +} + +// UpdateFact PUT /api/projects/:id/facts/:factId +func (h *ProjectHandler) UpdateFact(c *gin.Context) { + existing, err := h.db.GetProjectFact(c.Param("factId")) + if err != nil || existing.ProjectID != c.Param("id") { + c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"}) + return + } + var req upsertFactRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if k := strings.TrimSpace(req.FactKey); k != "" { + existing.FactKey = k + } + if req.Category != "" { + existing.Category = req.Category + } + if req.Summary != "" { + existing.Summary = req.Summary + } + existing.Body = req.Body + if req.Confidence != "" { + existing.Confidence = req.Confidence + } + existing.Pinned = req.Pinned + existing.RelatedVulnerabilityID = req.RelatedVulnerabilityID + updated, err := h.db.UpsertProjectFact(existing) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, updated) +} + +// DeleteFact DELETE /api/projects/:id/facts/:factId +func (h *ProjectHandler) DeleteFact(c *gin.Context) { + existing, err := h.db.GetProjectFact(c.Param("factId")) + if err != nil || existing.ProjectID != c.Param("id") { + c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"}) + return + } + if err := h.db.DeleteProjectFact(existing.ID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true}) +} + +type deprecateFactRequest struct { + FactKey string `json:"fact_key" binding:"required"` +} + +// DeprecateFact POST /api/projects/:id/facts/deprecate +func (h *ProjectHandler) DeprecateFact(c *gin.Context) { + var req deprecateFactRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"success": true}) +} diff --git a/internal/handler/project_context.go b/internal/handler/project_context.go new file mode 100644 index 00000000..66e0f943 --- /dev/null +++ b/internal/handler/project_context.go @@ -0,0 +1,32 @@ +package handler + +import ( + "strings" + + "cyberstrike-ai/internal/project" + "go.uber.org/zap" +) + +// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。 +func (h *AgentHandler) projectBlackboardBlock(conversationID string) string { + if h == nil || h.db == nil || h.config == nil { + return "" + } + if !h.config.Project.Enabled { + return "" + } + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return "" + } + projectID, err := h.db.GetConversationProjectID(conversationID) + if err != nil || projectID == "" { + return "" + } + block, err := project.BuildFactIndexBlock(h.db, projectID, h.config.Project) + if err != nil { + h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err)) + return "" + } + return strings.TrimSpace(block) +} diff --git a/internal/handler/project_resolve.go b/internal/handler/project_resolve.go new file mode 100644 index 00000000..88885838 --- /dev/null +++ b/internal/handler/project_resolve.go @@ -0,0 +1,18 @@ +package handler + +import ( + "strings" + + "cyberstrike-ai/internal/config" +) + +// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。 +func effectiveProjectID(cfg *config.Config, explicit string) string { + if pid := strings.TrimSpace(explicit); pid != "" { + return pid + } + if cfg != nil { + return strings.TrimSpace(cfg.Project.DefaultProjectID) + } + return "" +} diff --git a/internal/handler/robot.go b/internal/handler/robot.go index 2f4aa8de..9797e95a 100644 --- a/internal/handler/robot.go +++ b/internal/handler/robot.go @@ -133,7 +133,9 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) ( } else { t = safeTruncateString(t, 50) } - conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform}) + meta := database.ConversationCreateMeta{Source: "robot:" + platform} + meta.ProjectID = effectiveProjectID(h.config, "") + conv, err := h.db.CreateConversation(t, meta) if err != nil { h.logger.Warn("创建机器人会话失败", zap.Error(err)) return "", false @@ -188,7 +190,9 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) { // clearConversation 清空当前会话(切换到新对话) func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) { title := "新对话 " + time.Now().Format("01-02 15:04") - conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"}) + meta := database.ConversationCreateMeta{Source: "robot:" + platform + ":new"} + meta.ProjectID = effectiveProjectID(h.config, "") + conv, err := h.db.CreateConversation(title, meta) if err != nil { h.logger.Warn("创建新对话失败", zap.Error(err)) return "" diff --git a/internal/handler/vulnerability.go b/internal/handler/vulnerability.go index f9e83395..237f91ad 100644 --- a/internal/handler/vulnerability.go +++ b/internal/handler/vulnerability.go @@ -36,6 +36,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability // CreateVulnerabilityRequest 创建漏洞请求 type CreateVulnerabilityRequest struct { ConversationID string `json:"conversation_id" binding:"required"` + ProjectID string `json:"project_id"` ConversationTag string `json:"conversation_tag"` TaskTag string `json:"task_tag"` Title string `json:"title" binding:"required"` @@ -59,6 +60,7 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) { vuln := &database.Vulnerability{ ConversationID: req.ConversationID, + ProjectID: strings.TrimSpace(req.ProjectID), ConversationTag: req.ConversationTag, TaskTag: req.TaskTag, Title: req.Title, @@ -116,6 +118,7 @@ func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilt q = strings.TrimSpace(c.Query("search")) } return database.VulnerabilityListFilter{ + ProjectID: c.Query("project_id"), ID: c.Query("id"), Search: q, ConversationID: c.Query("conversation_id"), @@ -193,17 +196,18 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { // UpdateVulnerabilityRequest 更新漏洞请求 type UpdateVulnerabilityRequest struct { - ConversationTag string `json:"conversation_tag"` - TaskTag string `json:"task_tag"` - Title string `json:"title"` - Description string `json:"description"` - Severity string `json:"severity"` - Status string `json:"status"` - Type string `json:"type"` - Target string `json:"target"` - Proof string `json:"proof"` - Impact string `json:"impact"` - Recommendation string `json:"recommendation"` + ProjectID *string `json:"project_id"` + ConversationTag string `json:"conversation_tag"` + TaskTag string `json:"task_tag"` + Title string `json:"title"` + Description string `json:"description"` + Severity string `json:"severity"` + Status string `json:"status"` + Type string `json:"type"` + Target string `json:"target"` + Proof string `json:"proof"` + Impact string `json:"impact"` + Recommendation string `json:"recommendation"` } // UpdateVulnerability 更新漏洞 @@ -224,6 +228,9 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) { } // 更新字段 + if req.ProjectID != nil { + existing.ProjectID = strings.TrimSpace(*req.ProjectID) + } if req.ConversationTag != "" { existing.ConversationTag = req.ConversationTag } @@ -274,7 +281,7 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) { if h.audit != nil { h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{ - "severity": updated.Severity, "status": updated.Status, + "severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID, }) } c.JSON(http.StatusOK, updated) diff --git a/internal/handler/webshell_context.go b/internal/handler/webshell_context.go index 17541f5a..15cdc0f7 100644 --- a/internal/handler/webshell_context.go +++ b/internal/handler/webshell_context.go @@ -15,7 +15,7 @@ const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `s // webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。 // 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。 -const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base" +const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、list_knowledge_risk_types、search_knowledge_base" // BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。 // 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、 diff --git a/internal/project/blackboard.go b/internal/project/blackboard.go new file mode 100644 index 00000000..1b7c3f6d --- /dev/null +++ b/internal/project/blackboard.go @@ -0,0 +1,77 @@ +package project + +import ( + "fmt" + "sort" + "strings" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" +) + +// AppendSystemPromptBlock 将附加块追加到 system prompt。 +func AppendSystemPromptBlock(base, block string) string { + base = strings.TrimSpace(base) + block = strings.TrimSpace(block) + if block == "" { + return base + } + if base == "" { + return block + } + return base + "\n\n" + block +} + +// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。 +func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) { + if db == nil || !cfg.Enabled { + return "", nil + } + projectID = strings.TrimSpace(projectID) + if projectID == "" { + return "", nil + } + + proj, err := db.GetProject(projectID) + if err != nil { + return "", err + } + + facts, err := db.ListProjectFactsForIndex(projectID, cfg.DefaultInjectDeprecated) + if err != nil { + return "", err + } + if len(facts) == 0 { + return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil + } + + sort.SliceStable(facts, func(i, j int) bool { + if facts[i].Pinned != facts[j].Pinned { + return facts[i].Pinned + } + return facts[i].UpdatedAt.After(facts[j].UpdatedAt) + }) + + maxRunes := cfg.FactIndexMaxRunesEffective() + var b strings.Builder + b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID)) + used := len([]rune(b.String())) + omitted := 0 + + for _, f := range facts { + line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence) + lineRunes := len([]rune(line)) + if used+lineRunes > maxRunes { + omitted++ + continue + } + b.WriteString(line) + used += lineRunes + } + + if omitted > 0 { + b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted)) + } + b.WriteString("需要完整内容(POC、长文本等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n") + return b.String(), nil +}