Add files via upload

This commit is contained in:
公明
2026-05-04 13:09:43 +08:00
committed by GitHub
parent b27e443d37
commit 57ebc7c04b
4 changed files with 288 additions and 143 deletions
+59 -48
View File
@@ -10,6 +10,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"time"
"cyberstrike-ai/internal/c2"
@@ -20,18 +21,28 @@ import (
"go.uber.org/zap"
)
// C2Handler 处理 C2 相关的 REST API
// C2Handler 处理 C2 相关的 REST APImanager 可在运行时置 nil 以关闭 C2)
type C2Handler struct {
manager *c2.Manager
logger *zap.Logger
mgrPtr atomic.Pointer[c2.Manager]
logger *zap.Logger
}
// NewC2Handler 创建 C2 处理器
// NewC2Handler 创建 C2 处理器manager 可为 nil(功能关闭时)
func NewC2Handler(manager *c2.Manager, logger *zap.Logger) *C2Handler {
return &C2Handler{
manager: manager,
logger: logger,
h := &C2Handler{logger: logger}
if manager != nil {
h.mgrPtr.Store(manager)
}
return h
}
func (h *C2Handler) mgr() *c2.Manager {
return h.mgrPtr.Load()
}
// SetManager 运行时切换或清空 C2 Manager(与 App 启停同步)
func (h *C2Handler) SetManager(m *c2.Manager) {
h.mgrPtr.Store(m)
}
// ============================================================================
@@ -40,7 +51,7 @@ func NewC2Handler(manager *c2.Manager, logger *zap.Logger) *C2Handler {
// ListListeners 获取监听器列表
func (h *C2Handler) ListListeners(c *gin.Context) {
listeners, err := h.manager.DB().ListC2Listeners()
listeners, err := h.mgr().DB().ListC2Listeners()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -81,7 +92,7 @@ func (h *C2Handler) CreateListener(c *gin.Context) {
CallbackHost: strings.TrimSpace(req.CallbackHost),
}
listener, err := h.manager.CreateListener(input)
listener, err := h.mgr().CreateListener(input)
if err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
@@ -99,7 +110,7 @@ func (h *C2Handler) CreateListener(c *gin.Context) {
// GetListener 获取单个监听器
func (h *C2Handler) GetListener(c *gin.Context) {
id := c.Param("id")
listener, err := h.manager.DB().GetC2Listener(id)
listener, err := h.mgr().DB().GetC2Listener(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -116,7 +127,7 @@ func (h *C2Handler) GetListener(c *gin.Context) {
// UpdateListener 更新监听器
func (h *C2Handler) UpdateListener(c *gin.Context) {
id := c.Param("id")
listener, err := h.manager.DB().GetC2Listener(id)
listener, err := h.mgr().DB().GetC2Listener(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -141,7 +152,7 @@ func (h *C2Handler) UpdateListener(c *gin.Context) {
}
// 若监听器在运行,不能修改关键字段
if h.manager.IsListenerRunning(id) {
if h.mgr().IsListenerRunning(id) {
if req.BindHost != listener.BindHost || req.BindPort != listener.BindPort {
c.JSON(http.StatusConflict, gin.H{"error": "cannot modify bind address while listener is running"})
return
@@ -174,7 +185,7 @@ func (h *C2Handler) UpdateListener(c *gin.Context) {
listener.ConfigJSON = string(cfgJSON)
}
if err := h.manager.DB().UpdateC2Listener(listener); err != nil {
if err := h.mgr().DB().UpdateC2Listener(listener); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -186,7 +197,7 @@ func (h *C2Handler) UpdateListener(c *gin.Context) {
// DeleteListener 删除监听器
func (h *C2Handler) DeleteListener(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DeleteListener(id); err != nil {
if err := h.mgr().DeleteListener(id); err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
@@ -200,7 +211,7 @@ func (h *C2Handler) DeleteListener(c *gin.Context) {
// StartListener 启动监听器
func (h *C2Handler) StartListener(c *gin.Context) {
id := c.Param("id")
listener, err := h.manager.StartListener(id)
listener, err := h.mgr().StartListener(id)
if err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
@@ -217,7 +228,7 @@ func (h *C2Handler) StartListener(c *gin.Context) {
// StopListener 停止监听器
func (h *C2Handler) StopListener(c *gin.Context) {
id := c.Param("id")
if err := h.manager.StopListener(id); err != nil {
if err := h.mgr().StopListener(id); err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
@@ -246,7 +257,7 @@ func (h *C2Handler) ListSessions(c *gin.Context) {
}
}
sessions, err := h.manager.DB().ListC2Sessions(filter)
sessions, err := h.mgr().DB().ListC2Sessions(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -257,7 +268,7 @@ func (h *C2Handler) ListSessions(c *gin.Context) {
// GetSession 获取单个会话
func (h *C2Handler) GetSession(c *gin.Context) {
id := c.Param("id")
session, err := h.manager.DB().GetC2Session(id)
session, err := h.mgr().DB().GetC2Session(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -268,7 +279,7 @@ func (h *C2Handler) GetSession(c *gin.Context) {
}
// 获取最近任务
tasks, _ := h.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
tasks, _ := h.mgr().DB().ListC2Tasks(database.ListC2TasksFilter{
SessionID: id,
Limit: 20,
})
@@ -282,7 +293,7 @@ func (h *C2Handler) GetSession(c *gin.Context) {
// DeleteSession 删除会话
func (h *C2Handler) DeleteSession(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DB().DeleteC2Session(id); err != nil {
if err := h.mgr().DB().DeleteC2Session(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -301,7 +312,7 @@ func (h *C2Handler) SetSessionSleep(c *gin.Context) {
return
}
if err := h.manager.DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil {
if err := h.mgr().DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -343,14 +354,14 @@ func (h *C2Handler) ListTasks(c *gin.Context) {
}
}
tasks, err := h.manager.DB().ListC2Tasks(filter)
tasks, err := h.mgr().DB().ListC2Tasks(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 仪表盘「待审任务」为全局 queued/pending 数量,与列表 session 过滤无关
pendingN, _ := h.manager.DB().CountC2TasksQueuedOrPending("")
pendingN, _ := h.mgr().DB().CountC2TasksQueuedOrPending("")
if !paginated {
c.JSON(http.StatusOK, gin.H{
@@ -360,7 +371,7 @@ func (h *C2Handler) ListTasks(c *gin.Context) {
return
}
total, err := h.manager.DB().CountC2Tasks(filter)
total, err := h.mgr().DB().CountC2Tasks(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -387,7 +398,7 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
return
}
n, err := h.manager.DB().DeleteC2TasksByIDs(req.IDs)
n, err := h.mgr().DB().DeleteC2TasksByIDs(req.IDs)
if err != nil {
if errors.Is(err, database.ErrNoValidC2TaskIDs) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -402,7 +413,7 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) {
// GetTask 获取单个任务
func (h *C2Handler) GetTask(c *gin.Context) {
id := c.Param("id")
task, err := h.manager.DB().GetC2Task(id)
task, err := h.mgr().DB().GetC2Task(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -437,7 +448,7 @@ func (h *C2Handler) CreateTask(c *gin.Context) {
UserCtx: c.Request.Context(),
}
task, err := h.manager.EnqueueTask(input)
task, err := h.mgr().EnqueueTask(input)
if err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
@@ -452,7 +463,7 @@ func (h *C2Handler) CreateTask(c *gin.Context) {
// CancelTask 取消任务
func (h *C2Handler) CancelTask(c *gin.Context) {
id := c.Param("id")
if err := h.manager.CancelTask(id); err != nil {
if err := h.mgr().CancelTask(id); err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
@@ -475,7 +486,7 @@ func (h *C2Handler) WaitTask(c *gin.Context) {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
task, err := h.manager.DB().GetC2Task(id)
task, err := h.mgr().DB().GetC2Task(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -509,7 +520,7 @@ func (h *C2Handler) PayloadOneliner(c *gin.Context) {
return
}
listener, err := h.manager.DB().GetC2Listener(req.ListenerID)
listener, err := h.mgr().DB().GetC2Listener(req.ListenerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -572,7 +583,7 @@ func (h *C2Handler) PayloadBuild(c *gin.Context) {
return
}
listener, err := h.manager.DB().GetC2Listener(req.ListenerID)
listener, err := h.mgr().DB().GetC2Listener(req.ListenerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -582,7 +593,7 @@ func (h *C2Handler) PayloadBuild(c *gin.Context) {
return
}
builder := c2.NewPayloadBuilder(h.manager, h.logger, "", "")
builder := c2.NewPayloadBuilder(h.mgr(), h.logger, "", "")
input := c2.PayloadBuilderInput{
ListenerID: req.ListenerID,
OS: req.OS,
@@ -616,7 +627,7 @@ func (h *C2Handler) PayloadDownload(c *gin.Context) {
return
}
builder := c2.NewPayloadBuilder(h.manager, h.logger, "", "")
builder := c2.NewPayloadBuilder(h.mgr(), h.logger, "", "")
storageDir := builder.GetPayloadStoragePath()
targetPath := filepath.Join(storageDir, filename)
@@ -676,7 +687,7 @@ func (h *C2Handler) ListEvents(c *gin.Context) {
}
}
events, err := h.manager.DB().ListC2Events(filter)
events, err := h.mgr().DB().ListC2Events(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -685,7 +696,7 @@ func (h *C2Handler) ListEvents(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"events": events})
return
}
total, err := h.manager.DB().CountC2Events(filter)
total, err := h.mgr().DB().CountC2Events(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -711,7 +722,7 @@ func (h *C2Handler) DeleteEvents(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
return
}
n, err := h.manager.DB().DeleteC2EventsByIDs(req.IDs)
n, err := h.mgr().DB().DeleteC2EventsByIDs(req.IDs)
if err != nil {
if errors.Is(err, database.ErrNoValidC2EventIDs) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -733,14 +744,14 @@ func (h *C2Handler) EventStream(c *gin.Context) {
categoryFilter := c.Query("category")
levels := c.QueryArray("level")
sub := h.manager.EventBus().Subscribe(
sub := h.mgr().EventBus().Subscribe(
"sse-"+uuid.New().String(),
128,
sessionFilter,
categoryFilter,
levels,
)
defer h.manager.EventBus().Unsubscribe(sub.ID)
defer h.mgr().EventBus().Unsubscribe(sub.ID)
c.Stream(func(w io.Writer) bool {
select {
@@ -763,7 +774,7 @@ func (h *C2Handler) EventStream(c *gin.Context) {
// ListProfiles 获取 Malleable Profile 列表
func (h *C2Handler) ListProfiles(c *gin.Context) {
profiles, err := h.manager.DB().ListC2Profiles()
profiles, err := h.mgr().DB().ListC2Profiles()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -774,7 +785,7 @@ func (h *C2Handler) ListProfiles(c *gin.Context) {
// GetProfile 获取单个 Profile
func (h *C2Handler) GetProfile(c *gin.Context) {
id := c.Param("id")
profile, err := h.manager.DB().GetC2Profile(id)
profile, err := h.mgr().DB().GetC2Profile(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -797,7 +808,7 @@ func (h *C2Handler) CreateProfile(c *gin.Context) {
req.ID = "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
req.CreatedAt = time.Now()
if err := h.manager.DB().CreateC2Profile(&req); err != nil {
if err := h.mgr().DB().CreateC2Profile(&req); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -807,7 +818,7 @@ func (h *C2Handler) CreateProfile(c *gin.Context) {
// UpdateProfile 更新 Profile
func (h *C2Handler) UpdateProfile(c *gin.Context) {
id := c.Param("id")
profile, err := h.manager.DB().GetC2Profile(id)
profile, err := h.mgr().DB().GetC2Profile(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -832,7 +843,7 @@ func (h *C2Handler) UpdateProfile(c *gin.Context) {
profile.JitterMinMS = req.JitterMinMS
profile.JitterMaxMS = req.JitterMaxMS
if err := h.manager.DB().UpdateC2Profile(profile); err != nil {
if err := h.mgr().DB().UpdateC2Profile(profile); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -842,7 +853,7 @@ func (h *C2Handler) UpdateProfile(c *gin.Context) {
// DeleteProfile 删除 Profile
func (h *C2Handler) DeleteProfile(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DB().DeleteC2Profile(id); err != nil {
if err := h.mgr().DB().DeleteC2Profile(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -870,7 +881,7 @@ func (h *C2Handler) UploadFileForImplant(c *gin.Context) {
defer file.Close()
fileID := "f_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
dir := filepath.Join(h.manager.StorageDir(), "downstream")
dir := filepath.Join(h.mgr().StorageDir(), "downstream")
if err := osMkdirAll(dir); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -898,7 +909,7 @@ func (h *C2Handler) UploadFileForImplant(c *gin.Context) {
SizeBytes: n,
CreatedAt: time.Now(),
}
_ = h.manager.DB().CreateC2File(dbFile)
_ = h.mgr().DB().CreateC2File(dbFile)
c.JSON(http.StatusOK, gin.H{
"file_id": fileID,
@@ -915,7 +926,7 @@ func (h *C2Handler) ListFiles(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "session_id required"})
return
}
files, err := h.manager.DB().ListC2FilesBySession(sessionID)
files, err := h.mgr().DB().ListC2FilesBySession(sessionID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -926,7 +937,7 @@ func (h *C2Handler) ListFiles(c *gin.Context) {
// DownloadResultFile 下载任务结果文件(截图等 blob 结果)
func (h *C2Handler) DownloadResultFile(c *gin.Context) {
taskID := c.Param("id")
task, err := h.manager.DB().GetC2Task(taskID)
task, err := h.mgr().DB().GetC2Task(taskID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
+62
View File
@@ -41,6 +41,14 @@ type SkillsToolRegistrar func() error
// BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册)
type BatchTaskToolRegistrar func() error
// C2ToolRegistrar C2 MCP 工具注册器(ApplyConfig 时 ClearTools 之后调用)
type C2ToolRegistrar func() error
// C2Runtime ApplyConfig 时按配置启停 C2 子系统(由 internal/app.App 实现)
type C2Runtime interface {
ReconcileC2AfterConfigApply() error
}
// RetrieverUpdater 检索器更新接口
type RetrieverUpdater interface {
UpdateConfig(config *knowledge.RetrievalConfig)
@@ -73,6 +81,8 @@ type ConfigHandler struct {
webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选)
skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选)
batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选)
c2ToolRegistrar C2ToolRegistrar // C2 MCP 工具(可选)
c2Runtime C2Runtime // C2 启停(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选)
@@ -154,6 +164,20 @@ func (h *ConfigHandler) SetBatchTaskToolRegistrar(registrar BatchTaskToolRegistr
h.batchTaskToolRegistrar = registrar
}
// SetC2ToolRegistrar 设置 C2 MCP 工具注册器
func (h *ConfigHandler) SetC2ToolRegistrar(registrar C2ToolRegistrar) {
h.mu.Lock()
defer h.mu.Unlock()
h.c2ToolRegistrar = registrar
}
// SetC2Runtime 设置 C2 运行时(Apply 时启停)
func (h *ConfigHandler) SetC2Runtime(rt C2Runtime) {
h.mu.Lock()
defer h.mu.Unlock()
h.c2Runtime = rt
}
// SetRetrieverUpdater 设置检索器更新器
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
h.mu.Lock()
@@ -193,6 +217,7 @@ type GetConfigResponse struct {
Knowledge config.KnowledgeConfig `json:"knowledge"`
Robots config.RobotsConfig `json:"robots,omitempty"`
MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"`
C2 config.C2Public `json:"c2"`
}
// ToolConfigInfo 工具配置信息
@@ -286,6 +311,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
Agent: h.config.Agent,
Hitl: h.config.Hitl,
Knowledge: h.config.Knowledge,
C2: h.config.C2.Public(),
Robots: h.config.Robots,
MultiAgent: multiPub,
})
@@ -591,6 +617,7 @@ type UpdateConfigRequest struct {
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
Robots *config.RobotsConfig `json:"robots,omitempty"`
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
C2 *config.C2APIUpdate `json:"c2,omitempty"`
}
// ToolEnableStatus 工具启用状态
@@ -676,6 +703,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
)
}
if req.C2 != nil {
v := req.C2.Enabled
h.config.C2.Enabled = &v
h.logger.Info("更新C2配置", zap.Bool("enabled", v))
}
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
if req.MultiAgent != nil {
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
@@ -980,6 +1013,18 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("知识库组件重新初始化完成")
}
// C2:在 ClearTools 之前按配置启停(随后由 c2ToolRegistrar 注册 MCP 工具)
h.mu.RLock()
c2Rt := h.c2Runtime
h.mu.RUnlock()
if c2Rt != nil {
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
h.logger.Error("C2 配置应用失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
return
}
}
// 现在获取写锁,执行快速的操作
h.mu.Lock()
defer h.mu.Unlock()
@@ -1044,6 +1089,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
}
}
// 重新注册 C2 MCP 工具(仅当 C2 已启动)
if h.c2ToolRegistrar != nil {
h.logger.Info("重新注册 C2 MCP 工具")
if err := h.c2ToolRegistrar(); err != nil {
h.logger.Error("重新注册 C2 MCP 工具失败", zap.Error(err))
} else {
h.logger.Info("C2 MCP 工具已处理")
}
}
// 如果知识库启用,重新注册知识库工具
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
h.logger.Info("重新注册知识库工具")
@@ -1131,6 +1186,7 @@ func (h *ConfigHandler) saveConfig() error {
updateOpenAIConfig(root, h.config.OpenAI)
updateFOFAConfig(root, h.config.FOFA)
updateKnowledgeConfig(root, h.config.Knowledge)
updateC2Config(root, h.config.C2)
updateRobotsConfig(root, h.config.Robots)
updateHitlConfig(root, h.config.Hitl)
updateMultiAgentConfig(root, h.config.MultiAgent)
@@ -1309,6 +1365,12 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs)
}
func updateC2Config(doc *yaml.Node, cfg config.C2Config) {
root := doc.Content[0]
c2Node := ensureMap(root, "c2")
setBoolInMap(c2Node, "enabled", cfg.EnabledEffective())
}
func mergeHitlToolWhitelistSlice(existing, add []string) []string {
seen := make(map[string]struct{})
out := make([]string, 0, len(existing)+len(add))