Add files via upload

This commit is contained in:
公明
2026-05-04 03:44:08 +08:00
committed by GitHub
parent 4fca4a85c2
commit 8158b38f48
12 changed files with 3661 additions and 2 deletions
+8
View File
@@ -184,6 +184,14 @@ func (h *AgentHandler) SetHitlToolWhitelistSaver(s HitlToolWhitelistSaver) {
h.hitlWhitelistSaver = s
}
// HITLNeedsToolApproval 供 C2 危险任务门控:与会话侧人机协同及免审批白名单判定一致。
func (h *AgentHandler) HITLNeedsToolApproval(conversationID, toolName string) bool {
if h == nil || h.hitlManager == nil {
return false
}
return h.hitlManager.NeedsToolApproval(conversationID, toolName)
}
// ChatAttachment 聊天附件(用户上传的文件)
type ChatAttachment struct {
FileName string `json:"fileName"` // 展示用文件名
+955
View File
@@ -0,0 +1,955 @@
package handler
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// C2Handler 处理 C2 相关的 REST API
type C2Handler struct {
manager *c2.Manager
logger *zap.Logger
}
// NewC2Handler 创建 C2 处理器
func NewC2Handler(manager *c2.Manager, logger *zap.Logger) *C2Handler {
return &C2Handler{
manager: manager,
logger: logger,
}
}
// ============================================================================
// 监听器 API
// ============================================================================
// ListListeners 获取监听器列表
func (h *C2Handler) ListListeners(c *gin.Context) {
listeners, err := h.manager.DB().ListC2Listeners()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 移除敏感字段
for _, l := range listeners {
l.EncryptionKey = ""
l.ImplantToken = ""
}
c.JSON(http.StatusOK, gin.H{"listeners": listeners})
}
// CreateListener 创建监听器
func (h *C2Handler) CreateListener(c *gin.Context) {
var req struct {
Name string `json:"name"`
Type string `json:"type"`
BindHost string `json:"bind_host"`
BindPort int `json:"bind_port"`
ProfileID string `json:"profile_id,omitempty"`
Remark string `json:"remark,omitempty"`
CallbackHost string `json:"callback_host,omitempty"`
Config *c2.ListenerConfig `json:"config,omitempty"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
input := c2.CreateListenerInput{
Name: req.Name,
Type: req.Type,
BindHost: req.BindHost,
BindPort: req.BindPort,
ProfileID: req.ProfileID,
Remark: req.Remark,
Config: req.Config,
CallbackHost: strings.TrimSpace(req.CallbackHost),
}
listener, err := h.manager.CreateListener(input)
if err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
}
c.JSON(code, gin.H{"error": err.Error()})
return
}
implantToken := listener.ImplantToken
listener.EncryptionKey = ""
listener.ImplantToken = ""
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
}
// GetListener 获取单个监听器
func (h *C2Handler) GetListener(c *gin.Context) {
id := c.Param("id")
listener, err := h.manager.DB().GetC2Listener(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if listener == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"})
return
}
listener.EncryptionKey = ""
listener.ImplantToken = ""
c.JSON(http.StatusOK, gin.H{"listener": listener})
}
// UpdateListener 更新监听器
func (h *C2Handler) UpdateListener(c *gin.Context) {
id := c.Param("id")
listener, err := h.manager.DB().GetC2Listener(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if listener == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"})
return
}
var req struct {
Name string `json:"name"`
BindHost string `json:"bind_host"`
BindPort int `json:"bind_port"`
ProfileID string `json:"profile_id"`
Remark string `json:"remark"`
CallbackHost *string `json:"callback_host"`
Config *c2.ListenerConfig `json:"config,omitempty"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 若监听器在运行,不能修改关键字段
if h.manager.IsListenerRunning(id) {
if req.BindHost != listener.BindHost || req.BindPort != listener.BindPort {
c.JSON(http.StatusConflict, gin.H{"error": "cannot modify bind address while listener is running"})
return
}
}
listener.Name = req.Name
listener.BindHost = req.BindHost
listener.BindPort = req.BindPort
listener.ProfileID = req.ProfileID
listener.Remark = req.Remark
if req.Config != nil {
cfgJSON, _ := json.Marshal(req.Config)
listener.ConfigJSON = string(cfgJSON)
}
if req.CallbackHost != nil {
cfg := &c2.ListenerConfig{}
raw := strings.TrimSpace(listener.ConfigJSON)
if raw == "" {
raw = "{}"
}
_ = json.Unmarshal([]byte(raw), cfg)
cfg.CallbackHost = strings.TrimSpace(*req.CallbackHost)
cfg.ApplyDefaults()
cfgJSON, err := json.Marshal(cfg)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
listener.ConfigJSON = string(cfgJSON)
}
if err := h.manager.DB().UpdateC2Listener(listener); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
listener.EncryptionKey = ""
listener.ImplantToken = ""
c.JSON(http.StatusOK, gin.H{"listener": listener})
}
// DeleteListener 删除监听器
func (h *C2Handler) DeleteListener(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DeleteListener(id); err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
}
c.JSON(code, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}
// StartListener 启动监听器
func (h *C2Handler) StartListener(c *gin.Context) {
id := c.Param("id")
listener, err := h.manager.StartListener(id)
if err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
}
c.JSON(code, gin.H{"error": err.Error()})
return
}
listener.EncryptionKey = ""
listener.ImplantToken = ""
c.JSON(http.StatusOK, gin.H{"listener": listener})
}
// StopListener 停止监听器
func (h *C2Handler) StopListener(c *gin.Context) {
id := c.Param("id")
if err := h.manager.StopListener(id); err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
}
c.JSON(code, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"stopped": true})
}
// ============================================================================
// 会话 API
// ============================================================================
// ListSessions 获取会话列表
func (h *C2Handler) ListSessions(c *gin.Context) {
filter := database.ListC2SessionsFilter{
ListenerID: c.Query("listener_id"),
Status: c.Query("status"),
OS: c.Query("os"),
Search: c.Query("search"),
}
if limit := c.Query("limit"); limit != "" {
if n, err := strconv.Atoi(limit); err == nil && n > 0 {
filter.Limit = n
}
}
sessions, err := h.manager.DB().ListC2Sessions(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"sessions": sessions})
}
// GetSession 获取单个会话
func (h *C2Handler) GetSession(c *gin.Context) {
id := c.Param("id")
session, err := h.manager.DB().GetC2Session(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
return
}
// 获取最近任务
tasks, _ := h.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
SessionID: id,
Limit: 20,
})
c.JSON(http.StatusOK, gin.H{
"session": session,
"tasks": tasks,
})
}
// DeleteSession 删除会话
func (h *C2Handler) DeleteSession(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DB().DeleteC2Session(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}
// SetSessionSleep 设置会话的 sleep/jitter
func (h *C2Handler) SetSessionSleep(c *gin.Context) {
id := c.Param("id")
var req struct {
SleepSeconds int `json:"sleep_seconds"`
JitterPercent int `json:"jitter_percent"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.manager.DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"updated": true})
}
// ============================================================================
// 任务 API
// ============================================================================
// ListTasks 获取任务列表
func (h *C2Handler) ListTasks(c *gin.Context) {
filter := database.ListC2TasksFilter{
SessionID: c.Query("session_id"),
Status: c.Query("status"),
}
paginated := false
page := 1
pageSize := 10
if c.Query("page") != "" || c.Query("page_size") != "" {
paginated = true
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
page = p
}
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "10")); err == nil && ps > 0 {
pageSize = ps
if pageSize > 100 {
pageSize = 100
}
}
filter.Limit = pageSize
filter.Offset = (page - 1) * pageSize
} else {
if limit := c.Query("limit"); limit != "" {
if n, err := strconv.Atoi(limit); err == nil && n > 0 {
filter.Limit = n
}
}
}
tasks, err := h.manager.DB().ListC2Tasks(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 仪表盘「待审任务」为全局 queued/pending 数量,与列表 session 过滤无关
pendingN, _ := h.manager.DB().CountC2TasksQueuedOrPending("")
if !paginated {
c.JSON(http.StatusOK, gin.H{
"tasks": tasks,
"pending_queued_count": pendingN,
})
return
}
total, err := h.manager.DB().CountC2Tasks(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"tasks": tasks,
"total": total,
"page": page,
"page_size": pageSize,
"pending_queued_count": pendingN,
})
}
// DeleteTasks 批量删除任务(请求体 JSON: {"ids":["t_xxx",...]}
func (h *C2Handler) DeleteTasks(c *gin.Context) {
var req struct {
IDs []string `json:"ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()})
return
}
if len(req.IDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
return
}
n, err := h.manager.DB().DeleteC2TasksByIDs(req.IDs)
if err != nil {
if errors.Is(err, database.ErrNoValidC2TaskIDs) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": n})
}
// GetTask 获取单个任务
func (h *C2Handler) GetTask(c *gin.Context) {
id := c.Param("id")
task, err := h.manager.DB().GetC2Task(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if task == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "task not found"})
return
}
c.JSON(http.StatusOK, gin.H{"task": task})
}
// CreateTask 创建任务
func (h *C2Handler) CreateTask(c *gin.Context) {
var req struct {
SessionID string `json:"session_id"`
TaskType string `json:"task_type"`
Payload map[string]interface{} `json:"payload"`
Source string `json:"source"`
ConversationID string `json:"conversation_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
input := c2.EnqueueTaskInput{
SessionID: req.SessionID,
TaskType: c2.TaskType(req.TaskType),
Payload: req.Payload,
Source: firstNonEmpty(req.Source, "manual"),
ConversationID: req.ConversationID,
UserCtx: c.Request.Context(),
}
task, err := h.manager.EnqueueTask(input)
if err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
}
c.JSON(code, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"task": task})
}
// CancelTask 取消任务
func (h *C2Handler) CancelTask(c *gin.Context) {
id := c.Param("id")
if err := h.manager.CancelTask(id); err != nil {
code := http.StatusInternalServerError
if e, ok := err.(*c2.CommonError); ok {
code = e.HTTP
}
c.JSON(code, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"cancelled": true})
}
// WaitTask 等待任务完成
func (h *C2Handler) WaitTask(c *gin.Context) {
id := c.Param("id")
timeout := 60 * time.Second
if t := c.Query("timeout"); t != "" {
if n, err := strconv.Atoi(t); err == nil && n > 0 {
timeout = time.Duration(n) * time.Second
}
}
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
task, err := h.manager.DB().GetC2Task(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if task == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "task not found"})
return
}
if task.Status == "success" || task.Status == "failed" || task.Status == "cancelled" {
c.JSON(http.StatusOK, gin.H{"task": task})
return
}
time.Sleep(500 * time.Millisecond)
}
c.JSON(http.StatusRequestTimeout, gin.H{"error": "timeout waiting for task completion"})
}
// ============================================================================
// Payload API
// ============================================================================
// PayloadOneliner 生成单行 payload
func (h *C2Handler) PayloadOneliner(c *gin.Context) {
var req struct {
ListenerID string `json:"listener_id"`
Kind string `json:"kind"` // bash, python, powershell, curl_beacon
Host string `json:"host"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
listener, err := h.manager.DB().GetC2Listener(req.ListenerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if listener == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"})
return
}
host := c2.ResolveBeaconDialHost(listener, strings.TrimSpace(req.Host), h.logger, listener.ID)
kind := c2.OnelinerKind(req.Kind)
if !c2.IsOnelinerCompatible(listener.Type, kind) {
compatible := c2.OnelinerKindsForListener(listener.Type)
names := make([]string, len(compatible))
for i, k := range compatible {
names[i] = string(k)
}
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("监听器类型 %s 不支持 %s 类型的 oneliner,请选择兼容的类型", listener.Type, req.Kind),
"compatible_kinds": names,
})
return
}
input := c2.OnelinerInput{
Kind: kind,
Host: host,
Port: listener.BindPort,
HTTPBaseURL: fmt.Sprintf("http://%s:%d", host, listener.BindPort),
ImplantToken: listener.ImplantToken,
}
oneliner, err := c2.GenerateOneliner(input)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"oneliner": oneliner,
"kind": req.Kind,
"host": host,
"port": listener.BindPort,
})
}
// PayloadBuild 构建 beacon 二进制
func (h *C2Handler) PayloadBuild(c *gin.Context) {
var req struct {
ListenerID string `json:"listener_id"`
OS string `json:"os"`
Arch string `json:"arch"`
SleepSeconds int `json:"sleep_seconds"`
JitterPercent int `json:"jitter_percent"`
Host string `json:"host"` // 可选:编译进 Beacon 的回连地址,覆盖监听器 bind_host
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
listener, err := h.manager.DB().GetC2Listener(req.ListenerID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if listener == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"})
return
}
builder := c2.NewPayloadBuilder(h.manager, h.logger, "", "")
input := c2.PayloadBuilderInput{
ListenerID: req.ListenerID,
OS: req.OS,
Arch: req.Arch,
SleepSeconds: req.SleepSeconds,
JitterPercent: req.JitterPercent,
Host: strings.TrimSpace(req.Host),
}
result, err := builder.BuildBeacon(input)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"payload": result,
})
}
// PayloadDownload 下载 payload
func (h *C2Handler) PayloadDownload(c *gin.Context) {
id := c.Param("id")
filename := id
if !strings.HasPrefix(filename, "beacon_") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"})
return
}
if strings.Contains(filename, "/") || strings.Contains(filename, "\\") || strings.Contains(filename, "..") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"})
return
}
builder := c2.NewPayloadBuilder(h.manager, h.logger, "", "")
storageDir := builder.GetPayloadStoragePath()
targetPath := filepath.Join(storageDir, filename)
absTarget, err := filepath.Abs(targetPath)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path"})
return
}
absDir, err := filepath.Abs(storageDir)
if err != nil || !strings.HasPrefix(absTarget, absDir+string(filepath.Separator)) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"})
return
}
c.FileAttachment(absTarget, filepath.Base(absTarget))
}
// ============================================================================
// 事件 API
// ============================================================================
// ListEvents 获取事件列表
func (h *C2Handler) ListEvents(c *gin.Context) {
filter := database.ListC2EventsFilter{
Level: c.Query("level"),
Category: c.Query("category"),
SessionID: c.Query("session_id"),
TaskID: c.Query("task_id"),
}
if since := c.Query("since"); since != "" {
if t, err := time.Parse(time.RFC3339, since); err == nil {
filter.Since = &t
}
}
paginated := false
page := 1
pageSize := 10
if c.Query("page") != "" || c.Query("page_size") != "" {
paginated = true
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
page = p
}
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "10")); err == nil && ps > 0 {
pageSize = ps
if pageSize > 100 {
pageSize = 100
}
}
filter.Limit = pageSize
filter.Offset = (page - 1) * pageSize
} else {
if limit := c.Query("limit"); limit != "" {
if n, err := strconv.Atoi(limit); err == nil && n > 0 {
filter.Limit = n
}
}
}
events, err := h.manager.DB().ListC2Events(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !paginated {
c.JSON(http.StatusOK, gin.H{"events": events})
return
}
total, err := h.manager.DB().CountC2Events(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"events": events,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// DeleteEvents 批量删除事件(请求体 JSON: {"ids":["e_xxx",...]}
func (h *C2Handler) DeleteEvents(c *gin.Context) {
var req struct {
IDs []string `json:"ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()})
return
}
if len(req.IDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
return
}
n, err := h.manager.DB().DeleteC2EventsByIDs(req.IDs)
if err != nil {
if errors.Is(err, database.ErrNoValidC2EventIDs) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": n})
}
// EventStream SSE 实时事件流
func (h *C2Handler) EventStream(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
sessionFilter := c.Query("session_id")
categoryFilter := c.Query("category")
levels := c.QueryArray("level")
sub := h.manager.EventBus().Subscribe(
"sse-"+uuid.New().String(),
128,
sessionFilter,
categoryFilter,
levels,
)
defer h.manager.EventBus().Unsubscribe(sub.ID)
c.Stream(func(w io.Writer) bool {
select {
case e, ok := <-sub.Ch:
if !ok {
return false
}
data, _ := json.Marshal(e)
fmt.Fprintf(w, "data: %s\n\n", data)
return true
case <-c.Request.Context().Done():
return false
}
})
}
// ============================================================================
// Profile API
// ============================================================================
// ListProfiles 获取 Malleable Profile 列表
func (h *C2Handler) ListProfiles(c *gin.Context) {
profiles, err := h.manager.DB().ListC2Profiles()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"profiles": profiles})
}
// GetProfile 获取单个 Profile
func (h *C2Handler) GetProfile(c *gin.Context) {
id := c.Param("id")
profile, err := h.manager.DB().GetC2Profile(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if profile == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "profile not found"})
return
}
c.JSON(http.StatusOK, gin.H{"profile": profile})
}
// CreateProfile 创建 Profile
func (h *C2Handler) CreateProfile(c *gin.Context) {
var req database.C2Profile
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.ID = "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
req.CreatedAt = time.Now()
if err := h.manager.DB().CreateC2Profile(&req); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"profile": req})
}
// UpdateProfile 更新 Profile
func (h *C2Handler) UpdateProfile(c *gin.Context) {
id := c.Param("id")
profile, err := h.manager.DB().GetC2Profile(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if profile == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "profile not found"})
return
}
var req database.C2Profile
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
profile.Name = req.Name
profile.UserAgent = req.UserAgent
profile.URIs = req.URIs
profile.RequestHeaders = req.RequestHeaders
profile.ResponseHeaders = req.ResponseHeaders
profile.BodyTemplate = req.BodyTemplate
profile.JitterMinMS = req.JitterMinMS
profile.JitterMaxMS = req.JitterMaxMS
if err := h.manager.DB().UpdateC2Profile(profile); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"profile": profile})
}
// DeleteProfile 删除 Profile
func (h *C2Handler) DeleteProfile(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DB().DeleteC2Profile(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}
// ============================================================================
// 文件管理 API(C2 Upload 任务需要先通过此 API 上传文件到 downstream 目录)
// ============================================================================
// UploadFileForImplant 操作员上传文件,供 upload 任务推送给 implant
func (h *C2Handler) UploadFileForImplant(c *gin.Context) {
sessionID := strings.TrimSpace(c.PostForm("session_id"))
remotePath := strings.TrimSpace(c.PostForm("remote_path"))
if sessionID == "" || remotePath == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session_id and remote_path required"})
return
}
file, header, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "file field required: " + err.Error()})
return
}
defer file.Close()
fileID := "f_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
dir := filepath.Join(h.manager.StorageDir(), "downstream")
if err := osMkdirAll(dir); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
dstPath := filepath.Join(dir, fileID+".bin")
dst, err := osCreate(dstPath)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
n, err := io.Copy(dst, file)
dst.Close()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Record in DB
dbFile := &database.C2File{
ID: fileID,
SessionID: sessionID,
Direction: "upload",
RemotePath: remotePath,
LocalPath: dstPath,
SizeBytes: n,
CreatedAt: time.Now(),
}
_ = h.manager.DB().CreateC2File(dbFile)
c.JSON(http.StatusOK, gin.H{
"file_id": fileID,
"size": n,
"filename": header.Filename,
"remote_path": remotePath,
})
}
// ListFiles 列出某会话的文件记录
func (h *C2Handler) ListFiles(c *gin.Context) {
sessionID := c.Query("session_id")
if sessionID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "session_id required"})
return
}
files, err := h.manager.DB().ListC2FilesBySession(sessionID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"files": files})
}
// DownloadResultFile 下载任务结果文件(截图等 blob 结果)
func (h *C2Handler) DownloadResultFile(c *gin.Context) {
taskID := c.Param("id")
task, err := h.manager.DB().GetC2Task(taskID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if task == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "task not found"})
return
}
if task.ResultBlobPath == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "no result file for this task"})
return
}
c.FileAttachment(task.ResultBlobPath, filepath.Base(task.ResultBlobPath))
}
func osMkdirAll(path string) error {
return os.MkdirAll(path, 0o755)
}
func osCreate(path string) (*os.File, error) {
return os.Create(path)
}
// ============================================================================
// 辅助函数(firstNonEmpty 已在 vulnerability.go 中定义)
// ============================================================================
+9
View File
@@ -233,6 +233,15 @@ func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRunt
return cfg, !inWhitelist
}
// NeedsToolApproval 与 Agent 工具层 shouldInterrupt 语义一致:仅当该会话已开启人机协同且工具不在免审批白名单时为 true。
func (m *HITLManager) NeedsToolApproval(conversationID, toolName string) bool {
if m == nil {
return false
}
_, need := m.shouldInterrupt(conversationID, toolName)
return need
}
func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) {
now := time.Now()
id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "")
+58 -1
View File
@@ -38,6 +38,7 @@ type NotificationSummaryItem struct {
VulnerabilityID string `json:"vulnerabilityId,omitempty"`
ExecutionID string `json:"executionId,omitempty"`
InterruptID string `json:"interruptId,omitempty"`
SessionID string `json:"sessionId,omitempty"` // C2 会话(如新会话上线)
}
// NotificationSummaryResponse 聚合响应
@@ -239,6 +240,52 @@ func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, e
return items, counts, nil
}
// loadC2SessionOnlineEvents 新会话上线(c2_eventssession + critical,与 Manager.IngestCheckIn 一致)
func (h *NotificationHandler) loadC2SessionOnlineEvents(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT id, message, COALESCE(session_id, ''),
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM c2_events
WHERE category = 'session' AND level = 'critical'
AND CAST(strftime('%s', created_at) AS INTEGER) > ?
ORDER BY created_at DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, 0, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
for rows.Next() {
var id, message, sessionID string
var createdSec int64
if err := rows.Scan(&id, &message, &sessionID, &createdSec); err != nil {
continue
}
desc := strings.TrimSpace(message)
if len(desc) > 220 {
desc = desc[:200] + "…"
}
if desc == "" {
desc = i18nText(english, "新会话已建立", "A new session was created")
}
items = append(items, NotificationSummaryItem{
ID: "c2evt:" + id,
Level: "p0",
Type: "c2_session_online",
Title: i18nText(english, "C2 新会话上线", "C2 new session online"),
Desc: desc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: false,
Read: false,
SessionID: sessionID,
})
}
return items, len(items), rows.Err()
}
func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
@@ -492,6 +539,7 @@ func normalizeMarkableEventID(id string) (string, bool) {
"vuln:",
"exec_failed:",
"task_completed:",
"c2evt:",
}
for _, prefix := range allowedPrefixes {
if strings.HasPrefix(v, prefix) {
@@ -593,12 +641,20 @@ func (h *NotificationHandler) GetSummary(c *gin.Context) {
return
}
c2OnlineItems, c2OnlineCount, err := h.loadC2SessionOnlineEvents(sinceMs, limit, english)
if err != nil {
h.logger.Warn("加载 C2 会话上线通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize c2 session events"})
return
}
longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english)
completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english)
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(longRunningItems)+len(completedItems))
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(c2OnlineItems)+len(longRunningItems)+len(completedItems))
items = append(items, hitlItems...)
items = append(items, vulnItems...)
items = append(items, c2OnlineItems...)
items = append(items, longRunningItems...)
items = append(items, completedItems...)
@@ -636,6 +692,7 @@ func (h *NotificationHandler) GetSummary(c *gin.Context) {
"failedExecutions": 0,
"longRunningTasks": longRunningCount,
"completedTasks": completedCount,
"c2SessionOnline": c2OnlineCount,
},
Items: items,
})