Add files via upload

This commit is contained in:
公明
2026-05-26 14:34:21 +08:00
committed by GitHub
parent f42209682a
commit dd3b1ae219
15 changed files with 636 additions and 31 deletions
+25 -12
View File
@@ -222,6 +222,7 @@ type ChatReasoningRequest struct {
type ChatRequest struct {
Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"`
ProjectID string `json:"projectId,omitempty"` // 新对话绑定的项目(可选;未指定时可用 config.project.default_project_id
Role string `json:"role,omitempty"` // 角色名称
Attachments []ChatAttachment `json:"attachments,omitempty"`
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
@@ -560,7 +561,9 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
conversationID := req.ConversationID
if conversationID == "" {
title := safeTruncateString(req.Message, 50)
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "agent_loop"))
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop")
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
conv, err := h.db.CreateConversation(title, meta)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -635,6 +638,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
builtin.ToolWebshellFileRead,
builtin.ToolWebshellFileWrite,
builtin.ToolRecordVulnerability,
builtin.ToolListVulnerabilities,
builtin.ToolGetVulnerability,
builtin.ToolListKnowledgeRiskTypes,
builtin.ToolSearchKnowledgeBase,
}
@@ -682,7 +687,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, "", nil)
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
@@ -760,7 +765,9 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con
if strings.TrimSpace(platform) != "" {
src = "robot:" + strings.TrimSpace(platform)
}
conv, createErr := h.db.CreateConversation(title, audit.ConversationCreateMeta(src))
meta := audit.ConversationCreateMeta(src)
meta.ProjectID = effectiveProjectID(h.config, "")
conv, createErr := h.db.CreateConversation(title, meta)
if createErr != nil {
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
}
@@ -839,7 +846,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con
for {
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback, nil,
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
)
if errMA == nil {
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
@@ -872,7 +879,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con
resultMA, errMA = multiagent.RunDeepAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback,
h.agentsMarkdownDir, robotMode, nil,
h.agentsMarkdownDir, robotMode, nil, h.projectBlackboardBlock(conversationID),
)
if errMA == nil {
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
@@ -891,7 +898,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
}
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
if err != nil {
taskStatus = "failed"
errMsg := "执行失败: " + err.Error()
@@ -1518,6 +1525,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
var conv *database.Conversation
var err error
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream")
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
if req.WebShellConnectionID != "" {
meta.Source = "webshell_chat"
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta)
@@ -1595,6 +1603,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
builtin.ToolWebshellFileRead,
builtin.ToolWebshellFileWrite,
builtin.ToolRecordVulnerability,
builtin.ToolListVulnerabilities,
builtin.ToolGetVulnerability,
builtin.ToolListKnowledgeRiskTypes,
builtin.ToolSearchKnowledgeBase,
}
@@ -1725,7 +1735,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
defer close(stopKeepalive)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
cause := context.Cause(baseCtx)
@@ -2037,6 +2047,7 @@ type BatchTaskRequest struct {
ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
}
func normalizeBatchQueueAgentMode(mode string) string {
@@ -2117,7 +2128,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
nextRunAt = &next
}
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks)
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
if createErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
return
@@ -2651,7 +2662,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
// 创建新对话
title := safeTruncateString(task.Message, 50)
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMeta("batch_task"))
batchMeta := audit.ConversationCreateMeta("batch_task")
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
conv, err := h.db.CreateConversation(title, batchMeta)
var conversationID string
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
@@ -2801,15 +2814,15 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
var runErr error
switch {
case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil)
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
case useEinoSingle:
if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载")
} else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil)
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
}
default:
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
}
if runErr != nil {
+10 -1
View File
@@ -65,6 +65,7 @@ type BatchTaskQueue struct {
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
LastScheduleError string `json:"lastScheduleError,omitempty"`
LastRunError string `json:"lastRunError,omitempty"`
ProjectID string `json:"projectId,omitempty"`
Tasks []*BatchTask `json:"tasks"`
Status string `json:"status"` // pending, running, paused, completed, cancelled
CreatedAt time.Time `json:"createdAt"`
@@ -103,7 +104,7 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
// CreateBatchQueue 创建批量任务队列
func (m *BatchTaskManager) CreateBatchQueue(
title, role, agentMode, scheduleMode, cronExpr string,
title, role, agentMode, scheduleMode, cronExpr, projectID string,
nextRunAt *time.Time,
tasks []string,
) (*BatchTaskQueue, error) {
@@ -126,6 +127,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
ID: queueID,
Title: title,
Role: role,
ProjectID: strings.TrimSpace(projectID),
AgentMode: normalizeBatchQueueAgentMode(agentMode),
ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode),
CronExpr: strings.TrimSpace(cronExpr),
@@ -171,6 +173,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
queue.ScheduleMode,
queue.CronExpr,
queue.NextRunAt,
queue.ProjectID,
dbTasks,
); err != nil {
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
@@ -263,6 +266,9 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
if queueRow.LastRunError.Valid {
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
}
if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
}
if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time
}
@@ -499,6 +505,9 @@ func (m *BatchTaskManager) LoadFromDB() error {
if queueRow.LastRunError.Valid {
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
}
if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
}
if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time
}
+6 -1
View File
@@ -176,6 +176,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
"type": "boolean",
"description": "创建后是否立即开始执行队列,默认 falsepending,需 batch_task_start",
},
"project_id": map[string]interface{}{
"type": "string",
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id",
},
},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
@@ -204,7 +208,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
if !ok {
executeNow = false
}
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks)
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
if createErr != nil {
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
}
+30 -2
View File
@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"strconv"
"strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
@@ -33,7 +34,13 @@ func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHa
// CreateConversationRequest 创建对话请求
type CreateConversationRequest struct {
Title string `json:"title"`
Title string `json:"title"`
ProjectID string `json:"projectId,omitempty"`
}
// SetConversationProjectRequest 设置对话所属项目
type SetConversationProjectRequest struct {
ProjectID string `json:"projectId"` // 空字符串表示解除绑定
}
// CreateConversation 创建新对话
@@ -49,7 +56,9 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
title = "新对话"
}
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api"))
meta := audit.ConversationCreateMetaFromGin(c, "api")
meta.ProjectID = strings.TrimSpace(req.ProjectID)
conv, err := h.db.CreateConversation(title, meta)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -59,6 +68,25 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
c.JSON(http.StatusOK, conv)
}
// SetConversationProject 设置或清除对话绑定的项目
func (h *ConversationHandler) SetConversationProject(c *gin.Context) {
id := c.Param("id")
var req SetConversationProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := h.db.GetConversation(id); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)})
}
// ListConversations 列出对话
func (h *ConversationHandler) ListConversations(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50")
+2
View File
@@ -230,6 +230,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
roleTools,
progressCallback,
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(conversationID),
)
if result != nil && len(result.MCPExecutionIDs) > 0 {
@@ -429,6 +430,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.RoleTools,
progressCallback,
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
)
if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
+2
View File
@@ -242,6 +242,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
h.agentsMarkdownDir,
orch,
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(conversationID),
)
if result != nil && len(result.MCPExecutionIDs) > 0 {
@@ -443,6 +444,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
h.agentsMarkdownDir,
strings.TrimSpace(req.Orchestration),
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
)
if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
+8
View File
@@ -36,6 +36,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context
var conv *database.Conversation
var err error
meta := audit.ConversationCreateMetaFromGin(c, source)
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
if strings.TrimSpace(req.WebShellConnectionID) != "" {
meta.Source = source + "_webshell"
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
@@ -90,6 +91,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context
builtin.ToolWebshellFileRead,
builtin.ToolWebshellFileWrite,
builtin.ToolRecordVulnerability,
builtin.ToolListVulnerabilities,
builtin.ToolGetVulnerability,
builtin.ToolUpsertProjectFact,
builtin.ToolGetProjectFact,
builtin.ToolListProjectFacts,
builtin.ToolSearchProjectFacts,
builtin.ToolDeprecateProjectFact,
builtin.ToolListKnowledgeRiskTypes,
builtin.ToolSearchKnowledgeBase,
}
+138
View File
@@ -73,8 +73,22 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"description": "对话标题",
"example": "Web应用安全测试",
},
"projectId": map[string]interface{}{
"type": "string",
"description": "绑定的项目 ID(可选,共享事实黑板)",
},
},
},
"SetConversationProjectRequest": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"projectId": map[string]interface{}{
"type": "string",
"description": "项目 ID;空字符串表示解除绑定",
},
},
"required": []string{"projectId"},
},
"Conversation": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
@@ -98,6 +112,10 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"format": "date-time",
"description": "更新时间",
},
"projectId": map[string]interface{}{
"type": "string",
"description": "绑定的项目 ID(可选)",
},
},
},
"ConversationDetail": map[string]interface{}{
@@ -1326,6 +1344,37 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
},
},
},
"/api/conversations/{id}/project": map[string]interface{}{
"put": map[string]interface{}{
"tags": []string{"对话管理"},
"summary": "设置对话所属项目",
"description": "绑定或解除对话与项目的关联,用于共享事实黑板",
"operationId": "setConversationProject",
"parameters": []map[string]interface{}{
{
"name": "id", "in": "path", "required": true,
"description": "对话ID",
"schema": map[string]interface{}{"type": "string"},
},
},
"requestBody": map[string]interface{}{
"required": true,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"$ref": "#/components/schemas/SetConversationProjectRequest",
},
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{"description": "设置成功"},
"400": map[string]interface{}{"description": "项目不存在或参数错误"},
"404": map[string]interface{}{"description": "对话不存在"},
"401": map[string]interface{}{"description": "未授权"},
},
},
},
"/api/conversations/{id}/results": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"对话管理"},
@@ -2444,6 +2493,86 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
},
},
},
"/api/projects": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"项目管理"},
"summary": "列出项目",
"operationId": "listProjects",
"parameters": []map[string]interface{}{
{"name": "status", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"active", "archived"}}},
{"name": "limit", "in": "query", "schema": map[string]interface{}{"type": "integer", "default": 200}},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{"description": "项目列表"},
"401": map[string]interface{}{"description": "未授权"},
},
},
"post": map[string]interface{}{
"tags": []string{"项目管理"},
"summary": "创建项目",
"operationId": "createProject",
"requestBody": map[string]interface{}{
"required": true,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{"type": "string"},
"description": map[string]interface{}{"type": "string"},
"scope_json": map[string]interface{}{"type": "string"},
},
"required": []string{"name"},
},
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{"description": "创建成功"},
"401": map[string]interface{}{"description": "未授权"},
},
},
},
"/api/projects/{id}": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "获取项目", "operationId": "getProject",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "项目详情"}},
},
"put": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "更新项目", "operationId": "updateProject",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "更新成功"}},
},
"delete": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "删除项目", "operationId": "deleteProject",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
},
},
"/api/projects/{id}/facts": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "列出或按 key 获取事实", "operationId": "listProjectFacts",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
},
"post": map[string]interface{}{
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
"parameters": []map[string]interface{}{
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
},
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
},
},
"/api/vulnerabilities": map[string]interface{}{
"get": map[string]interface{}{
"tags": []string{"漏洞管理"},
@@ -2502,6 +2631,15 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"type": "string",
},
},
{
"name": "project_id",
"in": "query",
"required": false,
"description": "项目ID",
"schema": map[string]interface{}{
"type": "string",
},
},
{
"name": "severity",
"in": "query",
+262
View File
@@ -0,0 +1,262 @@
package handler
import (
"net/http"
"strconv"
"strings"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ProjectHandler 项目管理处理器。
type ProjectHandler struct {
db *database.DB
logger *zap.Logger
}
// NewProjectHandler 创建项目管理处理器。
func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler {
return &ProjectHandler{db: db, logger: logger}
}
type createProjectRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
ScopeJSON string `json:"scope_json"`
Status string `json:"status"`
}
type updateProjectRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ScopeJSON string `json:"scope_json"`
Status string `json:"status"`
Pinned *bool `json:"pinned"`
}
// CreateProject POST /api/projects
func (h *ProjectHandler) CreateProject(c *gin.Context) {
var req createProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p := &database.Project{
Name: strings.TrimSpace(req.Name),
Description: req.Description,
ScopeJSON: req.ScopeJSON,
Status: strings.TrimSpace(req.Status),
}
created, err := h.db.CreateProject(p)
if err != nil {
h.logger.Error("创建项目失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, created)
}
// ListProjects GET /api/projects
func (h *ProjectHandler) ListProjects(c *gin.Context) {
status := c.Query("status")
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200"))
offset, _ := strconv.Atoi(c.Query("offset"))
list, err := h.db.ListProjects(status, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.Project{}
}
c.JSON(http.StatusOK, list)
}
// GetProject GET /api/projects/:id
func (h *ProjectHandler) GetProject(c *gin.Context) {
p, err := h.db.GetProject(c.Param("id"))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
c.JSON(http.StatusOK, p)
}
// UpdateProject PUT /api/projects/:id
func (h *ProjectHandler) UpdateProject(c *gin.Context) {
id := c.Param("id")
p, err := h.db.GetProject(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
var req updateProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if s := strings.TrimSpace(req.Name); s != "" {
p.Name = s
}
if req.Description != "" || c.Request.ContentLength > 0 {
p.Description = req.Description
}
if req.ScopeJSON != "" || c.GetHeader("Content-Type") != "" {
p.ScopeJSON = req.ScopeJSON
}
if s := strings.TrimSpace(req.Status); s != "" {
p.Status = s
}
if req.Pinned != nil {
p.Pinned = *req.Pinned
}
if err := h.db.UpdateProject(p); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, p)
}
// DeleteProject DELETE /api/projects/:id
func (h *ProjectHandler) DeleteProject(c *gin.Context) {
if err := h.db.DeleteProject(c.Param("id")); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
type upsertFactRequest struct {
FactKey string `json:"fact_key" binding:"required"`
Category string `json:"category"`
Summary string `json:"summary" binding:"required"`
Body string `json:"body"`
Confidence string `json:"confidence"`
Pinned bool `json:"pinned"`
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
}
// ListFacts GET /api/projects/:id/facts fact_key 查询参数可获取单条详情)
func (h *ProjectHandler) ListFacts(c *gin.Context) {
projectID := c.Param("id")
if key := strings.TrimSpace(c.Query("fact_key")); key != "" {
f, err := h.db.GetProjectFactByKey(projectID, key)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, f)
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
offset, _ := strconv.Atoi(c.Query("offset"))
filter := database.ProjectFactListFilter{
Category: c.Query("category"),
Confidence: c.Query("confidence"),
Search: c.Query("search"),
}
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.ProjectFact{}
}
c.JSON(http.StatusOK, list)
}
// CreateFact POST /api/projects/:id/facts
func (h *ProjectHandler) CreateFact(c *gin.Context) {
var req upsertFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
f := &database.ProjectFact{
ProjectID: c.Param("id"),
FactKey: req.FactKey,
Category: req.Category,
Summary: req.Summary,
Body: req.Body,
Confidence: req.Confidence,
Pinned: req.Pinned,
RelatedVulnerabilityID: req.RelatedVulnerabilityID,
}
created, err := h.db.UpsertProjectFact(f)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, created)
}
// UpdateFact PUT /api/projects/:id/facts/:factId
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
var req upsertFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if k := strings.TrimSpace(req.FactKey); k != "" {
existing.FactKey = k
}
if req.Category != "" {
existing.Category = req.Category
}
if req.Summary != "" {
existing.Summary = req.Summary
}
existing.Body = req.Body
if req.Confidence != "" {
existing.Confidence = req.Confidence
}
existing.Pinned = req.Pinned
existing.RelatedVulnerabilityID = req.RelatedVulnerabilityID
updated, err := h.db.UpsertProjectFact(existing)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, updated)
}
// DeleteFact DELETE /api/projects/:id/facts/:factId
func (h *ProjectHandler) DeleteFact(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
if err := h.db.DeleteProjectFact(existing.ID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
type deprecateFactRequest struct {
FactKey string `json:"fact_key" binding:"required"`
}
// DeprecateFact POST /api/projects/:id/facts/deprecate
func (h *ProjectHandler) DeprecateFact(c *gin.Context) {
var req deprecateFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
+32
View File
@@ -0,0 +1,32 @@
package handler
import (
"strings"
"cyberstrike-ai/internal/project"
"go.uber.org/zap"
)
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
if h == nil || h.db == nil || h.config == nil {
return ""
}
if !h.config.Project.Enabled {
return ""
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
projectID, err := h.db.GetConversationProjectID(conversationID)
if err != nil || projectID == "" {
return ""
}
block, err := project.BuildFactIndexBlock(h.db, projectID, h.config.Project)
if err != nil {
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
return ""
}
return strings.TrimSpace(block)
}
+18
View File
@@ -0,0 +1,18 @@
package handler
import (
"strings"
"cyberstrike-ai/internal/config"
)
// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。
func effectiveProjectID(cfg *config.Config, explicit string) string {
if pid := strings.TrimSpace(explicit); pid != "" {
return pid
}
if cfg != nil {
return strings.TrimSpace(cfg.Project.DefaultProjectID)
}
return ""
}
+6 -2
View File
@@ -133,7 +133,9 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
} else {
t = safeTruncateString(t, 50)
}
conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform})
meta := database.ConversationCreateMeta{Source: "robot:" + platform}
meta.ProjectID = effectiveProjectID(h.config, "")
conv, err := h.db.CreateConversation(t, meta)
if err != nil {
h.logger.Warn("创建机器人会话失败", zap.Error(err))
return "", false
@@ -188,7 +190,9 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
// clearConversation 清空当前会话(切换到新对话)
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
title := "新对话 " + time.Now().Format("01-02 15:04")
conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"})
meta := database.ConversationCreateMeta{Source: "robot:" + platform + ":new"}
meta.ProjectID = effectiveProjectID(h.config, "")
conv, err := h.db.CreateConversation(title, meta)
if err != nil {
h.logger.Warn("创建新对话失败", zap.Error(err))
return ""
+19 -12
View File
@@ -36,6 +36,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
// CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
ProjectID string `json:"project_id"`
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title" binding:"required"`
@@ -59,6 +60,7 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
vuln := &database.Vulnerability{
ConversationID: req.ConversationID,
ProjectID: strings.TrimSpace(req.ProjectID),
ConversationTag: req.ConversationTag,
TaskTag: req.TaskTag,
Title: req.Title,
@@ -116,6 +118,7 @@ func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilt
q = strings.TrimSpace(c.Query("search"))
}
return database.VulnerabilityListFilter{
ProjectID: c.Query("project_id"),
ID: c.Query("id"),
Search: q,
ConversationID: c.Query("conversation_id"),
@@ -193,17 +196,18 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
// UpdateVulnerabilityRequest 更新漏洞请求
type UpdateVulnerabilityRequest struct {
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
ProjectID *string `json:"project_id"`
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// UpdateVulnerability 更新漏洞
@@ -224,6 +228,9 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
}
// 更新字段
if req.ProjectID != nil {
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
}
if req.ConversationTag != "" {
existing.ConversationTag = req.ConversationTag
}
@@ -274,7 +281,7 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
if h.audit != nil {
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
"severity": updated.Severity, "status": updated.Status,
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
})
}
c.JSON(http.StatusOK, updated)
+1 -1
View File
@@ -15,7 +15,7 @@ const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `s
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base"
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、list_knowledge_risk_types、search_knowledge_base"
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
+77
View File
@@ -0,0 +1,77 @@
package project
import (
"fmt"
"sort"
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
)
// AppendSystemPromptBlock 将附加块追加到 system prompt。
func AppendSystemPromptBlock(base, block string) string {
base = strings.TrimSpace(base)
block = strings.TrimSpace(block)
if block == "" {
return base
}
if base == "" {
return block
}
return base + "\n\n" + block
}
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
if db == nil || !cfg.Enabled {
return "", nil
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return "", nil
}
proj, err := db.GetProject(projectID)
if err != nil {
return "", err
}
facts, err := db.ListProjectFactsForIndex(projectID, cfg.DefaultInjectDeprecated)
if err != nil {
return "", err
}
if len(facts) == 0 {
return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil
}
sort.SliceStable(facts, func(i, j int) bool {
if facts[i].Pinned != facts[j].Pinned {
return facts[i].Pinned
}
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
})
maxRunes := cfg.FactIndexMaxRunesEffective()
var b strings.Builder
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s\n", proj.Name, proj.ID))
used := len([]rune(b.String()))
omitted := 0
for _, f := range facts {
line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
lineRunes := len([]rune(line))
if used+lineRunes > maxRunes {
omitted++
continue
}
b.WriteString(line)
used += lineRunes
}
if omitted > 0 {
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
}
b.WriteString("需要完整内容(POC、长文本等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n")
return b.String(), nil
}