mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-26 17:27:54 +02:00
Add files via upload
This commit is contained in:
+25
-12
@@ -222,6 +222,7 @@ type ChatReasoningRequest struct {
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty"` // 新对话绑定的项目(可选;未指定时可用 config.project.default_project_id)
|
||||
Role string `json:"role,omitempty"` // 角色名称
|
||||
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
||||
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
|
||||
@@ -560,7 +561,9 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "agent_loop"))
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop")
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -635,6 +638,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
@@ -682,7 +687,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, "", nil)
|
||||
|
||||
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
@@ -760,7 +765,9 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con
|
||||
if strings.TrimSpace(platform) != "" {
|
||||
src = "robot:" + strings.TrimSpace(platform)
|
||||
}
|
||||
conv, createErr := h.db.CreateConversation(title, audit.ConversationCreateMeta(src))
|
||||
meta := audit.ConversationCreateMeta(src)
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, createErr := h.db.CreateConversation(title, meta)
|
||||
if createErr != nil {
|
||||
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
|
||||
}
|
||||
@@ -839,7 +846,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback, nil,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA == nil {
|
||||
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||
@@ -872,7 +879,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con
|
||||
resultMA, errMA = multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, robotMode, nil,
|
||||
h.agentsMarkdownDir, robotMode, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA == nil {
|
||||
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||
@@ -891,7 +898,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con
|
||||
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||
}
|
||||
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
taskStatus = "failed"
|
||||
errMsg := "执行失败: " + err.Error()
|
||||
@@ -1518,6 +1525,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream")
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
if req.WebShellConnectionID != "" {
|
||||
meta.Source = "webshell_chat"
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta)
|
||||
@@ -1595,6 +1603,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
@@ -1725,7 +1735,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
|
||||
defer close(stopKeepalive)
|
||||
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
cause := context.Cause(baseCtx)
|
||||
@@ -2037,6 +2047,7 @@ type BatchTaskRequest struct {
|
||||
ScheduleMode string `json:"scheduleMode,omitempty"` // manual | cron
|
||||
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
||||
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
||||
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
||||
}
|
||||
|
||||
func normalizeBatchQueueAgentMode(mode string) string {
|
||||
@@ -2117,7 +2128,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
nextRunAt = &next
|
||||
}
|
||||
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, nextRunAt, validTasks)
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
|
||||
if createErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
||||
return
|
||||
@@ -2651,7 +2662,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
|
||||
// 创建新对话
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMeta("batch_task"))
|
||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||
var conversationID string
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
@@ -2801,15 +2814,15 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil)
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||
case useEinoSingle:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil)
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
default:
|
||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
|
||||
@@ -65,6 +65,7 @@ type BatchTaskQueue struct {
|
||||
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
|
||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||
LastRunError string `json:"lastRunError,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
@@ -103,7 +104,7 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (m *BatchTaskManager) CreateBatchQueue(
|
||||
title, role, agentMode, scheduleMode, cronExpr string,
|
||||
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
||||
nextRunAt *time.Time,
|
||||
tasks []string,
|
||||
) (*BatchTaskQueue, error) {
|
||||
@@ -126,6 +127,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
ID: queueID,
|
||||
Title: title,
|
||||
Role: role,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
AgentMode: normalizeBatchQueueAgentMode(agentMode),
|
||||
ScheduleMode: normalizeBatchQueueScheduleMode(scheduleMode),
|
||||
CronExpr: strings.TrimSpace(cronExpr),
|
||||
@@ -171,6 +173,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
queue.ScheduleMode,
|
||||
queue.CronExpr,
|
||||
queue.NextRunAt,
|
||||
queue.ProjectID,
|
||||
dbTasks,
|
||||
); err != nil {
|
||||
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
@@ -263,6 +266,9 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
if queueRow.LastRunError.Valid {
|
||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||
}
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -499,6 +505,9 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if queueRow.LastRunError.Valid {
|
||||
queue.LastRunError = strings.TrimSpace(queueRow.LastRunError.String)
|
||||
}
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
|
||||
@@ -176,6 +176,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
"type": "boolean",
|
||||
"description": "创建后是否立即开始执行队列,默认 false(pending,需 batch_task_start)",
|
||||
},
|
||||
"project_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
@@ -204,7 +208,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
if !ok {
|
||||
executeNow = false
|
||||
}
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks)
|
||||
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
||||
if createErr != nil {
|
||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
@@ -33,7 +34,13 @@ func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHa
|
||||
|
||||
// CreateConversationRequest 创建对话请求
|
||||
type CreateConversationRequest struct {
|
||||
Title string `json:"title"`
|
||||
Title string `json:"title"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
}
|
||||
|
||||
// SetConversationProjectRequest 设置对话所属项目
|
||||
type SetConversationProjectRequest struct {
|
||||
ProjectID string `json:"projectId"` // 空字符串表示解除绑定
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
@@ -49,7 +56,9 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
title = "新对话"
|
||||
}
|
||||
|
||||
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api"))
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "api")
|
||||
meta.ProjectID = strings.TrimSpace(req.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -59,6 +68,25 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// SetConversationProject 设置或清除对话绑定的项目
|
||||
func (h *ConversationHandler) SetConversationProject(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req SetConversationProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if _, err := h.db.GetConversation(id); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)})
|
||||
}
|
||||
|
||||
// ListConversations 列出对话
|
||||
func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
|
||||
@@ -230,6 +230,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
roleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
@@ -429,6 +430,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
|
||||
@@ -242,6 +242,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
@@ -443,6 +444,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
|
||||
@@ -36,6 +36,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, source)
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
||||
meta.Source = source + "_webshell"
|
||||
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
|
||||
@@ -90,6 +91,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolUpsertProjectFact,
|
||||
builtin.ToolGetProjectFact,
|
||||
builtin.ToolListProjectFacts,
|
||||
builtin.ToolSearchProjectFacts,
|
||||
builtin.ToolDeprecateProjectFact,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
|
||||
@@ -73,8 +73,22 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"description": "对话标题",
|
||||
"example": "Web应用安全测试",
|
||||
},
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "绑定的项目 ID(可选,共享事实黑板)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"SetConversationProjectRequest": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "项目 ID;空字符串表示解除绑定",
|
||||
},
|
||||
},
|
||||
"required": []string{"projectId"},
|
||||
},
|
||||
"Conversation": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
@@ -98,6 +112,10 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"format": "date-time",
|
||||
"description": "更新时间",
|
||||
},
|
||||
"projectId": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "绑定的项目 ID(可选)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"ConversationDetail": map[string]interface{}{
|
||||
@@ -1326,6 +1344,37 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/conversations/{id}/project": map[string]interface{}{
|
||||
"put": map[string]interface{}{
|
||||
"tags": []string{"对话管理"},
|
||||
"summary": "设置对话所属项目",
|
||||
"description": "绑定或解除对话与项目的关联,用于共享事实黑板",
|
||||
"operationId": "setConversationProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "id", "in": "path", "required": true,
|
||||
"description": "对话ID",
|
||||
"schema": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"$ref": "#/components/schemas/SetConversationProjectRequest",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "设置成功"},
|
||||
"400": map[string]interface{}{"description": "项目不存在或参数错误"},
|
||||
"404": map[string]interface{}{"description": "对话不存在"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/conversations/{id}/results": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"对话管理"},
|
||||
@@ -2444,6 +2493,86 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/projects": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"},
|
||||
"summary": "列出项目",
|
||||
"operationId": "listProjects",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "status", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"active", "archived"}}},
|
||||
{"name": "limit", "in": "query", "schema": map[string]interface{}{"type": "integer", "default": 200}},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "项目列表"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"},
|
||||
"summary": "创建项目",
|
||||
"operationId": "createProject",
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"name": map[string]interface{}{"type": "string"},
|
||||
"description": map[string]interface{}{"type": "string"},
|
||||
"scope_json": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{"description": "创建成功"},
|
||||
"401": map[string]interface{}{"description": "未授权"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "获取项目", "operationId": "getProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "项目详情"}},
|
||||
},
|
||||
"put": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "更新项目", "operationId": "updateProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "更新成功"}},
|
||||
},
|
||||
"delete": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "删除项目", "operationId": "deleteProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}/facts": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "列出或按 key 获取事实", "operationId": "listProjectFacts",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
||||
},
|
||||
},
|
||||
"/api/vulnerabilities": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"漏洞管理"},
|
||||
@@ -2502,6 +2631,15 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "project_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"description": "项目ID",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "severity",
|
||||
"in": "query",
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProjectHandler 项目管理处理器。
|
||||
type ProjectHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProjectHandler 创建项目管理处理器。
|
||||
func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler {
|
||||
return &ProjectHandler{db: db, logger: logger}
|
||||
}
|
||||
|
||||
type createProjectRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
ScopeJSON string `json:"scope_json"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type updateProjectRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ScopeJSON string `json:"scope_json"`
|
||||
Status string `json:"status"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
}
|
||||
|
||||
// CreateProject POST /api/projects
|
||||
func (h *ProjectHandler) CreateProject(c *gin.Context) {
|
||||
var req createProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
p := &database.Project{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Description: req.Description,
|
||||
ScopeJSON: req.ScopeJSON,
|
||||
Status: strings.TrimSpace(req.Status),
|
||||
}
|
||||
created, err := h.db.CreateProject(p)
|
||||
if err != nil {
|
||||
h.logger.Error("创建项目失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// ListProjects GET /api/projects
|
||||
func (h *ProjectHandler) ListProjects(c *gin.Context) {
|
||||
status := c.Query("status")
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "200"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
list, err := h.db.ListProjects(status, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.Project{}
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// GetProject GET /api/projects/:id
|
||||
func (h *ProjectHandler) GetProject(c *gin.Context) {
|
||||
p, err := h.db.GetProject(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
// UpdateProject PUT /api/projects/:id
|
||||
func (h *ProjectHandler) UpdateProject(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
p, err := h.db.GetProject(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
var req updateProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if s := strings.TrimSpace(req.Name); s != "" {
|
||||
p.Name = s
|
||||
}
|
||||
if req.Description != "" || c.Request.ContentLength > 0 {
|
||||
p.Description = req.Description
|
||||
}
|
||||
if req.ScopeJSON != "" || c.GetHeader("Content-Type") != "" {
|
||||
p.ScopeJSON = req.ScopeJSON
|
||||
}
|
||||
if s := strings.TrimSpace(req.Status); s != "" {
|
||||
p.Status = s
|
||||
}
|
||||
if req.Pinned != nil {
|
||||
p.Pinned = *req.Pinned
|
||||
}
|
||||
if err := h.db.UpdateProject(p); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
// DeleteProject DELETE /api/projects/:id
|
||||
func (h *ProjectHandler) DeleteProject(c *gin.Context) {
|
||||
if err := h.db.DeleteProject(c.Param("id")); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type upsertFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"`
|
||||
Pinned bool `json:"pinned"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||
}
|
||||
|
||||
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
||||
func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
if key := strings.TrimSpace(c.Query("fact_key")); key != "" {
|
||||
f, err := h.db.GetProjectFactByKey(projectID, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, f)
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
filter := database.ProjectFactListFilter{
|
||||
Category: c.Query("category"),
|
||||
Confidence: c.Query("confidence"),
|
||||
Search: c.Query("search"),
|
||||
}
|
||||
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.ProjectFact{}
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// CreateFact POST /api/projects/:id/facts
|
||||
func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
||||
var req upsertFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
f := &database.ProjectFact{
|
||||
ProjectID: c.Param("id"),
|
||||
FactKey: req.FactKey,
|
||||
Category: req.Category,
|
||||
Summary: req.Summary,
|
||||
Body: req.Body,
|
||||
Confidence: req.Confidence,
|
||||
Pinned: req.Pinned,
|
||||
RelatedVulnerabilityID: req.RelatedVulnerabilityID,
|
||||
}
|
||||
created, err := h.db.UpsertProjectFact(f)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// UpdateFact PUT /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
var req upsertFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if k := strings.TrimSpace(req.FactKey); k != "" {
|
||||
existing.FactKey = k
|
||||
}
|
||||
if req.Category != "" {
|
||||
existing.Category = req.Category
|
||||
}
|
||||
if req.Summary != "" {
|
||||
existing.Summary = req.Summary
|
||||
}
|
||||
existing.Body = req.Body
|
||||
if req.Confidence != "" {
|
||||
existing.Confidence = req.Confidence
|
||||
}
|
||||
existing.Pinned = req.Pinned
|
||||
existing.RelatedVulnerabilityID = req.RelatedVulnerabilityID
|
||||
updated, err := h.db.UpsertProjectFact(existing)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) DeleteFact(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteProjectFact(existing.ID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type deprecateFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
}
|
||||
|
||||
// DeprecateFact POST /api/projects/:id/facts/deprecate
|
||||
func (h *ProjectHandler) DeprecateFact(c *gin.Context) {
|
||||
var req deprecateFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/project"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
|
||||
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||
if h == nil || h.db == nil || h.config == nil {
|
||||
return ""
|
||||
}
|
||||
if !h.config.Project.Enabled {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
projectID, err := h.db.GetConversationProjectID(conversationID)
|
||||
if err != nil || projectID == "" {
|
||||
return ""
|
||||
}
|
||||
block, err := project.BuildFactIndexBlock(h.db, projectID, h.config.Project)
|
||||
if err != nil {
|
||||
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(block)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。
|
||||
func effectiveProjectID(cfg *config.Config, explicit string) string {
|
||||
if pid := strings.TrimSpace(explicit); pid != "" {
|
||||
return pid
|
||||
}
|
||||
if cfg != nil {
|
||||
return strings.TrimSpace(cfg.Project.DefaultProjectID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -133,7 +133,9 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
|
||||
} else {
|
||||
t = safeTruncateString(t, 50)
|
||||
}
|
||||
conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform})
|
||||
meta := database.ConversationCreateMeta{Source: "robot:" + platform}
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, err := h.db.CreateConversation(t, meta)
|
||||
if err != nil {
|
||||
h.logger.Warn("创建机器人会话失败", zap.Error(err))
|
||||
return "", false
|
||||
@@ -188,7 +190,9 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
|
||||
// clearConversation 清空当前会话(切换到新对话)
|
||||
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
|
||||
title := "新对话 " + time.Now().Format("01-02 15:04")
|
||||
conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"})
|
||||
meta := database.ConversationCreateMeta{Source: "robot:" + platform + ":new"}
|
||||
meta.ProjectID = effectiveProjectID(h.config, "")
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Warn("创建新对话失败", zap.Error(err))
|
||||
return ""
|
||||
|
||||
@@ -36,6 +36,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
|
||||
// CreateVulnerabilityRequest 创建漏洞请求
|
||||
type CreateVulnerabilityRequest struct {
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
ProjectID string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
@@ -59,6 +60,7 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: req.ConversationID,
|
||||
ProjectID: strings.TrimSpace(req.ProjectID),
|
||||
ConversationTag: req.ConversationTag,
|
||||
TaskTag: req.TaskTag,
|
||||
Title: req.Title,
|
||||
@@ -116,6 +118,7 @@ func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilt
|
||||
q = strings.TrimSpace(c.Query("search"))
|
||||
}
|
||||
return database.VulnerabilityListFilter{
|
||||
ProjectID: c.Query("project_id"),
|
||||
ID: c.Query("id"),
|
||||
Search: q,
|
||||
ConversationID: c.Query("conversation_id"),
|
||||
@@ -193,17 +196,18 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
|
||||
// UpdateVulnerabilityRequest 更新漏洞请求
|
||||
type UpdateVulnerabilityRequest struct {
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
ProjectID *string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
}
|
||||
|
||||
// UpdateVulnerability 更新漏洞
|
||||
@@ -224,6 +228,9 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.ProjectID != nil {
|
||||
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
|
||||
}
|
||||
if req.ConversationTag != "" {
|
||||
existing.ConversationTag = req.ConversationTag
|
||||
}
|
||||
@@ -274,7 +281,7 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
||||
"severity": updated.Severity, "status": updated.Status,
|
||||
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
|
||||
@@ -15,7 +15,7 @@ const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `s
|
||||
|
||||
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
|
||||
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
|
||||
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base"
|
||||
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、list_knowledge_risk_types、search_knowledge_base"
|
||||
|
||||
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
|
||||
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// AppendSystemPromptBlock 将附加块追加到 system prompt。
|
||||
func AppendSystemPromptBlock(base, block string) string {
|
||||
base = strings.TrimSpace(base)
|
||||
block = strings.TrimSpace(block)
|
||||
if block == "" {
|
||||
return base
|
||||
}
|
||||
if base == "" {
|
||||
return block
|
||||
}
|
||||
return base + "\n\n" + block
|
||||
}
|
||||
|
||||
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。
|
||||
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
|
||||
if db == nil || !cfg.Enabled {
|
||||
return "", nil
|
||||
}
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
proj, err := db.GetProject(projectID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
facts, err := db.ListProjectFactsForIndex(projectID, cfg.DefaultInjectDeprecated)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(facts) == 0 {
|
||||
return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil
|
||||
}
|
||||
|
||||
sort.SliceStable(facts, func(i, j int) bool {
|
||||
if facts[i].Pinned != facts[j].Pinned {
|
||||
return facts[i].Pinned
|
||||
}
|
||||
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
|
||||
})
|
||||
|
||||
maxRunes := cfg.FactIndexMaxRunesEffective()
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID))
|
||||
used := len([]rune(b.String()))
|
||||
omitted := 0
|
||||
|
||||
for _, f := range facts {
|
||||
line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
|
||||
lineRunes := len([]rune(line))
|
||||
if used+lineRunes > maxRunes {
|
||||
omitted++
|
||||
continue
|
||||
}
|
||||
b.WriteString(line)
|
||||
used += lineRunes
|
||||
}
|
||||
|
||||
if omitted > 0 {
|
||||
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
|
||||
}
|
||||
b.WriteString("需要完整内容(POC、长文本等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n")
|
||||
return b.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user