diff --git a/internal/app/app.go b/internal/app/app.go index b041a838..811a608c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -52,9 +52,10 @@ type App struct { robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启 larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启 - c2Manager *c2.Manager // C2 管理器 + c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil) c2Watchdog *c2.SessionWatchdog // C2 会话看门狗 c2WatchdogCancel context.CancelFunc // 看门狗取消函数 + c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步) } // New 创建新应用 @@ -343,50 +344,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { } // ============================================================================ - // 初始化 C2 模块 + // 初始化 C2 模块(可按配置关闭,节省本机部署资源) // ============================================================================ - c2Manager := c2.NewManager(db, log.Logger, "tmp/c2") - // 注册 Listener 工厂 - c2Manager.Registry().Register(string(c2.ListenerTypeTCPReverse), c2.NewTCPReverseListener) - c2Manager.Registry().Register(string(c2.ListenerTypeHTTPBeacon), c2.NewHTTPBeaconListener) - c2Manager.Registry().Register(string(c2.ListenerTypeHTTPSBeacon), c2.NewHTTPSBeaconListener) - c2Manager.Registry().Register(string(c2.ListenerTypeWebSocket), c2.NewWebSocketListener) - // 设置 HITL 桥(仅当会话开启人机协同且 c2_task 不在免审批白名单时,危险任务才走桥) - c2HITLBridge := NewC2HITLBridge(db, log.Logger) - c2Manager.SetHITLBridge(c2HITLBridge) - c2Manager.SetHITLDangerousGate(func(conversationID, toolName string) bool { - return agentHandler.HITLNeedsToolApproval(conversationID, toolName) - }) - // 设置业务钩子 - c2Hooks := SetupC2Hooks(&C2HooksConfig{ - DB: db, - Logger: log.Logger, - AttackChainRecord: func(session *database.C2Session, phase string, description string) { - // 通过攻击链处理器记录(简化版,实际需要完整实现) - log.Logger.Info("C2 Attack Chain", - zap.String("session_id", session.ID), - zap.String("phase", phase), - zap.String("desc", description), - ) - }, - VulnRecord: func(session *database.C2Session, title string, severity string) { - // 记录漏洞(简化版) - log.Logger.Info("C2 Vulnerability", - zap.String("session_id", session.ID), - zap.String("title", title), - zap.String("severity", severity), - ) - }, - }) - c2Manager.SetHooks(c2Hooks) - // 恢复运行中的监听器 - c2Manager.RestoreRunningListeners() - // 启动会话看门狗 - c2Watchdog := c2.NewSessionWatchdog(c2Manager) - watchdogCtx, watchdogCancel := context.WithCancel(context.Background()) - go c2Watchdog.Run(watchdogCtx) - // 注册 C2 MCP 工具 - registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port) + c2Manager, c2Watchdog, watchdogCancel := setupC2Runtime(cfg, db, agentHandler, log.Logger) + if c2Manager != nil { + registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port) + } + c2Handler := handler.NewC2Handler(c2Manager, log.Logger) // 创建OpenAPI处理器 conversationHandler := handler.NewConversationHandler(db, log.Logger) @@ -414,6 +378,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { c2Manager: c2Manager, c2Watchdog: c2Watchdog, c2WatchdogCancel: watchdogCancel, + c2Handler: c2Handler, } // 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启 app.startRobotConnections() @@ -482,8 +447,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { // 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效 configHandler.SetRobotRestarter(app) - // 创建 C2 Handler - c2Handler := handler.NewC2Handler(c2Manager, log.Logger) + configHandler.SetC2Runtime(app) + configHandler.SetC2ToolRegistrar(func() error { + if app.config.C2.EnabledEffective() && app.c2Manager != nil { + registerC2Tools(mcpServer, app.c2Manager, log.Logger, app.config.Server.Port) + } + return nil + }) // 设置路由(使用 App 实例以便动态获取 handler) setupRoutes( @@ -507,7 +477,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { markdownAgentsHandler, fofaHandler, terminalHandler, - c2Handler, + app.c2Handler, mcpServer, authManager, openAPIHandler, @@ -599,14 +569,7 @@ func (a *App) Shutdown() { } a.robotMu.Unlock() - // 停止 C2 看门狗 - if a.c2WatchdogCancel != nil { - a.c2WatchdogCancel() - } - // 关闭 C2 Manager(停止所有监听器) - if a.c2Manager != nil { - a.c2Manager.Close() - } + a.shutdownC2() // 停止所有外部MCP客户端 if a.externalMCPMgr != nil { @@ -994,46 +957,51 @@ func setupRoutes( protected.POST("/webshell/exec", webshellHandler.Exec) protected.POST("/webshell/file", webshellHandler.FileOp) - // C2 管理(AI-Native 轻量级 C2 框架) - // 监听器 - protected.GET("/c2/listeners", c2Handler.ListListeners) - protected.POST("/c2/listeners", c2Handler.CreateListener) - protected.GET("/c2/listeners/:id", c2Handler.GetListener) - protected.PUT("/c2/listeners/:id", c2Handler.UpdateListener) - protected.DELETE("/c2/listeners/:id", c2Handler.DeleteListener) - protected.POST("/c2/listeners/:id/start", c2Handler.StartListener) - protected.POST("/c2/listeners/:id/stop", c2Handler.StopListener) - // 会话 - protected.GET("/c2/sessions", c2Handler.ListSessions) - protected.GET("/c2/sessions/:id", c2Handler.GetSession) - protected.DELETE("/c2/sessions/:id", c2Handler.DeleteSession) - protected.PUT("/c2/sessions/:id/sleep", c2Handler.SetSessionSleep) - // 任务 - protected.GET("/c2/tasks", c2Handler.ListTasks) - protected.DELETE("/c2/tasks", c2Handler.DeleteTasks) - protected.GET("/c2/tasks/:id", c2Handler.GetTask) - protected.POST("/c2/tasks", c2Handler.CreateTask) - protected.POST("/c2/tasks/:id/cancel", c2Handler.CancelTask) - protected.GET("/c2/tasks/:id/wait", c2Handler.WaitTask) - protected.POST("/c2/sessions/:id/tasks", c2Handler.CreateTask) // 快捷方式:直接对会话下发任务 - // Payload - protected.POST("/c2/payloads/oneliner", c2Handler.PayloadOneliner) - protected.POST("/c2/payloads/build", c2Handler.PayloadBuild) - protected.GET("/c2/payloads/:id/download", c2Handler.PayloadDownload) - // 事件 & SSE - protected.GET("/c2/events", c2Handler.ListEvents) - protected.DELETE("/c2/events", c2Handler.DeleteEvents) - protected.GET("/c2/events/stream", c2Handler.EventStream) - // 文件管理 - protected.POST("/c2/files/upload", c2Handler.UploadFileForImplant) - protected.GET("/c2/files", c2Handler.ListFiles) - protected.GET("/c2/tasks/:id/result-file", c2Handler.DownloadResultFile) - // Malleable Profile - protected.GET("/c2/profiles", c2Handler.ListProfiles) - protected.GET("/c2/profiles/:id", c2Handler.GetProfile) - protected.POST("/c2/profiles", c2Handler.CreateProfile) - protected.PUT("/c2/profiles/:id", c2Handler.UpdateProfile) - protected.DELETE("/c2/profiles/:id", c2Handler.DeleteProfile) + // C2 管理(未启用时返回 503,避免 Handler 空指针) + c2Routes := protected.Group("/c2") + c2Routes.Use(func(c *gin.Context) { + if app.c2Manager == nil { + c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{ + "error": "c2_disabled", + "message": "C2 功能已在系统设置中关闭", + "enabled": false, + }) + return + } + c.Next() + }) + c2Routes.GET("/listeners", c2Handler.ListListeners) + c2Routes.POST("/listeners", c2Handler.CreateListener) + c2Routes.GET("/listeners/:id", c2Handler.GetListener) + c2Routes.PUT("/listeners/:id", c2Handler.UpdateListener) + c2Routes.DELETE("/listeners/:id", c2Handler.DeleteListener) + c2Routes.POST("/listeners/:id/start", c2Handler.StartListener) + c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener) + c2Routes.GET("/sessions", c2Handler.ListSessions) + c2Routes.GET("/sessions/:id", c2Handler.GetSession) + c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession) + c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep) + c2Routes.GET("/tasks", c2Handler.ListTasks) + c2Routes.DELETE("/tasks", c2Handler.DeleteTasks) + c2Routes.GET("/tasks/:id", c2Handler.GetTask) + c2Routes.POST("/tasks", c2Handler.CreateTask) + c2Routes.POST("/tasks/:id/cancel", c2Handler.CancelTask) + c2Routes.GET("/tasks/:id/wait", c2Handler.WaitTask) + c2Routes.POST("/sessions/:id/tasks", c2Handler.CreateTask) + c2Routes.POST("/payloads/oneliner", c2Handler.PayloadOneliner) + c2Routes.POST("/payloads/build", c2Handler.PayloadBuild) + c2Routes.GET("/payloads/:id/download", c2Handler.PayloadDownload) + c2Routes.GET("/events", c2Handler.ListEvents) + c2Routes.DELETE("/events", c2Handler.DeleteEvents) + c2Routes.GET("/events/stream", c2Handler.EventStream) + c2Routes.POST("/files/upload", c2Handler.UploadFileForImplant) + c2Routes.GET("/files", c2Handler.ListFiles) + c2Routes.GET("/tasks/:id/result-file", c2Handler.DownloadResultFile) + c2Routes.GET("/profiles", c2Handler.ListProfiles) + c2Routes.GET("/profiles/:id", c2Handler.GetProfile) + c2Routes.POST("/profiles", c2Handler.CreateProfile) + c2Routes.PUT("/profiles/:id", c2Handler.UpdateProfile) + c2Routes.DELETE("/profiles/:id", c2Handler.DeleteProfile) // 对话附件(chat_uploads)管理 protected.GET("/chat-uploads", chatUploadsHandler.List) diff --git a/internal/app/c2_lifecycle.go b/internal/app/c2_lifecycle.go new file mode 100644 index 00000000..af651c39 --- /dev/null +++ b/internal/app/c2_lifecycle.go @@ -0,0 +1,104 @@ +package app + +import ( + "context" + + "cyberstrike-ai/internal/c2" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/handler" + + "go.uber.org/zap" +) + +// setupC2Runtime 创建 C2 Manager、看门狗与取消函数;不注册 MCP 工具(由 Apply 统一 ClearTools 后注册)。 +func setupC2Runtime( + cfg *config.Config, + db *database.DB, + agentHandler *handler.AgentHandler, + logger *zap.Logger, +) (*c2.Manager, *c2.SessionWatchdog, context.CancelFunc) { + if !cfg.C2.EnabledEffective() { + return nil, nil, nil + } + c2Manager := c2.NewManager(db, logger, "tmp/c2") + c2Manager.Registry().Register(string(c2.ListenerTypeTCPReverse), c2.NewTCPReverseListener) + c2Manager.Registry().Register(string(c2.ListenerTypeHTTPBeacon), c2.NewHTTPBeaconListener) + c2Manager.Registry().Register(string(c2.ListenerTypeHTTPSBeacon), c2.NewHTTPSBeaconListener) + c2Manager.Registry().Register(string(c2.ListenerTypeWebSocket), c2.NewWebSocketListener) + c2HITLBridge := NewC2HITLBridge(db, logger) + c2Manager.SetHITLBridge(c2HITLBridge) + c2Manager.SetHITLDangerousGate(func(conversationID, toolName string) bool { + return agentHandler.HITLNeedsToolApproval(conversationID, toolName) + }) + c2Hooks := SetupC2Hooks(&C2HooksConfig{ + DB: db, + Logger: logger, + AttackChainRecord: func(session *database.C2Session, phase string, description string) { + logger.Info("C2 Attack Chain", + zap.String("session_id", session.ID), + zap.String("phase", phase), + zap.String("desc", description), + ) + }, + VulnRecord: func(session *database.C2Session, title string, severity string) { + logger.Info("C2 Vulnerability", + zap.String("session_id", session.ID), + zap.String("title", title), + zap.String("severity", severity), + ) + }, + }) + c2Manager.SetHooks(c2Hooks) + c2Manager.RestoreRunningListeners() + c2Watchdog := c2.NewSessionWatchdog(c2Manager) + watchdogCtx, watchdogCancel := context.WithCancel(context.Background()) + go c2Watchdog.Run(watchdogCtx) + return c2Manager, c2Watchdog, watchdogCancel +} + +// ReconcileC2AfterConfigApply 根据当前内存配置启停 C2(不写盘;在 Apply 中 ClearTools 之前调用)。 +func (a *App) ReconcileC2AfterConfigApply() error { + if !a.config.C2.EnabledEffective() { + a.shutdownC2() + return nil + } + if a.c2Manager != nil { + return nil + } + if a.db == nil || a.agentHandler == nil { + return nil + } + m, wd, cancel := setupC2Runtime(a.config, a.db, a.agentHandler, a.logger.Logger) + if m == nil { + return nil + } + a.c2Manager = m + a.c2Watchdog = wd + a.c2WatchdogCancel = cancel + if a.c2Handler != nil { + a.c2Handler.SetManager(m) + } + a.logger.Info("C2 子系统已按配置启动") + return nil +} + +// shutdownC2 停止看门狗与所有监听器,并断开 Handler 引用。 +func (a *App) shutdownC2() { + had := a.c2WatchdogCancel != nil || a.c2Manager != nil + if a.c2WatchdogCancel != nil { + a.c2WatchdogCancel() + a.c2WatchdogCancel = nil + } + a.c2Watchdog = nil + if a.c2Manager != nil { + a.c2Manager.Close() + a.c2Manager = nil + } + if a.c2Handler != nil { + a.c2Handler.SetManager(nil) + } + if had { + a.logger.Info("C2 子系统已关闭") + } +} diff --git a/internal/handler/c2.go b/internal/handler/c2.go index a835db1b..22639b50 100644 --- a/internal/handler/c2.go +++ b/internal/handler/c2.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "time" "cyberstrike-ai/internal/c2" @@ -20,18 +21,28 @@ import ( "go.uber.org/zap" ) -// C2Handler 处理 C2 相关的 REST API +// C2Handler 处理 C2 相关的 REST API(manager 可在运行时置 nil 以关闭 C2) type C2Handler struct { - manager *c2.Manager - logger *zap.Logger + mgrPtr atomic.Pointer[c2.Manager] + logger *zap.Logger } -// NewC2Handler 创建 C2 处理器 +// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时) func NewC2Handler(manager *c2.Manager, logger *zap.Logger) *C2Handler { - return &C2Handler{ - manager: manager, - logger: logger, + h := &C2Handler{logger: logger} + if manager != nil { + h.mgrPtr.Store(manager) } + return h +} + +func (h *C2Handler) mgr() *c2.Manager { + return h.mgrPtr.Load() +} + +// SetManager 运行时切换或清空 C2 Manager(与 App 启停同步) +func (h *C2Handler) SetManager(m *c2.Manager) { + h.mgrPtr.Store(m) } // ============================================================================ @@ -40,7 +51,7 @@ func NewC2Handler(manager *c2.Manager, logger *zap.Logger) *C2Handler { // ListListeners 获取监听器列表 func (h *C2Handler) ListListeners(c *gin.Context) { - listeners, err := h.manager.DB().ListC2Listeners() + listeners, err := h.mgr().DB().ListC2Listeners() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -81,7 +92,7 @@ func (h *C2Handler) CreateListener(c *gin.Context) { CallbackHost: strings.TrimSpace(req.CallbackHost), } - listener, err := h.manager.CreateListener(input) + listener, err := h.mgr().CreateListener(input) if err != nil { code := http.StatusInternalServerError if e, ok := err.(*c2.CommonError); ok { @@ -99,7 +110,7 @@ func (h *C2Handler) CreateListener(c *gin.Context) { // GetListener 获取单个监听器 func (h *C2Handler) GetListener(c *gin.Context) { id := c.Param("id") - listener, err := h.manager.DB().GetC2Listener(id) + listener, err := h.mgr().DB().GetC2Listener(id) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -116,7 +127,7 @@ func (h *C2Handler) GetListener(c *gin.Context) { // UpdateListener 更新监听器 func (h *C2Handler) UpdateListener(c *gin.Context) { id := c.Param("id") - listener, err := h.manager.DB().GetC2Listener(id) + listener, err := h.mgr().DB().GetC2Listener(id) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -141,7 +152,7 @@ func (h *C2Handler) UpdateListener(c *gin.Context) { } // 若监听器在运行,不能修改关键字段 - if h.manager.IsListenerRunning(id) { + if h.mgr().IsListenerRunning(id) { if req.BindHost != listener.BindHost || req.BindPort != listener.BindPort { c.JSON(http.StatusConflict, gin.H{"error": "cannot modify bind address while listener is running"}) return @@ -174,7 +185,7 @@ func (h *C2Handler) UpdateListener(c *gin.Context) { listener.ConfigJSON = string(cfgJSON) } - if err := h.manager.DB().UpdateC2Listener(listener); err != nil { + if err := h.mgr().DB().UpdateC2Listener(listener); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -186,7 +197,7 @@ func (h *C2Handler) UpdateListener(c *gin.Context) { // DeleteListener 删除监听器 func (h *C2Handler) DeleteListener(c *gin.Context) { id := c.Param("id") - if err := h.manager.DeleteListener(id); err != nil { + if err := h.mgr().DeleteListener(id); err != nil { code := http.StatusInternalServerError if e, ok := err.(*c2.CommonError); ok { code = e.HTTP @@ -200,7 +211,7 @@ func (h *C2Handler) DeleteListener(c *gin.Context) { // StartListener 启动监听器 func (h *C2Handler) StartListener(c *gin.Context) { id := c.Param("id") - listener, err := h.manager.StartListener(id) + listener, err := h.mgr().StartListener(id) if err != nil { code := http.StatusInternalServerError if e, ok := err.(*c2.CommonError); ok { @@ -217,7 +228,7 @@ func (h *C2Handler) StartListener(c *gin.Context) { // StopListener 停止监听器 func (h *C2Handler) StopListener(c *gin.Context) { id := c.Param("id") - if err := h.manager.StopListener(id); err != nil { + if err := h.mgr().StopListener(id); err != nil { code := http.StatusInternalServerError if e, ok := err.(*c2.CommonError); ok { code = e.HTTP @@ -246,7 +257,7 @@ func (h *C2Handler) ListSessions(c *gin.Context) { } } - sessions, err := h.manager.DB().ListC2Sessions(filter) + sessions, err := h.mgr().DB().ListC2Sessions(filter) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -257,7 +268,7 @@ func (h *C2Handler) ListSessions(c *gin.Context) { // GetSession 获取单个会话 func (h *C2Handler) GetSession(c *gin.Context) { id := c.Param("id") - session, err := h.manager.DB().GetC2Session(id) + session, err := h.mgr().DB().GetC2Session(id) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -268,7 +279,7 @@ func (h *C2Handler) GetSession(c *gin.Context) { } // 获取最近任务 - tasks, _ := h.manager.DB().ListC2Tasks(database.ListC2TasksFilter{ + tasks, _ := h.mgr().DB().ListC2Tasks(database.ListC2TasksFilter{ SessionID: id, Limit: 20, }) @@ -282,7 +293,7 @@ func (h *C2Handler) GetSession(c *gin.Context) { // DeleteSession 删除会话 func (h *C2Handler) DeleteSession(c *gin.Context) { id := c.Param("id") - if err := h.manager.DB().DeleteC2Session(id); err != nil { + if err := h.mgr().DB().DeleteC2Session(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -301,7 +312,7 @@ func (h *C2Handler) SetSessionSleep(c *gin.Context) { return } - if err := h.manager.DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil { + if err := h.mgr().DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -343,14 +354,14 @@ func (h *C2Handler) ListTasks(c *gin.Context) { } } - tasks, err := h.manager.DB().ListC2Tasks(filter) + tasks, err := h.mgr().DB().ListC2Tasks(filter) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 仪表盘「待审任务」为全局 queued/pending 数量,与列表 session 过滤无关 - pendingN, _ := h.manager.DB().CountC2TasksQueuedOrPending("") + pendingN, _ := h.mgr().DB().CountC2TasksQueuedOrPending("") if !paginated { c.JSON(http.StatusOK, gin.H{ @@ -360,7 +371,7 @@ func (h *C2Handler) ListTasks(c *gin.Context) { return } - total, err := h.manager.DB().CountC2Tasks(filter) + total, err := h.mgr().DB().CountC2Tasks(filter) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -387,7 +398,7 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"}) return } - n, err := h.manager.DB().DeleteC2TasksByIDs(req.IDs) + n, err := h.mgr().DB().DeleteC2TasksByIDs(req.IDs) if err != nil { if errors.Is(err, database.ErrNoValidC2TaskIDs) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -402,7 +413,7 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) { // GetTask 获取单个任务 func (h *C2Handler) GetTask(c *gin.Context) { id := c.Param("id") - task, err := h.manager.DB().GetC2Task(id) + task, err := h.mgr().DB().GetC2Task(id) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -437,7 +448,7 @@ func (h *C2Handler) CreateTask(c *gin.Context) { UserCtx: c.Request.Context(), } - task, err := h.manager.EnqueueTask(input) + task, err := h.mgr().EnqueueTask(input) if err != nil { code := http.StatusInternalServerError if e, ok := err.(*c2.CommonError); ok { @@ -452,7 +463,7 @@ func (h *C2Handler) CreateTask(c *gin.Context) { // CancelTask 取消任务 func (h *C2Handler) CancelTask(c *gin.Context) { id := c.Param("id") - if err := h.manager.CancelTask(id); err != nil { + if err := h.mgr().CancelTask(id); err != nil { code := http.StatusInternalServerError if e, ok := err.(*c2.CommonError); ok { code = e.HTTP @@ -475,7 +486,7 @@ func (h *C2Handler) WaitTask(c *gin.Context) { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { - task, err := h.manager.DB().GetC2Task(id) + task, err := h.mgr().DB().GetC2Task(id) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -509,7 +520,7 @@ func (h *C2Handler) PayloadOneliner(c *gin.Context) { return } - listener, err := h.manager.DB().GetC2Listener(req.ListenerID) + listener, err := h.mgr().DB().GetC2Listener(req.ListenerID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -572,7 +583,7 @@ func (h *C2Handler) PayloadBuild(c *gin.Context) { return } - listener, err := h.manager.DB().GetC2Listener(req.ListenerID) + listener, err := h.mgr().DB().GetC2Listener(req.ListenerID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -582,7 +593,7 @@ func (h *C2Handler) PayloadBuild(c *gin.Context) { return } - builder := c2.NewPayloadBuilder(h.manager, h.logger, "", "") + builder := c2.NewPayloadBuilder(h.mgr(), h.logger, "", "") input := c2.PayloadBuilderInput{ ListenerID: req.ListenerID, OS: req.OS, @@ -616,7 +627,7 @@ func (h *C2Handler) PayloadDownload(c *gin.Context) { return } - builder := c2.NewPayloadBuilder(h.manager, h.logger, "", "") + builder := c2.NewPayloadBuilder(h.mgr(), h.logger, "", "") storageDir := builder.GetPayloadStoragePath() targetPath := filepath.Join(storageDir, filename) @@ -676,7 +687,7 @@ func (h *C2Handler) ListEvents(c *gin.Context) { } } - events, err := h.manager.DB().ListC2Events(filter) + events, err := h.mgr().DB().ListC2Events(filter) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -685,7 +696,7 @@ func (h *C2Handler) ListEvents(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"events": events}) return } - total, err := h.manager.DB().CountC2Events(filter) + total, err := h.mgr().DB().CountC2Events(filter) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -711,7 +722,7 @@ func (h *C2Handler) DeleteEvents(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"}) return } - n, err := h.manager.DB().DeleteC2EventsByIDs(req.IDs) + n, err := h.mgr().DB().DeleteC2EventsByIDs(req.IDs) if err != nil { if errors.Is(err, database.ErrNoValidC2EventIDs) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -733,14 +744,14 @@ func (h *C2Handler) EventStream(c *gin.Context) { categoryFilter := c.Query("category") levels := c.QueryArray("level") - sub := h.manager.EventBus().Subscribe( + sub := h.mgr().EventBus().Subscribe( "sse-"+uuid.New().String(), 128, sessionFilter, categoryFilter, levels, ) - defer h.manager.EventBus().Unsubscribe(sub.ID) + defer h.mgr().EventBus().Unsubscribe(sub.ID) c.Stream(func(w io.Writer) bool { select { @@ -763,7 +774,7 @@ func (h *C2Handler) EventStream(c *gin.Context) { // ListProfiles 获取 Malleable Profile 列表 func (h *C2Handler) ListProfiles(c *gin.Context) { - profiles, err := h.manager.DB().ListC2Profiles() + profiles, err := h.mgr().DB().ListC2Profiles() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -774,7 +785,7 @@ func (h *C2Handler) ListProfiles(c *gin.Context) { // GetProfile 获取单个 Profile func (h *C2Handler) GetProfile(c *gin.Context) { id := c.Param("id") - profile, err := h.manager.DB().GetC2Profile(id) + profile, err := h.mgr().DB().GetC2Profile(id) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -797,7 +808,7 @@ func (h *C2Handler) CreateProfile(c *gin.Context) { req.ID = "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] req.CreatedAt = time.Now() - if err := h.manager.DB().CreateC2Profile(&req); err != nil { + if err := h.mgr().DB().CreateC2Profile(&req); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -807,7 +818,7 @@ func (h *C2Handler) CreateProfile(c *gin.Context) { // UpdateProfile 更新 Profile func (h *C2Handler) UpdateProfile(c *gin.Context) { id := c.Param("id") - profile, err := h.manager.DB().GetC2Profile(id) + profile, err := h.mgr().DB().GetC2Profile(id) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -832,7 +843,7 @@ func (h *C2Handler) UpdateProfile(c *gin.Context) { profile.JitterMinMS = req.JitterMinMS profile.JitterMaxMS = req.JitterMaxMS - if err := h.manager.DB().UpdateC2Profile(profile); err != nil { + if err := h.mgr().DB().UpdateC2Profile(profile); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -842,7 +853,7 @@ func (h *C2Handler) UpdateProfile(c *gin.Context) { // DeleteProfile 删除 Profile func (h *C2Handler) DeleteProfile(c *gin.Context) { id := c.Param("id") - if err := h.manager.DB().DeleteC2Profile(id); err != nil { + if err := h.mgr().DB().DeleteC2Profile(id); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -870,7 +881,7 @@ func (h *C2Handler) UploadFileForImplant(c *gin.Context) { defer file.Close() fileID := "f_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] - dir := filepath.Join(h.manager.StorageDir(), "downstream") + dir := filepath.Join(h.mgr().StorageDir(), "downstream") if err := osMkdirAll(dir); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -898,7 +909,7 @@ func (h *C2Handler) UploadFileForImplant(c *gin.Context) { SizeBytes: n, CreatedAt: time.Now(), } - _ = h.manager.DB().CreateC2File(dbFile) + _ = h.mgr().DB().CreateC2File(dbFile) c.JSON(http.StatusOK, gin.H{ "file_id": fileID, @@ -915,7 +926,7 @@ func (h *C2Handler) ListFiles(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "session_id required"}) return } - files, err := h.manager.DB().ListC2FilesBySession(sessionID) + files, err := h.mgr().DB().ListC2FilesBySession(sessionID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -926,7 +937,7 @@ func (h *C2Handler) ListFiles(c *gin.Context) { // DownloadResultFile 下载任务结果文件(截图等 blob 结果) func (h *C2Handler) DownloadResultFile(c *gin.Context) { taskID := c.Param("id") - task, err := h.manager.DB().GetC2Task(taskID) + task, err := h.mgr().DB().GetC2Task(taskID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/internal/handler/config.go b/internal/handler/config.go index 9f48397f..b5497327 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -41,6 +41,14 @@ type SkillsToolRegistrar func() error // BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册) type BatchTaskToolRegistrar func() error +// C2ToolRegistrar C2 MCP 工具注册器(ApplyConfig 时 ClearTools 之后调用) +type C2ToolRegistrar func() error + +// C2Runtime ApplyConfig 时按配置启停 C2 子系统(由 internal/app.App 实现) +type C2Runtime interface { + ReconcileC2AfterConfigApply() error +} + // RetrieverUpdater 检索器更新接口 type RetrieverUpdater interface { UpdateConfig(config *knowledge.RetrievalConfig) @@ -73,6 +81,8 @@ type ConfigHandler struct { webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选) skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选) batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选) + c2ToolRegistrar C2ToolRegistrar // C2 MCP 工具(可选) + c2Runtime C2Runtime // C2 启停(可选) retrieverUpdater RetrieverUpdater // 检索器更新器(可选) knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) appUpdater AppUpdater // App更新器(可选) @@ -154,6 +164,20 @@ func (h *ConfigHandler) SetBatchTaskToolRegistrar(registrar BatchTaskToolRegistr h.batchTaskToolRegistrar = registrar } +// SetC2ToolRegistrar 设置 C2 MCP 工具注册器 +func (h *ConfigHandler) SetC2ToolRegistrar(registrar C2ToolRegistrar) { + h.mu.Lock() + defer h.mu.Unlock() + h.c2ToolRegistrar = registrar +} + +// SetC2Runtime 设置 C2 运行时(Apply 时启停) +func (h *ConfigHandler) SetC2Runtime(rt C2Runtime) { + h.mu.Lock() + defer h.mu.Unlock() + h.c2Runtime = rt +} + // SetRetrieverUpdater 设置检索器更新器 func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) { h.mu.Lock() @@ -193,6 +217,7 @@ type GetConfigResponse struct { Knowledge config.KnowledgeConfig `json:"knowledge"` Robots config.RobotsConfig `json:"robots,omitempty"` MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"` + C2 config.C2Public `json:"c2"` } // ToolConfigInfo 工具配置信息 @@ -286,6 +311,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { Agent: h.config.Agent, Hitl: h.config.Hitl, Knowledge: h.config.Knowledge, + C2: h.config.C2.Public(), Robots: h.config.Robots, MultiAgent: multiPub, }) @@ -591,6 +617,7 @@ type UpdateConfigRequest struct { Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"` Robots *config.RobotsConfig `json:"robots,omitempty"` MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"` + C2 *config.C2APIUpdate `json:"c2,omitempty"` } // ToolEnableStatus 工具启用状态 @@ -676,6 +703,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { ) } + if req.C2 != nil { + v := req.C2.Enabled + h.config.C2.Enabled = &v + h.logger.Info("更新C2配置", zap.Bool("enabled", v)) + } + // 多代理标量(sub_agents 等仍由 config.yaml 维护) if req.MultiAgent != nil { h.config.MultiAgent.Enabled = req.MultiAgent.Enabled @@ -980,6 +1013,18 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { h.logger.Info("知识库组件重新初始化完成") } + // C2:在 ClearTools 之前按配置启停(随后由 c2ToolRegistrar 注册 MCP 工具) + h.mu.RLock() + c2Rt := h.c2Runtime + h.mu.RUnlock() + if c2Rt != nil { + if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil { + h.logger.Error("C2 配置应用失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()}) + return + } + } + // 现在获取写锁,执行快速的操作 h.mu.Lock() defer h.mu.Unlock() @@ -1044,6 +1089,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) { } } + // 重新注册 C2 MCP 工具(仅当 C2 已启动) + if h.c2ToolRegistrar != nil { + h.logger.Info("重新注册 C2 MCP 工具") + if err := h.c2ToolRegistrar(); err != nil { + h.logger.Error("重新注册 C2 MCP 工具失败", zap.Error(err)) + } else { + h.logger.Info("C2 MCP 工具已处理") + } + } + // 如果知识库启用,重新注册知识库工具 if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil { h.logger.Info("重新注册知识库工具") @@ -1131,6 +1186,7 @@ func (h *ConfigHandler) saveConfig() error { updateOpenAIConfig(root, h.config.OpenAI) updateFOFAConfig(root, h.config.FOFA) updateKnowledgeConfig(root, h.config.Knowledge) + updateC2Config(root, h.config.C2) updateRobotsConfig(root, h.config.Robots) updateHitlConfig(root, h.config.Hitl) updateMultiAgentConfig(root, h.config.MultiAgent) @@ -1309,6 +1365,12 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) { setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs) } +func updateC2Config(doc *yaml.Node, cfg config.C2Config) { + root := doc.Content[0] + c2Node := ensureMap(root, "c2") + setBoolInMap(c2Node, "enabled", cfg.EnabledEffective()) +} + func mergeHitlToolWhitelistSlice(existing, add []string) []string { seen := make(map[string]struct{}) out := make([]string, 0, len(existing)+len(add))