From 1b1aed1699c8ffd525485eb068e933cd44d40fa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Tue, 26 May 2026 14:27:44 +0800 Subject: [PATCH] Add files via upload --- internal/agent/agent.go | 8 +- .../agent/default_single_system_prompt.go | 10 +- internal/config/config.go | 26 + internal/database/batch_task.go | 22 +- internal/database/conversation.go | 57 ++- internal/database/conversation_create_meta.go | 1 + internal/database/database.go | 98 ++++ internal/database/project.go | 451 ++++++++++++++++++ internal/database/project_time_test.go | 93 ++++ internal/database/vulnerability.go | 35 +- internal/mcp/builtin/constants.go | 25 +- 11 files changed, 790 insertions(+), 36 deletions(-) create mode 100644 internal/database/project.go create mode 100644 internal/database/project_time_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index b72106fa..026ecd70 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -17,6 +17,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" + "cyberstrike-ai/internal/project" "cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/security" "cyberstrike-ai/internal/storage" @@ -365,12 +366,12 @@ type ProgressCallback func(eventType, message string, data interface{}) // AgentLoop 执行Agent循环 func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) { - return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil) + return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil, "") } // AgentLoopWithConversationID 执行Agent循环(带对话ID) func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) { - return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil) + return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil, "") } // EinoSingleAgentSystemInstruction 供 Eino adk.ChatModelAgent.Instruction 使用,与 AgentLoopWithProgress 首条 system 对齐(含 system_prompt_path)。 @@ -396,7 +397,7 @@ func (a *Agent) EinoSingleAgentSystemInstruction() string { } // AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID) -func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) { +func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string, systemPromptExtra string) (*AgentLoopResult, error) { ctx = withAgentConversationID(ctx, conversationID) // 设置当前对话ID(兼容未走 context 的旧路径;并发会话应以 context 为准) a.mu.Lock() @@ -426,6 +427,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his } } } + systemPrompt = project.AppendSystemPromptBlock(systemPrompt, systemPromptExtra) messages := []ChatMessage{ { diff --git a/internal/agent/default_single_system_prompt.go b/internal/agent/default_single_system_prompt.go index 6300ea1e..12453c93 100644 --- a/internal/agent/default_single_system_prompt.go +++ b/internal/agent/default_single_system_prompt.go @@ -105,11 +105,15 @@ func DefaultSingleAgentSystemPrompt() string { - 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。 - 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。 -## 漏洞记录 +## 项目黑板(事实)与漏洞记录(分离) -发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。 +当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ` + builtin.ToolGetProjectFact + `(fact_key) 获取 body,禁止凭摘要臆造细节。** -严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。记录后可在授权范围内继续测试。 +- **环境/目标/认证等认知**(非正式漏洞条目):使用 ` + builtin.ToolUpsertProjectFact + `,fact_key 建议 ` + "`category/slug`" + `(如 target/primary_domain),同 key 覆盖更新。 +- **可交付漏洞**:使用 ` + builtin.ToolRecordVulnerability + `,含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ` + builtin.ToolListVulnerabilities + ` 查重,详情用 ` + builtin.ToolGetVulnerability + `(id)(默认仅当前项目/会话)。 +- 同一发现可能需**各记一次**(事实记上下文,漏洞记正式 findings)。误报用 ` + builtin.ToolDeprecateProjectFact + ` 或漏洞状态 false_positive。 + +严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。 ## 技能库(Skills)与知识库 diff --git a/internal/config/config.go b/internal/config/config.go index 69191bbf..fb501ee5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,6 +36,32 @@ type Config struct { SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录 AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.md,YAML front matter) MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"` + Project ProjectConfig `yaml:"project,omitempty" json:"project,omitempty"` +} + +// ProjectConfig 项目黑板(跨对话共享事实)配置。 +type ProjectConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目 + FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"` + FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"` + DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"` +} + +// FactIndexMaxRunesEffective 自动注入黑板索引的最大 rune 数。 +func (c ProjectConfig) FactIndexMaxRunesEffective() int { + if c.FactIndexMaxRunes <= 0 { + return 3500 + } + return c.FactIndexMaxRunes +} + +// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数。 +func (c ProjectConfig) FactSummaryMaxRunesEffective() int { + if c.FactSummaryMaxRunes <= 0 { + return 120 + } + return c.FactSummaryMaxRunes } // MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。 diff --git a/internal/database/batch_task.go b/internal/database/batch_task.go index c774be65..fa22a31f 100644 --- a/internal/database/batch_task.go +++ b/internal/database/batch_task.go @@ -22,6 +22,7 @@ type BatchTaskQueueRow struct { LastScheduleTriggerAt sql.NullTime LastScheduleError sql.NullString LastRunError sql.NullString + ProjectID sql.NullString Status string CreatedAt time.Time StartedAt sql.NullTime @@ -51,6 +52,7 @@ func (db *DB) CreateBatchQueue( scheduleMode string, cronExpr string, nextRunAt *time.Time, + projectID string, tasks []map[string]interface{}, ) error { tx, err := db.Begin() @@ -65,9 +67,13 @@ func (db *DB) CreateBatchQueue( nextRunAtValue = *nextRunAt } + var projectIDVal interface{} + if strings.TrimSpace(projectID) != "" { + projectIDVal = strings.TrimSpace(projectID) + } _, err = tx.Exec( - "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0, + "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0, ) if err != nil { return fmt.Errorf("创建批量任务队列失败: %w", err) @@ -101,9 +107,9 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { var row BatchTaskQueueRow var createdAt string err := db.QueryRow( - "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", + "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", queueID, - ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) + ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) if err == sql.ErrNoRows { return nil, nil } @@ -127,7 +133,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { // GetAllBatchQueues 获取所有批量任务队列 func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { rows, err := db.Query( - "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", + "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", ) if err != nil { return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) @@ -138,7 +144,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { for rows.Next() { var row BatchTaskQueueRow var createdAt string - if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { + if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) } parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) @@ -158,7 +164,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { // ListBatchQueues 列出批量任务队列(支持筛选和分页) func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { - query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" + query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" args := []interface{}{} // 状态筛选 @@ -186,7 +192,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat for rows.Next() { var row BatchTaskQueueRow var createdAt string - if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { + if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) } parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) diff --git a/internal/database/conversation.go b/internal/database/conversation.go index e51b9a6e..c0f1c422 100644 --- a/internal/database/conversation.go +++ b/internal/database/conversation.go @@ -17,6 +17,7 @@ import ( type Conversation struct { ID string `json:"id"` Title string `json:"title"` + ProjectID string `json:"projectId,omitempty"` Pinned bool `json:"pinned"` CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` @@ -46,13 +47,32 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, id := uuid.New().String() now := time.Now() + projectID := strings.TrimSpace(meta.ProjectID) + if projectID != "" { + if _, err := db.GetProject(projectID); err != nil { + return nil, err + } + } + var err error - if webshellConnectionID != "" { + wsID := strings.TrimSpace(webshellConnectionID) + switch { + case wsID != "" && projectID != "": + _, err = db.Exec( + "INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id, project_id) VALUES (?, ?, ?, ?, ?, ?)", + id, title, now, now, wsID, projectID, + ) + case wsID != "": _, err = db.Exec( "INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)", - id, title, now, now, webshellConnectionID, + id, title, now, now, wsID, ) - } else { + case projectID != "": + _, err = db.Exec( + "INSERT INTO conversations (id, title, created_at, updated_at, project_id) VALUES (?, ?, ?, ?, ?)", + id, title, now, now, projectID, + ) + default: _, err = db.Exec( "INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)", id, title, now, now, @@ -65,11 +85,12 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, conv := &Conversation{ ID: id, Title: title, + ProjectID: projectID, CreatedAt: now, UpdatedAt: now, } - if webshellConnectionID != "" { - meta.WebShellConnectionID = webshellConnectionID + if wsID != "" { + meta.WebShellConnectionID = wsID } notifyConversationCreated(conv, meta) return conv, nil @@ -210,16 +231,20 @@ func (db *DB) GetConversation(id string) (*Conversation, error) { var createdAt, updatedAt string var pinned int + var projectID sql.NullString err := db.QueryRow( - "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", + "SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?", id, - ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) + ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("对话不存在") } return nil, fmt.Errorf("查询对话失败: %w", err) } + if projectID.Valid { + conv.ProjectID = strings.TrimSpace(projectID.String) + } // 尝试多种时间格式解析 var err1, err2 error @@ -292,16 +317,20 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) { var createdAt, updatedAt string var pinned int + var projectID sql.NullString err := db.QueryRow( - "SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?", + "SELECT id, title, pinned, created_at, updated_at, project_id FROM conversations WHERE id = ?", id, - ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt) + ).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID) if err != nil { if err == sql.ErrNoRows { return nil, fmt.Errorf("对话不存在") } return nil, fmt.Errorf("查询对话失败: %w", err) } + if projectID.Valid { + conv.ProjectID = strings.TrimSpace(projectID.String) + } // 尝试多种时间格式解析 var err1, err2 error @@ -341,7 +370,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 searchPattern := "%" + search + "%" rows, err = db.Query( - `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at + `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id FROM conversations c WHERE c.title LIKE ? OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?) @@ -351,7 +380,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati ) } else { rows, err = db.Query( - "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?", + "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?", limit, offset, ) } @@ -366,10 +395,14 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati var conv Conversation var createdAt, updatedAt string var pinned int + var projectID sql.NullString - if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil { + if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &projectID); err != nil { return nil, fmt.Errorf("扫描对话失败: %w", err) } + if projectID.Valid { + conv.ProjectID = strings.TrimSpace(projectID.String) + } // 尝试多种时间格式解析 var err1, err2 error diff --git a/internal/database/conversation_create_meta.go b/internal/database/conversation_create_meta.go index 4ba96530..8f94dc8e 100644 --- a/internal/database/conversation_create_meta.go +++ b/internal/database/conversation_create_meta.go @@ -4,6 +4,7 @@ package database type ConversationCreateMeta struct { Source string WebShellConnectionID string + ProjectID string ClientIP string SessionHint string } diff --git a/internal/database/database.go b/internal/database/database.go index 26d44fd6..5f62cac1 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -213,6 +213,40 @@ func (db *DB) initTables() error { FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE );` + // 创建项目表 + createProjectsTable := ` + CREATE TABLE IF NOT EXISTS projects ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + scope_json TEXT, + status TEXT NOT NULL DEFAULT 'active', + pinned INTEGER NOT NULL DEFAULT 0, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + );` + + // 创建项目事实表(黑板) + createProjectFactsTable := ` + CREATE TABLE IF NOT EXISTS project_facts ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + fact_key TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'note', + summary TEXT NOT NULL DEFAULT '', + body TEXT, + confidence TEXT NOT NULL DEFAULT 'tentative', + source_conversation_id TEXT, + source_message_id TEXT, + pinned INTEGER NOT NULL DEFAULT 0, + supersedes_fact_id TEXT, + related_vulnerability_id TEXT, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + UNIQUE(project_id, fact_key) + );` + // 创建漏洞表 createVulnerabilitiesTable := ` CREATE TABLE IF NOT EXISTS vulnerabilities ( @@ -445,6 +479,12 @@ func (db *DB) initTables() error { CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at); + CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status); + CREATE INDEX IF NOT EXISTS idx_projects_updated_at ON projects(updated_at); + CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id); + CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence); + CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id); + CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id); CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id); CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at); CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title); @@ -516,6 +556,14 @@ func (db *DB) initTables() error { return fmt.Errorf("创建robot_user_sessions表失败: %w", err) } + if _, err := db.Exec(createProjectsTable); err != nil { + return fmt.Errorf("创建projects表失败: %w", err) + } + + if _, err := db.Exec(createProjectFactsTable); err != nil { + return fmt.Errorf("创建project_facts表失败: %w", err) + } + if _, err := db.Exec(createVulnerabilitiesTable); err != nil { return fmt.Errorf("创建vulnerabilities表失败: %w", err) } @@ -583,6 +631,10 @@ func (db *DB) initTables() error { // 不返回错误,允许继续运行 } + if err := db.migrateProjectsTable(); err != nil { + db.logger.Warn("迁移projects相关表失败", zap.Error(err)) + } + if err := db.migrateWebshellConnectionsTable(); err != nil { db.logger.Warn("迁移webshell_connections表失败", zap.Error(err)) // 不返回错误,允许继续运行 @@ -930,6 +982,51 @@ func (db *DB) migrateBatchTaskQueuesTable() error { } } + var projectIDCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='project_id'").Scan(&projectIDCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(addErr)) + } + } + } else if projectIDCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN project_id TEXT"); err != nil { + db.logger.Warn("添加batch_task_queues.project_id字段失败", zap.Error(err)) + } + } + + return nil +} + +// migrateProjectsTable 迁移 projects / conversations / vulnerabilities 的项目关联字段。 +func (db *DB) migrateProjectsTable() error { + for _, col := range []struct { + table string + name string + stmt string + }{ + {"conversations", "project_id", "ALTER TABLE conversations ADD COLUMN project_id TEXT REFERENCES projects(id) ON DELETE SET NULL"}, + {"vulnerabilities", "project_id", "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"}, + } { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info(?) WHERE name=?", col.table, col.name).Scan(&count) + if err != nil { + if _, addErr := db.Exec(col.stmt); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr)) + } + } + continue + } + if count == 0 { + if _, addErr := db.Exec(col.stmt); addErr != nil { + db.logger.Warn("添加字段失败", zap.String("table", col.table), zap.String("field", col.name), zap.Error(addErr)) + } + } + } return nil } @@ -941,6 +1038,7 @@ func (db *DB) migrateVulnerabilitiesTable() error { }{ {name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"}, {name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"}, + {name: "project_id", stmt: "ALTER TABLE vulnerabilities ADD COLUMN project_id TEXT"}, } for _, col := range columns { diff --git a/internal/database/project.go b/internal/database/project.go new file mode 100644 index 00000000..f300d82e --- /dev/null +++ b/internal/database/project.go @@ -0,0 +1,451 @@ +package database + +import ( + "database/sql" + "fmt" + "regexp" + "strings" + "time" + + "github.com/google/uuid" +) + +var factKeyPattern = regexp.MustCompile(`^[a-z0-9][a-z0-9._/-]*$`) + +// ValidateFactKey 校验事实 key(项目内唯一标识)。 +func ValidateFactKey(key string) error { + key = strings.TrimSpace(key) + if key == "" { + return fmt.Errorf("fact_key 不能为空") + } + if len(key) > 128 { + return fmt.Errorf("fact_key 过长(最多 128 字符)") + } + if !factKeyPattern.MatchString(key) { + return fmt.Errorf("fact_key 格式无效,仅允许小写字母、数字及 . _ / -,且须以小写字母或数字开头") + } + return nil +} + +// Project 渗透测试项目(跨对话共享黑板)。 +type Project struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + ScopeJSON string `json:"scope_json,omitempty"` + Status string `json:"status"` // active | archived + Pinned bool `json:"pinned"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ProjectFact 项目事实(黑板条目)。 +type ProjectFact struct { + ID string `json:"id"` + ProjectID string `json:"project_id"` + FactKey string `json:"fact_key"` + Category string `json:"category"` + Summary string `json:"summary"` + Body string `json:"body"` + Confidence string `json:"confidence"` // confirmed | tentative | deprecated + SourceConversationID string `json:"source_conversation_id,omitempty"` + SourceMessageID string `json:"source_message_id,omitempty"` + Pinned bool `json:"pinned"` + SupersedesFactID string `json:"supersedes_fact_id,omitempty"` + RelatedVulnerabilityID string `json:"related_vulnerability_id,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// ProjectFactListFilter 事实列表筛选。 +type ProjectFactListFilter struct { + Category string + Confidence string + Search string +} + +// CreateProject 创建项目。 +func (db *DB) CreateProject(p *Project) (*Project, error) { + if p.ID == "" { + p.ID = uuid.New().String() + } + if strings.TrimSpace(p.Status) == "" { + p.Status = "active" + } + now := time.Now() + if p.CreatedAt.IsZero() { + p.CreatedAt = now + } + p.UpdatedAt = now + + _, err := db.Exec( + `INSERT INTO projects (id, name, description, scope_json, status, pinned, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + p.ID, p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.CreatedAt, p.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("创建项目失败: %w", err) + } + return p, nil +} + +// GetProject 获取项目。 +func (db *DB) GetProject(id string) (*Project, error) { + var p Project + var pinned int + var createdAt, updatedAt string + err := db.QueryRow( + `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at + FROM projects WHERE id = ?`, id, + ).Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("项目不存在") + } + return nil, fmt.Errorf("获取项目失败: %w", err) + } + p.Pinned = pinned != 0 + p.CreatedAt = parseDBTime(createdAt) + p.UpdatedAt = parseDBTime(updatedAt) + return &p, nil +} + +// ListProjects 列出项目。 +func (db *DB) ListProjects(status string, limit, offset int) ([]*Project, error) { + if limit <= 0 { + limit = 200 + } + query := `SELECT id, name, COALESCE(description,''), COALESCE(scope_json,''), status, pinned, created_at, updated_at + FROM projects WHERE 1=1` + args := []interface{}{} + if s := strings.TrimSpace(status); s != "" { + query += " AND status = ?" + args = append(args, s) + } + query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?" + args = append(args, limit, offset) + + rows, err := db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("列出项目失败: %w", err) + } + defer rows.Close() + + var out []*Project + for rows.Next() { + var p Project + var pinned int + var createdAt, updatedAt string + if err := rows.Scan(&p.ID, &p.Name, &p.Description, &p.ScopeJSON, &p.Status, &pinned, &createdAt, &updatedAt); err != nil { + return nil, err + } + p.Pinned = pinned != 0 + p.CreatedAt = parseDBTime(createdAt) + p.UpdatedAt = parseDBTime(updatedAt) + out = append(out, &p) + } + return out, rows.Err() +} + +// UpdateProject 更新项目。 +func (db *DB) UpdateProject(p *Project) error { + p.UpdatedAt = time.Now() + _, err := db.Exec( + `UPDATE projects SET name = ?, description = ?, scope_json = ?, status = ?, pinned = ?, updated_at = ? WHERE id = ?`, + p.Name, p.Description, p.ScopeJSON, p.Status, boolToInt(p.Pinned), p.UpdatedAt, p.ID, + ) + if err != nil { + return fmt.Errorf("更新项目失败: %w", err) + } + return nil +} + +// DeleteProject 删除项目(级联删除事实;对话 project_id 置空由 FK 处理)。 +func (db *DB) DeleteProject(id string) error { + _, err := db.Exec(`DELETE FROM projects WHERE id = ?`, id) + if err != nil { + return fmt.Errorf("删除项目失败: %w", err) + } + return nil +} + +// GetConversationProjectID 返回对话绑定的项目 ID。 +func (db *DB) GetConversationProjectID(conversationID string) (string, error) { + var pid sql.NullString + err := db.QueryRow(`SELECT project_id FROM conversations WHERE id = ?`, conversationID).Scan(&pid) + if err != nil { + if err == sql.ErrNoRows { + return "", fmt.Errorf("对话不存在") + } + return "", err + } + if pid.Valid { + return strings.TrimSpace(pid.String), nil + } + return "", nil +} + +// SetConversationProjectID 设置对话所属项目(空字符串表示解除绑定)。 +func (db *DB) SetConversationProjectID(conversationID, projectID string) error { + projectID = strings.TrimSpace(projectID) + if projectID != "" { + if _, err := db.GetProject(projectID); err != nil { + return err + } + } + var val interface{} + if projectID == "" { + val = nil + } else { + val = projectID + } + _, err := db.Exec(`UPDATE conversations SET project_id = ?, updated_at = ? WHERE id = ?`, val, time.Now(), conversationID) + if err != nil { + return fmt.Errorf("设置对话项目失败: %w", err) + } + return nil +} + +// ListProjectFactsForIndex 列出用于黑板索引注入的事实(不含 deprecated,除非 includeDeprecated)。 +func (db *DB) ListProjectFactsForIndex(projectID string, includeDeprecated bool) ([]*ProjectFact, error) { + query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, + COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, + COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at + FROM project_facts WHERE project_id = ?` + args := []interface{}{projectID} + if !includeDeprecated { + query += " AND confidence != 'deprecated'" + } + query += " ORDER BY pinned DESC, updated_at DESC" + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + return scanProjectFacts(rows) +} + +// ListProjectFacts 分页列出项目事实。 +func (db *DB) ListProjectFacts(projectID string, filter ProjectFactListFilter, limit, offset int) ([]*ProjectFact, error) { + if limit <= 0 { + limit = 100 + } + query := `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, + COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, + COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at + FROM project_facts WHERE project_id = ?` + args := []interface{}{projectID} + if c := strings.TrimSpace(filter.Category); c != "" { + query += " AND category = ?" + args = append(args, c) + } + if c := strings.TrimSpace(filter.Confidence); c != "" { + query += " AND confidence = ?" + args = append(args, c) + } + if s := strings.TrimSpace(filter.Search); s != "" { + pat := "%" + s + "%" + query += " AND (fact_key LIKE ? OR summary LIKE ? OR body LIKE ?)" + args = append(args, pat, pat, pat) + } + query += " ORDER BY pinned DESC, updated_at DESC LIMIT ? OFFSET ?" + args = append(args, limit, offset) + + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + return scanProjectFacts(rows) +} + +// GetProjectFactByKey 按 key 获取事实。 +func (db *DB) GetProjectFactByKey(projectID, factKey string) (*ProjectFact, error) { + row := db.QueryRow( + `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, + COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, + COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at + FROM project_facts WHERE project_id = ? AND fact_key = ?`, + projectID, factKey, + ) + return scanProjectFactRow(row) +} + +// GetProjectFact 按 ID 获取事实。 +func (db *DB) GetProjectFact(id string) (*ProjectFact, error) { + row := db.QueryRow( + `SELECT id, project_id, fact_key, category, summary, COALESCE(body,''), confidence, + COALESCE(source_conversation_id,''), COALESCE(source_message_id,''), pinned, + COALESCE(supersedes_fact_id,''), COALESCE(related_vulnerability_id,''), created_at, updated_at + FROM project_facts WHERE id = ?`, id, + ) + return scanProjectFactRow(row) +} + +// UpsertProjectFact 创建或更新事实(按 project_id + fact_key)。 +func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) { + if err := ValidateFactKey(f.FactKey); err != nil { + return nil, err + } + if strings.TrimSpace(f.Category) == "" { + f.Category = "note" + } + if strings.TrimSpace(f.Confidence) == "" { + f.Confidence = "tentative" + } + now := time.Now() + + existing, err := db.GetProjectFactByKey(f.ProjectID, f.FactKey) + if err == nil && existing != nil { + f.ID = existing.ID + f.CreatedAt = existing.CreatedAt + f.UpdatedAt = now + _, err = db.Exec( + `UPDATE project_facts SET category = ?, summary = ?, body = ?, confidence = ?, + source_conversation_id = ?, source_message_id = ?, pinned = ?, + supersedes_fact_id = ?, related_vulnerability_id = ?, updated_at = ? + WHERE id = ?`, + f.Category, f.Summary, f.Body, f.Confidence, + nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned), + nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), f.UpdatedAt, f.ID, + ) + if err != nil { + return nil, fmt.Errorf("更新事实失败: %w", err) + } + return f, nil + } + + if f.ID == "" { + f.ID = uuid.New().String() + } + f.CreatedAt = now + f.UpdatedAt = now + _, err = db.Exec( + `INSERT INTO project_facts ( + id, project_id, fact_key, category, summary, body, confidence, + source_conversation_id, source_message_id, pinned, supersedes_fact_id, related_vulnerability_id, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + f.ID, f.ProjectID, f.FactKey, f.Category, f.Summary, f.Body, f.Confidence, + nullIfEmpty(f.SourceConversationID), nullIfEmpty(f.SourceMessageID), boolToInt(f.Pinned), + nullIfEmpty(f.SupersedesFactID), nullIfEmpty(f.RelatedVulnerabilityID), + f.CreatedAt, f.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("创建事实失败: %w", err) + } + return f, nil +} + +// DeprecateProjectFact 将事实标记为 deprecated。 +func (db *DB) DeprecateProjectFact(projectID, factKey string) error { + res, err := db.Exec( + `UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`, + time.Now(), projectID, factKey, + ) + if err != nil { + return err + } + n, _ := res.RowsAffected() + if n == 0 { + return fmt.Errorf("事实不存在") + } + return nil +} + +// DeleteProjectFact 删除事实。 +func (db *DB) DeleteProjectFact(id string) error { + _, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id) + return err +} + +func scanProjectFacts(rows *sql.Rows) ([]*ProjectFact, error) { + var out []*ProjectFact + for rows.Next() { + f, err := scanProjectFactFromRows(rows) + if err != nil { + return nil, err + } + out = append(out, f) + } + return out, rows.Err() +} + +func scanProjectFactRow(row *sql.Row) (*ProjectFact, error) { + var f ProjectFact + var pinned int + var createdAt, updatedAt string + err := row.Scan( + &f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence, + &f.SourceConversationID, &f.SourceMessageID, &pinned, + &f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("事实不存在") + } + return nil, err + } + f.Pinned = pinned != 0 + f.CreatedAt = parseDBTime(createdAt) + f.UpdatedAt = parseDBTime(updatedAt) + return &f, nil +} + +func scanProjectFactFromRows(rows *sql.Rows) (*ProjectFact, error) { + var f ProjectFact + var pinned int + var createdAt, updatedAt string + err := rows.Scan( + &f.ID, &f.ProjectID, &f.FactKey, &f.Category, &f.Summary, &f.Body, &f.Confidence, + &f.SourceConversationID, &f.SourceMessageID, &pinned, + &f.SupersedesFactID, &f.RelatedVulnerabilityID, &createdAt, &updatedAt, + ) + if err != nil { + return nil, err + } + f.Pinned = pinned != 0 + f.CreatedAt = parseDBTime(createdAt) + f.UpdatedAt = parseDBTime(updatedAt) + return &f, nil +} + +func boolToInt(b bool) int { + if b { + return 1 + } + return 0 +} + +func nullIfEmpty(s string) interface{} { + if strings.TrimSpace(s) == "" { + return nil + } + return s +} + +func parseDBTime(s string) time.Time { + s = strings.TrimSpace(s) + if s == "" { + return time.Time{} + } + // go-sqlite3 读 DATETIME 常返回 RFC3339(含 T),写入时可能是空格分隔格式,需兼容多种形态 + layouts := []string{ + time.RFC3339Nano, + time.RFC3339, + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05-07:00", + "2006-01-02T15:04:05.999999999-07:00", + "2006-01-02T15:04:05-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05.999999999", + "2006-01-02T15:04:05", + } + for _, layout := range layouts { + if t, e := time.Parse(layout, s); e == nil { + return t + } + } + return time.Time{} +} diff --git a/internal/database/project_time_test.go b/internal/database/project_time_test.go new file mode 100644 index 00000000..c064ee49 --- /dev/null +++ b/internal/database/project_time_test.go @@ -0,0 +1,93 @@ +package database + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "go.uber.org/zap" +) + +func TestParseDBTime_projectFactFormats(t *testing.T) { + cases := []string{ + "2026-05-26 11:13:07.442143+08:00", + "2026-05-26 11:13:07", + "2026-05-26T11:13:07.442143+08:00", + } + for _, s := range cases { + got := parseDBTime(s) + if got.IsZero() { + t.Fatalf("parseDBTime(%q) returned zero", s) + } + } +} + +func TestListProjectFacts_updatedAtJSON(t *testing.T) { + root, err := os.Getwd() + if err != nil { + t.Skip(err) + } + dbPath := filepath.Join(root, "..", "..", "data", "conversations.db") + if _, err := os.Stat(dbPath); err != nil { + t.Skip("conversations.db not found") + } + db, err := NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + projects, err := db.ListProjects("", 1, 0) + if err != nil || len(projects) == 0 { + t.Skip("no projects") + } + pid := projects[0].ID + + list, err := db.ListProjectFacts(pid, ProjectFactListFilter{}, 5, 0) + if err != nil { + t.Fatal(err) + } + if len(list) == 0 { + t.Skip("no facts") + } + for _, f := range list { + if f.UpdatedAt.IsZero() { + t.Fatalf("fact %s UpdatedAt is zero after ListProjectFacts", f.FactKey) + } + b, err := json.Marshal(f) + if err != nil { + t.Fatal(err) + } + var m map[string]interface{} + if err := json.Unmarshal(b, &m); err != nil { + t.Fatal(err) + } + raw, ok := m["updated_at"].(string) + if !ok || raw == "" || raw[:4] == "0001" { + t.Fatalf("bad updated_at in JSON: %v", m["updated_at"]) + } + } +} + +func TestParseDBTime_zeroOnGarbage(t *testing.T) { + if !parseDBTime("").IsZero() { + t.Fatal("expected zero for empty") + } +} + +// Ensure RFC3339 round-trip used by API is after year 2000. +func TestParseDBTime_marshalRoundTrip(t *testing.T) { + s := "2026-05-26 11:13:07.442143+08:00" + tm := parseDBTime(s) + b, err := json.Marshal(tm) + if err != nil { + t.Fatal(err) + } + var back time.Time + if err := json.Unmarshal(b, &back); err != nil { + t.Fatal(err) + } + if back.IsZero() { + t.Fatalf("unmarshal zero from %s", string(b)) + } +} diff --git a/internal/database/vulnerability.go b/internal/database/vulnerability.go index 06d360f2..de67dbeb 100644 --- a/internal/database/vulnerability.go +++ b/internal/database/vulnerability.go @@ -15,6 +15,7 @@ type VulnerabilityListFilter struct { ID string Search string // 关键词模糊匹配(标题、描述、类型、目标等) ConversationID string + ProjectID string Severity string Status string TaskID string @@ -38,6 +39,10 @@ func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) ( query += " AND conversation_id = ?" args = append(args, f.ConversationID) } + if f.ProjectID != "" { + query += " AND project_id = ?" + args = append(args, f.ProjectID) + } if f.TaskID != "" { query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" args = append(args, f.TaskID, f.TaskID) @@ -85,6 +90,7 @@ func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) ( type Vulnerability struct { ID string `json:"id"` ConversationID string `json:"conversation_id"` + ProjectID string `json:"project_id,omitempty"` ConversationTag string `json:"conversation_tag,omitempty"` TaskTag string `json:"task_tag,omitempty"` TaskID string `json:"task_id,omitempty"` @@ -116,17 +122,23 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { } vuln.UpdatedAt = now + if strings.TrimSpace(vuln.ProjectID) == "" && vuln.ConversationID != "" { + if pid, err := db.GetConversationProjectID(vuln.ConversationID); err == nil { + vuln.ProjectID = pid + } + } + query := ` INSERT INTO vulnerabilities ( - id, conversation_id, conversation_tag, task_tag, title, description, severity, status, + id, conversation_id, project_id, conversation_tag, task_tag, title, description, severity, status, vulnerability_type, target, proof, impact, recommendation, created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` _, err := db.Exec( query, - vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, + vuln.ID, vuln.ConversationID, nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.CreatedAt, vuln.UpdatedAt, @@ -142,7 +154,7 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) { func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { var vuln Vulnerability query := ` - SELECT id, conversation_id, title, description, severity, status, + SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, @@ -152,7 +164,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { ` err := db.QueryRow(query, id).Scan( - &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, + &vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, &vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.TaskID, &vuln.TaskQueueID, @@ -171,7 +183,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) { // ListVulnerabilities 列出漏洞 func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) { query := ` - SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag, + SELECT id, conversation_id, COALESCE(project_id,''), title, description, severity, status, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation, COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id, COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id, @@ -195,7 +207,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFil for rows.Next() { var vuln Vulnerability err := rows.Scan( - &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, + &vuln.ID, &vuln.ConversationID, &vuln.ProjectID, &vuln.Title, &vuln.Description, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target, &vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.TaskID, &vuln.TaskQueueID, @@ -232,7 +244,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error { query := ` UPDATE vulnerabilities - SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?, + SET project_id = ?, conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?, vulnerability_type = ?, target = ?, proof = ?, impact = ?, recommendation = ?, updated_at = ? WHERE id = ? @@ -240,7 +252,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error { _, err := db.Exec( query, - vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, + nullIfEmpty(vuln.ProjectID), vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.UpdatedAt, id, ) @@ -366,10 +378,15 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) { if err != nil { return nil, fmt.Errorf("查询任务标签建议失败: %w", err) } + projectIDs, err := collect(`SELECT DISTINCT project_id FROM vulnerabilities WHERE project_id IS NOT NULL AND project_id <> '' ORDER BY created_at DESC LIMIT 200`) + if err != nil { + return nil, fmt.Errorf("查询项目ID建议失败: %w", err) + } return map[string][]string{ "vulnerability_ids": vulnIDs, "conversation_ids": conversationIDs, + "project_ids": projectIDs, "task_ids": taskIDs, "queue_ids": queueIDs, "conversation_tags": conversationTags, diff --git a/internal/mcp/builtin/constants.go b/internal/mcp/builtin/constants.go index 29d2fad7..bc178049 100644 --- a/internal/mcp/builtin/constants.go +++ b/internal/mcp/builtin/constants.go @@ -4,7 +4,16 @@ package builtin // 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串 const ( // 漏洞管理工具 - ToolRecordVulnerability = "record_vulnerability" + ToolRecordVulnerability = "record_vulnerability" + ToolListVulnerabilities = "list_vulnerabilities" + ToolGetVulnerability = "get_vulnerability" + + // 项目黑板(事实)工具 + ToolUpsertProjectFact = "upsert_project_fact" + ToolGetProjectFact = "get_project_fact" + ToolListProjectFacts = "list_project_facts" + ToolSearchProjectFacts = "search_project_facts" + ToolDeprecateProjectFact = "deprecate_project_fact" // 知识库工具 ToolListKnowledgeRiskTypes = "list_knowledge_risk_types" @@ -53,6 +62,13 @@ const ( func IsBuiltinTool(toolName string) bool { switch toolName { case ToolRecordVulnerability, + ToolListVulnerabilities, + ToolGetVulnerability, + ToolUpsertProjectFact, + ToolGetProjectFact, + ToolListProjectFacts, + ToolSearchProjectFacts, + ToolDeprecateProjectFact, ToolListKnowledgeRiskTypes, ToolSearchKnowledgeBase, ToolWebshellExec, @@ -96,6 +112,13 @@ func IsBuiltinTool(toolName string) bool { func GetAllBuiltinTools() []string { return []string{ ToolRecordVulnerability, + ToolListVulnerabilities, + ToolGetVulnerability, + ToolUpsertProjectFact, + ToolGetProjectFact, + ToolListProjectFacts, + ToolSearchProjectFacts, + ToolDeprecateProjectFact, ToolListKnowledgeRiskTypes, ToolSearchKnowledgeBase, ToolWebshellExec,