Files
CyberStrikeAI/internal/app/app.go
2026-02-28 00:52:17 +08:00

1212 lines
42 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"context"
"database/sql"
"fmt"
"net/http"
"os"
"path/filepath"
"sync"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/handler"
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/robot"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/skills"
"cyberstrike-ai/internal/storage"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// App 应用
type App struct {
config *config.Config
logger *logger.Logger
router *gin.Engine
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager
agent *agent.Agent
executor *security.Executor
db *database.DB
knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库)
auth *security.AuthManager
knowledgeManager *knowledge.Manager // 知识库管理器(用于动态初始化)
knowledgeRetriever *knowledge.Retriever // 知识库检索器(用于动态初始化)
knowledgeIndexer *knowledge.Indexer // 知识库索引器(用于动态初始化)
knowledgeHandler *handler.KnowledgeHandler // 知识库处理器(用于动态初始化)
agentHandler *handler.AgentHandler // Agent处理器用于更新知识库管理器
robotHandler *handler.RobotHandler // 机器人处理器(钉钉/飞书/企业微信)
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
}
// New 创建新应用
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
gin.SetMode(gin.ReleaseMode)
router := gin.Default()
// CORS中间件
router.Use(corsMiddleware())
// 认证管理器
authManager, err := security.NewAuthManager(cfg.Auth.Password, cfg.Auth.SessionDurationHours)
if err != nil {
return nil, fmt.Errorf("初始化认证失败: %w", err)
}
// 初始化数据库
dbPath := cfg.Database.Path
if dbPath == "" {
dbPath = "data/conversations.db"
}
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
db, err := database.NewDB(dbPath, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err)
}
// 创建MCP服务器带数据库持久化
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
// 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
// 注册工具
executor.RegisterTools(mcpServer)
// 注册漏洞记录工具
registerVulnerabilityTool(mcpServer, db, log.Logger)
if cfg.Auth.GeneratedPassword != "" {
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
cfg.Auth.GeneratedPassword = ""
cfg.Auth.GeneratedPasswordPersisted = false
cfg.Auth.GeneratedPasswordPersistErr = ""
}
// 创建外部MCP管理器使用与内部MCP服务器相同的存储
externalMCPMgr := mcp.NewExternalMCPManagerWithStorage(log.Logger, db)
if cfg.ExternalMCP.Servers != nil {
externalMCPMgr.LoadConfigs(&cfg.ExternalMCP)
// 启动所有启用的外部MCP客户端
externalMCPMgr.StartAllEnabled()
}
// 初始化结果存储
resultStorageDir := "tmp"
if cfg.Agent.ResultStorageDir != "" {
resultStorageDir = cfg.Agent.ResultStorageDir
}
// 确保存储目录存在
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
return nil, fmt.Errorf("创建结果存储目录失败: %w", err)
}
// 创建结果存储实例
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化结果存储失败: %w", err)
}
// 创建Agent
maxIterations := cfg.Agent.MaxIterations
if maxIterations <= 0 {
maxIterations = 30 // 默认值
}
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
// 设置结果存储到Agent
agent.SetResultStorage(resultStorage)
// 设置结果存储到Executor用于查询工具
executor.SetResultStorage(resultStorage)
// 初始化知识库模块(如果启用)
var knowledgeManager *knowledge.Manager
var knowledgeRetriever *knowledge.Retriever
var knowledgeIndexer *knowledge.Indexer
var knowledgeHandler *handler.KnowledgeHandler
var knowledgeDBConn *database.DB
log.Logger.Info("检查知识库配置", zap.Bool("enabled", cfg.Knowledge.Enabled))
if cfg.Knowledge.Enabled {
// 确定知识库数据库路径
knowledgeDBPath := cfg.Database.KnowledgeDBPath
var knowledgeDB *sql.DB
if knowledgeDBPath != "" {
// 使用独立的知识库数据库
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil {
return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err)
}
var err error
knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化知识库数据库失败: %w", err)
}
knowledgeDB = knowledgeDBConn.DB
log.Logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath))
} else {
// 向后兼容:使用会话数据库
knowledgeDB = db.DB
log.Logger.Info("使用会话数据库存储知识库数据建议配置knowledge_db_path以分离数据")
}
// 创建知识库管理器
knowledgeManager = knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, log.Logger)
// 创建嵌入器
// 使用OpenAI配置的API Key如果知识库配置中没有指定
if cfg.Knowledge.Embedding.APIKey == "" {
cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey
}
if cfg.Knowledge.Embedding.BaseURL == "" {
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
}
httpClient := &http.Client{
Timeout: 30 * time.Minute,
}
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, log.Logger)
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, log.Logger)
// 创建检索器
retrievalConfig := &knowledge.RetrievalConfig{
TopK: cfg.Knowledge.Retrieval.TopK,
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
}
knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger)
// 创建索引器
knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
// 创建知识库API处理器
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 扫描知识库并建立索引(异步)
go func() {
itemsToIndex, err := knowledgeManager.ScanKnowledgeBase()
if err != nil {
log.Logger.Warn("扫描知识库失败", zap.Error(err))
return
}
// 检查是否已有索引
hasIndex, err := knowledgeIndexer.HasIndex()
if err != nil {
log.Logger.Warn("检查索引状态失败", zap.Error(err))
return
}
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
log.Logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
log.Logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次立即停止增量索引
if consecutiveFailures >= 2 {
log.Logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
}
log.Logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
log.Logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
return
}
// 只有在没有索引时才自动重建
log.Logger.Info("未检测到知识库索引,开始自动构建索引")
ctx := context.Background()
if err := knowledgeIndexer.RebuildIndex(ctx); err != nil {
log.Logger.Warn("重建知识库索引失败", zap.Error(err))
}
}()
}
// 获取配置文件路径
configPath := "config.yaml"
if len(os.Args) > 1 {
configPath = os.Args[1]
}
// 初始化Skills管理器
skillsDir := cfg.SkillsDir
if skillsDir == "" {
skillsDir = "skills" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
configDir := filepath.Dir(configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
skillsManager := skills.NewManager(skillsDir, log.Logger)
log.Logger.Info("Skills管理器已初始化", zap.String("skillsDir", skillsDir))
// 注册Skills工具到MCP服务器让AI可以按需调用带数据库存储支持统计
// 创建一个适配器将database.DB适配为SkillStatsStorage接口
var skillStatsStorage skills.SkillStatsStorage
if db != nil {
skillStatsStorage = &skillStatsDBAdapter{db: db}
}
skills.RegisterSkillsToolWithStorage(mcpServer, skillsManager, skillStatsStorage, log.Logger)
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
agentHandler.SetSkillsManager(skillsManager) // 设置Skills管理器
// 如果知识库已启用设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager)
}
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器以便获取外部MCP执行记录
groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
roleHandler.SetSkillsManager(skillsManager) // 设置Skills管理器到RoleHandler
skillsHandler := handler.NewSkillsHandler(skillsManager, cfg, configPath, log.Logger)
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
if db != nil {
skillsHandler.SetDB(db) // 设置数据库连接以便获取调用统计
}
// 创建OpenAPI处理器
conversationHandler := handler.NewConversationHandler(db, log.Logger)
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
// 创建 App 实例(部分字段稍后填充)
app := &App{
config: cfg,
logger: log,
router: router,
mcpServer: mcpServer,
externalMCPMgr: externalMCPMgr,
agent: agent,
executor: executor,
db: db,
knowledgeDB: knowledgeDBConn,
auth: authManager,
knowledgeManager: knowledgeManager,
knowledgeRetriever: knowledgeRetriever,
knowledgeIndexer: knowledgeIndexer,
knowledgeHandler: knowledgeHandler,
agentHandler: agentHandler,
robotHandler: robotHandler,
}
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
app.startRobotConnections()
// 设置漏洞工具注册器(内置工具,必须设置)
vulnerabilityRegistrar := func() error {
registerVulnerabilityTool(mcpServer, db, log.Logger)
return nil
}
configHandler.SetVulnerabilityToolRegistrar(vulnerabilityRegistrar)
// 设置Skills工具注册器内置工具必须设置
skillsRegistrar := func() error {
// 创建一个适配器将database.DB适配为SkillStatsStorage接口
var skillStatsStorage skills.SkillStatsStorage
if db != nil {
skillStatsStorage = &skillStatsDBAdapter{db: db}
}
skills.RegisterSkillsToolWithStorage(mcpServer, skillsManager, skillStatsStorage, log.Logger)
return nil
}
configHandler.SetSkillsToolRegistrar(skillsRegistrar)
// 设置知识库初始化器(用于动态初始化,需要在 App 创建后设置)
configHandler.SetKnowledgeInitializer(func() (*handler.KnowledgeHandler, error) {
knowledgeHandler, err := initializeKnowledge(cfg, db, knowledgeDBConn, mcpServer, agentHandler, app, log.Logger)
if err != nil {
return nil, err
}
// 动态初始化后,设置知识库工具注册器和检索器更新器
// 这样后续 ApplyConfig 时就能重新注册工具了
if app.knowledgeRetriever != nil && app.knowledgeManager != nil {
// 创建闭包捕获knowledgeRetriever和knowledgeManager的引用
registrar := func() error {
knowledge.RegisterKnowledgeTool(mcpServer, app.knowledgeRetriever, app.knowledgeManager, log.Logger)
return nil
}
configHandler.SetKnowledgeToolRegistrar(registrar)
// 设置检索器更新器以便在ApplyConfig时更新检索器配置
configHandler.SetRetrieverUpdater(app.knowledgeRetriever)
log.Logger.Info("动态初始化后已设置知识库工具注册器和检索器更新器")
}
return knowledgeHandler, nil
})
// 如果知识库已启用,设置知识库工具注册器和检索器更新器
if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil {
// 创建闭包捕获knowledgeRetriever和knowledgeManager的引用
registrar := func() error {
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
return nil
}
configHandler.SetKnowledgeToolRegistrar(registrar)
// 设置检索器更新器以便在ApplyConfig时更新检索器配置
configHandler.SetRetrieverUpdater(knowledgeRetriever)
}
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效
configHandler.SetRobotRestarter(app)
// 设置路由(使用 App 实例以便动态获取 handler
setupRoutes(
router,
authHandler,
agentHandler,
monitorHandler,
conversationHandler,
robotHandler,
groupHandler,
configHandler,
externalMCPHandler,
attackChainHandler,
app, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler,
roleHandler,
skillsHandler,
fofaHandler,
mcpServer,
authManager,
openAPIHandler,
)
return app, nil
}
// Run 启动应用
func (a *App) Run() error {
// 启动MCP服务器如果启用
if a.config.MCP.Enabled {
go func() {
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
mux := http.NewServeMux()
mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP)
if err := http.ListenAndServe(mcpAddr, mux); err != nil {
a.logger.Error("MCP服务器启动失败", zap.Error(err))
}
}()
}
// 启动主服务器
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
return a.router.Run(addr)
}
// Shutdown 关闭应用
func (a *App) Shutdown() {
// 停止钉钉/飞书长连接
a.robotMu.Lock()
if a.dingCancel != nil {
a.dingCancel()
a.dingCancel = nil
}
if a.larkCancel != nil {
a.larkCancel()
a.larkCancel = nil
}
a.robotMu.Unlock()
// 停止所有外部MCP客户端
if a.externalMCPMgr != nil {
a.externalMCPMgr.StopAll()
}
// 关闭知识库数据库连接(如果使用独立数据库)
if a.knowledgeDB != nil {
if err := a.knowledgeDB.Close(); err != nil {
a.logger.Logger.Warn("关闭知识库数据库连接失败", zap.Error(err))
}
}
}
// startRobotConnections 根据当前配置启动钉钉/飞书长连接(不先关闭已有连接,仅用于首次启动)
func (a *App) startRobotConnections() {
a.robotMu.Lock()
defer a.robotMu.Unlock()
cfg := a.config
if cfg.Robots.Lark.Enabled && cfg.Robots.Lark.AppID != "" && cfg.Robots.Lark.AppSecret != "" {
ctx, cancel := context.WithCancel(context.Background())
a.larkCancel = cancel
go robot.StartLark(ctx, cfg.Robots.Lark, a.robotHandler, a.logger.Logger)
}
if cfg.Robots.Dingtalk.Enabled && cfg.Robots.Dingtalk.ClientID != "" && cfg.Robots.Dingtalk.ClientSecret != "" {
ctx, cancel := context.WithCancel(context.Background())
a.dingCancel = cancel
go robot.StartDing(ctx, cfg.Robots.Dingtalk, a.robotHandler, a.logger.Logger)
}
}
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter
func (a *App) RestartRobotConnections() {
a.robotMu.Lock()
if a.dingCancel != nil {
a.dingCancel()
a.dingCancel = nil
}
if a.larkCancel != nil {
a.larkCancel()
a.larkCancel = nil
}
a.robotMu.Unlock()
// 给旧 goroutine 一点时间退出
time.Sleep(200 * time.Millisecond)
a.startRobotConnections()
}
// setupRoutes 设置路由
func setupRoutes(
router *gin.Engine,
authHandler *handler.AuthHandler,
agentHandler *handler.AgentHandler,
monitorHandler *handler.MonitorHandler,
conversationHandler *handler.ConversationHandler,
robotHandler *handler.RobotHandler,
groupHandler *handler.GroupHandler,
configHandler *handler.ConfigHandler,
externalMCPHandler *handler.ExternalMCPHandler,
attackChainHandler *handler.AttackChainHandler,
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler *handler.VulnerabilityHandler,
roleHandler *handler.RoleHandler,
skillsHandler *handler.SkillsHandler,
fofaHandler *handler.FofaHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
openAPIHandler *handler.OpenAPIHandler,
) {
// API路由
api := router.Group("/api")
// 认证相关路由
authRoutes := api.Group("/auth")
{
authRoutes.POST("/login", authHandler.Login)
authRoutes.POST("/logout", security.AuthMiddleware(authManager), authHandler.Logout)
authRoutes.POST("/change-password", security.AuthMiddleware(authManager), authHandler.ChangePassword)
authRoutes.GET("/validate", security.AuthMiddleware(authManager), authHandler.Validate)
}
// 机器人回调(无需登录,供企业微信/钉钉/飞书服务器调用)
api.GET("/robot/wecom", robotHandler.HandleWecomGET)
api.POST("/robot/wecom", robotHandler.HandleWecomPOST)
api.POST("/robot/dingtalk", robotHandler.HandleDingtalkPOST)
api.POST("/robot/lark", robotHandler.HandleLarkPOST)
protected := api.Group("")
protected.Use(security.AuthMiddleware(authManager))
{
// 机器人测试需登录POST /api/robot/testbody: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
protected.POST("/robot/test", robotHandler.HandleRobotTest)
// Agent Loop
protected.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出
protected.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
// Agent Loop 取消与任务列表
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
protected.GET("/agent-loop/tasks/completed", agentHandler.ListCompletedTasks)
// 信息收集 - FOFA 查询(后端代理)
protected.POST("/fofa/search", fofaHandler.Search)
// 信息收集 - 自然语言解析为 FOFA 语法(需人工确认后再查询)
protected.POST("/fofa/parse", fofaHandler.ParseNaturalLanguage)
// 批量任务管理
protected.POST("/batch-tasks", agentHandler.CreateBatchQueue)
protected.GET("/batch-tasks", agentHandler.ListBatchQueues)
protected.GET("/batch-tasks/:queueId", agentHandler.GetBatchQueue)
protected.POST("/batch-tasks/:queueId/start", agentHandler.StartBatchQueue)
protected.POST("/batch-tasks/:queueId/pause", agentHandler.PauseBatchQueue)
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
// 对话历史
protected.POST("/conversations", conversationHandler.CreateConversation)
protected.GET("/conversations", conversationHandler.ListConversations)
protected.GET("/conversations/:id", conversationHandler.GetConversation)
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
// 对话分组
protected.POST("/groups", groupHandler.CreateGroup)
protected.GET("/groups", groupHandler.ListGroups)
protected.GET("/groups/:id", groupHandler.GetGroup)
protected.PUT("/groups/:id", groupHandler.UpdateGroup)
protected.DELETE("/groups/:id", groupHandler.DeleteGroup)
protected.PUT("/groups/:id/pinned", groupHandler.UpdateGroupPinned)
protected.GET("/groups/:id/conversations", groupHandler.GetGroupConversations)
protected.POST("/groups/conversations", groupHandler.AddConversationToGroup)
protected.DELETE("/groups/:id/conversations/:conversationId", groupHandler.RemoveConversationFromGroup)
protected.PUT("/groups/:id/conversations/:conversationId/pinned", groupHandler.UpdateConversationPinnedInGroup)
// 监控
protected.GET("/monitor", monitorHandler.Monitor)
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
protected.GET("/monitor/stats", monitorHandler.GetStats)
// 配置管理
protected.GET("/config", configHandler.GetConfig)
protected.GET("/config/tools", configHandler.GetTools)
protected.PUT("/config", configHandler.UpdateConfig)
protected.POST("/config/apply", configHandler.ApplyConfig)
// 外部MCP管理
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
protected.GET("/external-mcp/:name", externalMCPHandler.GetExternalMCP)
protected.PUT("/external-mcp/:name", externalMCPHandler.AddOrUpdateExternalMCP)
protected.DELETE("/external-mcp/:name", externalMCPHandler.DeleteExternalMCP)
protected.POST("/external-mcp/:name/start", externalMCPHandler.StartExternalMCP)
protected.POST("/external-mcp/:name/stop", externalMCPHandler.StopExternalMCP)
// 攻击链可视化
protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain)
protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain)
// 知识库管理(始终注册路由,通过 App 实例动态获取 handler
knowledgeRoutes := protected.Group("/knowledge")
{
knowledgeRoutes.GET("/categories", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"categories": []string{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetCategories(c)
})
knowledgeRoutes.GET("/items", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"items": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetItems(c)
})
knowledgeRoutes.GET("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetItem(c)
})
knowledgeRoutes.POST("/items", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.CreateItem(c)
})
knowledgeRoutes.PUT("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.UpdateItem(c)
})
knowledgeRoutes.DELETE("/items/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.DeleteItem(c)
})
knowledgeRoutes.GET("/index-status", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"total_items": 0,
"indexed_items": 0,
"progress_percent": 0,
"is_complete": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetIndexStatus(c)
})
knowledgeRoutes.POST("/index", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.RebuildIndex(c)
})
knowledgeRoutes.POST("/scan", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.ScanKnowledgeBase(c)
})
knowledgeRoutes.GET("/retrieval-logs", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"logs": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetRetrievalLogs(c)
})
knowledgeRoutes.DELETE("/retrieval-logs/:id", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"error": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.DeleteRetrievalLog(c)
})
knowledgeRoutes.POST("/search", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"results": []interface{}{},
"enabled": false,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.Search(c)
})
knowledgeRoutes.GET("/stats", func(c *gin.Context) {
if app.knowledgeHandler == nil {
c.JSON(http.StatusOK, gin.H{
"enabled": false,
"total_categories": 0,
"total_items": 0,
"message": "知识库功能未启用,请前往系统设置启用知识检索功能",
})
return
}
app.knowledgeHandler.GetStats(c)
})
}
// 漏洞管理
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability)
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// 角色管理
protected.GET("/roles", roleHandler.GetRoles)
protected.GET("/roles/:name", roleHandler.GetRole)
protected.GET("/roles/skills/list", roleHandler.GetSkills)
protected.POST("/roles", roleHandler.CreateRole)
protected.PUT("/roles/:name", roleHandler.UpdateRole)
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
// Skills管理
protected.GET("/skills", skillsHandler.GetSkills)
protected.GET("/skills/stats", skillsHandler.GetSkillStats)
protected.DELETE("/skills/stats", skillsHandler.ClearSkillStats)
protected.GET("/skills/:name", skillsHandler.GetSkill)
protected.GET("/skills/:name/bound-roles", skillsHandler.GetSkillBoundRoles)
protected.POST("/skills", skillsHandler.CreateSkill)
protected.PUT("/skills/:name", skillsHandler.UpdateSkill)
protected.DELETE("/skills/:name", skillsHandler.DeleteSkill)
protected.DELETE("/skills/:name/stats", skillsHandler.ClearSkillStatsByName)
// MCP端点
protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
})
// OpenAPI结果聚合端点可选用于获取对话的完整结果
protected.GET("/conversations/:id/results", openAPIHandler.GetConversationResults)
}
// OpenAPI规范需要认证避免暴露API结构信息
protected.GET("/openapi/spec", openAPIHandler.GetOpenAPISpec)
// API文档页面公开访问但需要登录后才能使用API
router.GET("/api-docs", func(c *gin.Context) {
c.HTML(http.StatusOK, "api-docs.html", nil)
})
// 静态文件
router.Static("/static", "./web/static")
router.LoadHTMLGlob("web/templates/*")
// 前端页面
router.GET("/", func(c *gin.Context) {
version := app.config.Version
if version == "" {
version = "v1.0.0"
}
c.HTML(http.StatusOK, "index.html", gin.H{"Version": version})
})
}
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: builtin.ToolRecordVulnerability,
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
"description": "漏洞标题(必需)",
},
"description": map[string]interface{}{
"type": "string",
"description": "漏洞详细描述",
},
"severity": map[string]interface{}{
"type": "string",
"description": "漏洞严重程度critical严重、high、medium、low、info信息",
"enum": []string{"critical", "high", "medium", "low", "info"},
},
"vulnerability_type": map[string]interface{}{
"type": "string",
"description": "漏洞类型SQL注入、XSS、CSRF、命令注入等",
},
"target": map[string]interface{}{
"type": "string",
"description": "受影响的目标URL、IP地址、服务等",
},
"proof": map[string]interface{}{
"type": "string",
"description": "漏洞证明POC、截图、请求/响应等)",
},
"impact": map[string]interface{}{
"type": "string",
"description": "漏洞影响说明",
},
"recommendation": map[string]interface{}{
"type": "string",
"description": "修复建议",
},
},
"required": []string{"title", "severity"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
// 从参数中获取conversation_id由Agent自动添加
conversationID, _ := args["conversation_id"].(string)
if conversationID == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: conversation_id 未设置。这是系统错误,请重试。",
},
},
IsError: true,
}, nil
}
title, ok := args["title"].(string)
if !ok || title == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: title 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
severity, ok := args["severity"].(string)
if !ok || severity == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: severity 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
// 验证严重程度
validSeverities := map[string]bool{
"critical": true,
"high": true,
"medium": true,
"low": true,
"info": true,
}
if !validSeverities[severity] {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity),
},
},
IsError: true,
}, nil
}
// 获取可选参数
description := ""
if d, ok := args["description"].(string); ok {
description = d
}
vulnType := ""
if t, ok := args["vulnerability_type"].(string); ok {
vulnType = t
}
target := ""
if t, ok := args["target"].(string); ok {
target = t
}
proof := ""
if p, ok := args["proof"].(string); ok {
proof = p
}
impact := ""
if i, ok := args["impact"].(string); ok {
impact = i
}
recommendation := ""
if r, ok := args["recommendation"].(string); ok {
recommendation = r
}
// 创建漏洞记录
vuln := &database.Vulnerability{
ConversationID: conversationID,
Title: title,
Description: description,
Severity: severity,
Status: "open",
Type: vulnType,
Target: target,
Proof: proof,
Impact: impact,
Recommendation: recommendation,
}
created, err := db.CreateVulnerability(vuln)
if err != nil {
logger.Error("记录漏洞失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("记录漏洞失败: %v", err),
},
},
IsError: true,
}, nil
}
logger.Info("漏洞记录成功",
zap.String("id", created.ID),
zap.String("title", created.Title),
zap.String("severity", created.Severity),
zap.String("conversation_id", conversationID),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(tool, handler)
logger.Info("漏洞记录工具注册成功")
}
// initializeKnowledge 初始化知识库组件(用于动态初始化)
func initializeKnowledge(
cfg *config.Config,
db *database.DB,
knowledgeDBConn *database.DB,
mcpServer *mcp.Server,
agentHandler *handler.AgentHandler,
app *App, // 传递 App 引用以便更新知识库组件
logger *zap.Logger,
) (*handler.KnowledgeHandler, error) {
// 确定知识库数据库路径
knowledgeDBPath := cfg.Database.KnowledgeDBPath
var knowledgeDB *sql.DB
if knowledgeDBPath != "" {
// 使用独立的知识库数据库
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil {
return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err)
}
var err error
knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, logger)
if err != nil {
return nil, fmt.Errorf("初始化知识库数据库失败: %w", err)
}
knowledgeDB = knowledgeDBConn.DB
logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath))
} else {
// 向后兼容:使用会话数据库
knowledgeDB = db.DB
logger.Info("使用会话数据库存储知识库数据建议配置knowledge_db_path以分离数据")
}
// 创建知识库管理器
knowledgeManager := knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, logger)
// 创建嵌入器
// 使用OpenAI配置的API Key如果知识库配置中没有指定
if cfg.Knowledge.Embedding.APIKey == "" {
cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey
}
if cfg.Knowledge.Embedding.BaseURL == "" {
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
}
httpClient := &http.Client{
Timeout: 30 * time.Minute,
}
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, logger)
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, logger)
// 创建检索器
retrievalConfig := &knowledge.RetrievalConfig{
TopK: cfg.Knowledge.Retrieval.TopK,
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
}
knowledgeRetriever := knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, logger)
// 创建索引器
knowledgeIndexer := knowledge.NewIndexer(knowledgeDB, embedder, logger)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, logger)
// 创建知识库API处理器
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 设置知识库管理器到AgentHandler以便记录检索日志
agentHandler.SetKnowledgeManager(knowledgeManager)
// 更新 App 中的知识库组件(如果 App 不为 nil说明是动态初始化
if app != nil {
app.knowledgeManager = knowledgeManager
app.knowledgeRetriever = knowledgeRetriever
app.knowledgeIndexer = knowledgeIndexer
app.knowledgeHandler = knowledgeHandler
// 如果使用独立数据库,更新 knowledgeDB
if knowledgeDBPath != "" {
app.knowledgeDB = knowledgeDBConn
}
logger.Info("App 中的知识库组件已更新")
}
// 扫描知识库并建立索引(异步)
go func() {
itemsToIndex, err := knowledgeManager.ScanKnowledgeBase()
if err != nil {
logger.Warn("扫描知识库失败", zap.Error(err))
return
}
// 检查是否已有索引
hasIndex, err := knowledgeIndexer.HasIndex()
if err != nil {
logger.Warn("检查索引状态失败", zap.Error(err))
return
}
if hasIndex {
// 如果已有索引,只索引新添加或更新的项
if len(itemsToIndex) > 0 {
logger.Info("检测到已有知识库索引,开始增量索引", zap.Int("count", len(itemsToIndex)))
ctx := context.Background()
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
failedCount := 0
for _, itemID := range itemsToIndex {
if err := knowledgeIndexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
logger.Warn("索引知识项失败", zap.String("itemId", itemID), zap.Error(err))
}
// 如果连续失败2次立即停止增量索引
if consecutiveFailures >= 2 {
logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
}
logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
} else {
logger.Info("检测到已有知识库索引,没有需要索引的新项或更新项")
}
return
}
// 只有在没有索引时才自动重建
logger.Info("未检测到知识库索引,开始自动构建索引")
ctx := context.Background()
if err := knowledgeIndexer.RebuildIndex(ctx); err != nil {
logger.Warn("重建知识库索引失败", zap.Error(err))
}
}()
return knowledgeHandler, nil
}
// corsMiddleware CORS中间件
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}