Add files via upload

This commit is contained in:
公明
2026-06-18 12:44:42 +08:00
committed by GitHub
parent 01b361e4a7
commit 7eadccbff6
85 changed files with 33500 additions and 0 deletions
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,99 @@
package handler
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"testing"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/openai"
"go.uber.org/zap"
)
// TestCreateProgressCallback_ConcurrentToolEvents 回归 issue #142:并行 tool 回调不得 concurrent map panic。
func TestCreateProgressCallback_ConcurrentToolEvents(t *testing.T) {
logger := zap.NewNop()
h := &AgentHandler{
logger: logger,
config: &config.Config{},
}
cb := h.createProgressCallback(context.Background(), nil, "conv-race-test", "", nil)
const workers = 64
var wg sync.WaitGroup
wg.Add(workers * 2)
for i := 0; i < workers; i++ {
i := i
go func() {
defer wg.Done()
toolCallID := fmt.Sprintf("tc-%d", i)
cb("tool_call", "calling skill", map[string]interface{}{
"toolCallId": toolCallID,
"toolName": "skill",
"argumentsObj": map[string]interface{}{"skill_name": "demo-skill"},
})
}()
go func() {
defer wg.Done()
toolCallID := fmt.Sprintf("tc-%d", i)
cb("tool_result", "skill done", map[string]interface{}{
"toolCallId": toolCallID,
"toolName": "skill",
"success": true,
})
}()
}
wg.Wait()
}
// TestCreateProgressCallback_FlushesReasoningOnDone 流式推理聚合须在 done/response 时落库,刷新后可回放。
func TestCreateProgressCallback_FlushesReasoningOnDone(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer os.RemoveAll(tmp)
conv, err := db.CreateConversation("test", database.ConversationCreateMeta{})
if err != nil {
t.Fatalf("CreateConversation: %v", err)
}
asst, err := db.AddMessage(conv.ID, "assistant", "处理中...", nil)
if err != nil {
t.Fatalf("AddMessage: %v", err)
}
h := &AgentHandler{logger: zap.NewNop(), db: db}
cb := h.createProgressCallback(context.Background(), nil, conv.ID, asst.ID, nil)
streamID := "eino-reasoning-test-1"
cb("reasoning_chain_stream_start", " ", map[string]interface{}{
"streamId": streamID,
"source": "eino",
})
cb("reasoning_chain_stream_delta", "step one", openai.WithSSEAccumulated(map[string]interface{}{
"streamId": streamID,
}, "step one"))
cb("done", "", map[string]interface{}{"conversationId": conv.ID})
details, err := db.GetProcessDetails(asst.ID)
if err != nil {
t.Fatalf("GetProcessDetails: %v", err)
}
found := false
for _, d := range details {
if d.EventType == "reasoning_chain" && d.Message == "step one" {
found = true
break
}
}
if !found {
t.Fatalf("expected reasoning_chain persisted on done, got %+v", details)
}
}
+172
View File
@@ -0,0 +1,172 @@
package handler
import (
"context"
"net/http"
"sync"
"time"
"cyberstrike-ai/internal/attackchain"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AttackChainHandler 攻击链处理器
type AttackChainHandler struct {
db *database.DB
logger *zap.Logger
openAIConfig *config.OpenAIConfig
mu sync.RWMutex // 保护 openAIConfig 的并发访问
// 用于防止同一对话的并发生成
generatingLocks sync.Map // map[string]*sync.Mutex
}
// NewAttackChainHandler 创建新的攻击链处理器
func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *AttackChainHandler {
return &AttackChainHandler{
db: db,
logger: logger,
openAIConfig: openAIConfig,
}
}
// UpdateConfig 更新OpenAI配置
func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) {
h.mu.Lock()
defer h.mu.Unlock()
h.openAIConfig = cfg
h.logger.Info("AttackChainHandler配置已更新",
zap.String("base_url", cfg.BaseURL),
zap.String("model", cfg.Model),
)
}
// getOpenAIConfig 获取OpenAI配置(线程安全)
func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig {
h.mu.RLock()
defer h.mu.RUnlock()
return h.openAIConfig
}
// GetAttackChain 获取攻击链(按需生成)
// GET /api/attack-chain/:conversationId
func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
conversationID := c.Param("conversationId")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
return
}
// 检查对话是否存在
_, err := h.db.GetConversation(conversationID)
if err != nil {
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
// 先尝试从数据库加载(如果已生成过)
openAIConfig := h.getOpenAIConfig()
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
chain, err := builder.LoadChainFromDatabase(conversationID)
if err == nil && len(chain.Nodes) > 0 {
// 如果已存在,直接返回
h.logger.Info("返回已存在的攻击链", zap.String("conversationId", conversationID))
c.JSON(http.StatusOK, chain)
return
}
// 如果不存在,则生成新的攻击链(按需生成)
// 使用锁机制防止同一对话的并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
// 尝试获取锁,如果正在生成则返回错误
acquired := lock.TryLock()
if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
return
}
defer lock.Unlock()
// 再次检查是否已生成(可能在等待锁的过程中已经生成完成)
chain, err = builder.LoadChainFromDatabase(conversationID)
if err == nil && len(chain.Nodes) > 0 {
h.logger.Info("返回已存在的攻击链(在锁等待期间已生成)", zap.String("conversationId", conversationID))
c.JSON(http.StatusOK, chain)
return
}
h.logger.Info("开始生成攻击链", zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
chain, err = builder.BuildChainFromConversation(ctx, conversationID)
if err != nil {
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
return
}
// 生成完成后,从锁映射中删除(可选,保留也可以用于防止短时间内重复生成)
// h.generatingLocks.Delete(conversationID)
c.JSON(http.StatusOK, chain)
}
// RegenerateAttackChain 重新生成攻击链
// POST /api/attack-chain/:conversationId/regenerate
func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
conversationID := c.Param("conversationId")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
return
}
// 检查对话是否存在
_, err := h.db.GetConversation(conversationID)
if err != nil {
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
// 删除旧的攻击链
if err := h.db.DeleteAttackChain(conversationID); err != nil {
h.logger.Warn("删除旧攻击链失败", zap.Error(err))
}
// 使用锁机制防止并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
acquired := lock.TryLock()
if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
return
}
defer lock.Unlock()
// 生成新的攻击链
h.logger.Info("重新生成攻击链", zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
openAIConfig := h.getOpenAIConfig()
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
chain, err := builder.BuildChainFromConversation(ctx, conversationID)
if err != nil {
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
return
}
c.JSON(http.StatusOK, chain)
}
+147
View File
@@ -0,0 +1,147 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AuditHandler serves platform audit log APIs.
type AuditHandler struct {
db *database.DB
audit *audit.Service
logger *zap.Logger
}
// NewAuditHandler creates an audit log handler.
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
}
// Meta GET /api/audit/meta
func (h *AuditHandler) Meta(c *gin.Context) {
enabled := false
retentionDays := 0
if h.audit != nil {
enabled = h.audit.Enabled()
retentionDays = h.audit.RetentionDays()
}
c.JSON(http.StatusOK, gin.H{
"enabled": enabled,
"retention_days": retentionDays,
"default_page_size": 20,
"max_page_size": 100,
"max_export": 5000,
})
}
// Summary GET /api/audit/summary
func (h *AuditHandler) Summary(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
base := auditFilterFromQuery(c)
total, err := h.db.CountAuditLogs(base)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
failFilter := base
failFilter.Result = "failure"
failures, err := h.db.CountAuditLogs(failFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
since := time.Now().AddDate(0, 0, -7)
recentFilter := base
recentFilter.Since = &since
recent7d, err := h.db.CountAuditLogs(recentFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"total": total,
"failures": failures,
"recent_7d": recent7d,
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
})
}
// ListLogs GET /api/audit/logs
func (h *AuditHandler) ListLogs(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
filter := auditFilterFromQuery(c)
page, pageSize := auditPaginationFromQuery(c)
filter.Limit = pageSize
filter.Offset = (page - 1) * pageSize
logs, err := h.db.ListAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
total, err := h.db.CountAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"logs": logs,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// GetLog GET /api/audit/logs/:id
func (h *AuditHandler) GetLog(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
row, err := h.db.GetAuditLogByID(c.Param("id"))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
return
}
audit.ApplyResourceAvailability(h.db, row)
c.JSON(http.StatusOK, gin.H{"log": row})
}
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
func (h *AuditHandler) ExportLogs(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
filter := auditFilterFromQuery(c)
filter.Limit = 5000
filter.Offset = 0
logs, err := h.db.ListAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if c.Query("format") == "csv" {
writeAuditLogsCSV(c, logs)
return
}
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
c.JSON(http.StatusOK, gin.H{
"exported_at": time.Now().UTC().Format(time.RFC3339),
"logs": logs,
})
}
+42
View File
@@ -0,0 +1,42 @@
package handler
import (
"encoding/csv"
"fmt"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
)
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
c.Header("Content-Type", "text/csv; charset=utf-8")
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
w := csv.NewWriter(c.Writer)
_ = w.Write([]string{
"id", "created_at", "level", "category", "action", "result", "actor",
"session_hint", "client_ip", "resource_type", "resource_id", "message",
})
for _, row := range logs {
if row == nil {
continue
}
_ = w.Write([]string{
row.ID,
row.CreatedAt.UTC().Format(time.RFC3339),
row.Level,
row.Category,
row.Action,
row.Result,
row.Actor,
row.SessionHint,
row.ClientIP,
row.ResourceType,
row.ResourceID,
row.Message,
})
}
w.Flush()
}
+47
View File
@@ -0,0 +1,47 @@
package handler
import (
"strconv"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
)
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
filter := database.ListAuditLogsFilter{
Level: c.Query("level"),
Category: c.Query("category"),
Action: c.Query("action"),
Result: c.Query("result"),
Query: c.Query("q"),
ResourceType: c.Query("resource_type"),
ResourceID: c.Query("resource_id"),
}
if since := c.Query("since"); since != "" {
if t, err := database.ParseRFC3339Time(since); err == nil {
filter.Since = &t
}
}
if until := c.Query("until"); until != "" {
if t, err := database.ParseRFC3339Time(until); err == nil {
filter.Until = &t
}
}
return filter
}
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
page = p
}
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
pageSize = ps
if pageSize > 100 {
pageSize = 100
}
}
return page, pageSize
}
+211
View File
@@ -0,0 +1,211 @@
package handler
import (
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AuthHandler handles authentication-related endpoints.
type AuthHandler struct {
manager *security.AuthManager
config *config.Config
configPath string
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *AuthHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewAuthHandler creates a new AuthHandler.
func NewAuthHandler(manager *security.AuthManager, cfg *config.Config, configPath string, logger *zap.Logger) *AuthHandler {
return &AuthHandler{
manager: manager,
config: cfg,
configPath: configPath,
logger: logger,
}
}
type loginRequest struct {
Password string `json:"password" binding:"required"`
}
type changePasswordRequest struct {
OldPassword string `json:"oldPassword"`
NewPassword string `json:"newPassword"`
}
// Login verifies password and returns a session token.
func (h *AuthHandler) Login(c *gin.Context) {
var req loginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"})
return
}
token, expiresAt, err := h.manager.Authenticate(req.Password)
if err != nil {
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Level: "warn",
Category: "auth",
Action: "login",
Result: "failure",
Message: "登录失败:密码错误",
})
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
return
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "login",
Result: "success",
SessionHint: audit.HintFromToken(token),
Message: "登录成功",
Detail: map[string]interface{}{
"expires_at": expiresAt.UTC().Format(time.RFC3339),
},
})
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"expires_at": expiresAt.UTC().Format(time.RFC3339),
"session_duration_hr": h.manager.SessionDurationHours(),
})
}
// Logout revokes the current session token.
func (h *AuthHandler) Logout(c *gin.Context) {
token := c.GetString(security.ContextAuthTokenKey)
if token == "" {
authHeader := c.GetHeader("Authorization")
if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") {
token = strings.TrimSpace(authHeader[7:])
} else {
token = strings.TrimSpace(authHeader)
}
}
h.manager.RevokeToken(token)
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "logout",
Result: "success",
Message: "退出登录",
})
}
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
}
// ChangePassword updates the login password.
func (h *AuthHandler) ChangePassword(c *gin.Context) {
var req changePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "参数无效"})
return
}
oldPassword := strings.TrimSpace(req.OldPassword)
newPassword := strings.TrimSpace(req.NewPassword)
if oldPassword == "" || newPassword == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码和新密码均不能为空"})
return
}
if len(newPassword) < 8 {
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码长度至少需要 8 位"})
return
}
if oldPassword == newPassword {
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码不能与旧密码相同"})
return
}
if !h.manager.CheckPassword(oldPassword) {
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Level: "warn",
Category: "auth",
Action: "change_password",
Result: "failure",
Message: "修改密码失败:当前密码不正确",
})
}
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
return
}
if err := config.PersistAuthPassword(h.configPath, newPassword); err != nil {
if h.logger != nil {
h.logger.Error("保存新密码失败", zap.Error(err))
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存新密码失败,请重试"})
return
}
if err := h.manager.UpdateConfig(newPassword, h.config.Auth.SessionDurationHours); err != nil {
if h.logger != nil {
h.logger.Error("更新认证配置失败", zap.Error(err))
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新认证配置失败"})
return
}
h.config.Auth.Password = newPassword
h.config.Auth.GeneratedPassword = ""
h.config.Auth.GeneratedPasswordPersisted = false
h.config.Auth.GeneratedPasswordPersistErr = ""
if h.logger != nil {
h.logger.Info("登录密码已更新,所有会话已失效")
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "change_password",
Result: "success",
Message: "登录密码已修改",
})
}
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
}
// Validate returns the current session status.
func (h *AuthHandler) Validate(c *gin.Context) {
token := c.GetString(security.ContextAuthTokenKey)
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话无效"})
return
}
session, ok := h.manager.ValidateToken(token)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话已过期"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": session.Token,
"expires_at": session.ExpiresAt.UTC().Format(time.RFC3339),
})
}
File diff suppressed because it is too large Load Diff
+831
View File
@@ -0,0 +1,831 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler
func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) {
if mcpServer == nil || h == nil || logger == nil {
return
}
reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) {
mcpServer.RegisterTool(tool, fn)
}
// --- list ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskList,
Description: "列出批量任务队列(精简摘要,省上下文)。含队列元数据、子任务 id/status/截断后的 message、各状态计数。完整子任务(含 result/error/conversationId/时间等)请用 batch_task_get(queue_id)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确提及查看/管理批量任务、任务队列时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "列出批量任务队列",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"status": map[string]interface{}{
"type": "string",
"description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled",
"enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"},
},
"keyword": map[string]interface{}{
"type": "string",
"description": "按队列 ID 或标题模糊搜索",
},
"page": map[string]interface{}{
"type": "integer",
"description": "页码,从 1 开始,默认 1",
},
"page_size": map[string]interface{}{
"type": "integer",
"description": "每页条数,默认 20,最大 100",
},
},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
status := mcpArgString(args, "status")
if status == "" {
status = "all"
}
keyword := mcpArgString(args, "keyword")
page := int(mcpArgFloat(args, "page"))
if page <= 0 {
page = 1
}
pageSize := int(mcpArgFloat(args, "page_size"))
if pageSize <= 0 {
pageSize = 20
}
if pageSize > 100 {
pageSize = 100
}
offset := (page - 1) * pageSize
if offset > 100000 {
offset = 100000
}
queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword)
if err != nil {
return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil
}
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
slim := make([]batchTaskQueueMCPListItem, 0, len(queues))
for _, q := range queues {
if q == nil {
continue
}
slim = append(slim, toBatchTaskQueueMCPListItem(q))
}
payload := map[string]interface{}{
"queues": slim,
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": totalPages,
}
logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total))
return batchMCPJSONResult(payload)
})
// --- get ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskGet,
Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确提及查看/管理批量任务、任务队列时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "获取批量任务队列详情",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
queue, ok := h.batchTaskManager.GetBatchQueue(qid)
if !ok {
return batchMCPTextResult("队列不存在: "+qid, true), nil
}
return batchMCPJSONResult(queue)
})
// --- create ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskCreate,
Description: `⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求创建批量任务、任务队列时才可调用。禁止在用户未提及”批量任务””任务队列””定时任务”等关键词时自行调用。如果用户只是让你做某件事,请在当前对话中直接完成,不要自作主张创建任务队列。
【用途】应用内「任务管理 / 批量任务队列」:把多条彼此独立的用户指令登记成一条队列,便于在界面里查看进度、暂停/继续、定时重跑等。这是队列数据与调度入口,不是再开一个”子代理会话”替你探索当前问题。
【何时用】用户明确要批量排队执行、Cron 周期跑同一批指令、或需要与任务管理页面对齐时调用。需要即时追问、强依赖当前对话上下文的分析/编码,应在本对话内直接完成,不要为了”委派”而创建队列。
【参数】tasks(字符串数组)或 tasks_text(多行,每行一条)二选一;每项是一条将来由系统按队列顺序执行的指令文案。agent_modeeino_singleEino ADK 单代理,默认)、deep / plan_execute / supervisor(需系统启用多代理)。非”把主对话拆给子代理”。schedule_modemanual(默认)或 croncron 须填 cron_expr5 段,如 “0 */6 * * *”)。
【执行】默认创建后为 pending,不自动跑。execute_now=true 可创建后立即跑;否则之后调用 batch_task_start。Cron 自动下一轮需 schedule_enabled 为 true(可用 batch_task_schedule_enabled)。`,
ShortDescription: "任务管理:创建批量任务队列(登记多条指令,可选立即或 Cron)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
"description": "可选队列标题,便于在任务管理中识别",
},
"role": map[string]interface{}{
"type": "string",
"description": "队列使用的角色名,空表示默认",
},
"tasks": map[string]interface{}{
"type": "array",
"description": "队列中的子任务指令,每项一条独立待执行文案(与 tasks_text 二选一)",
"items": map[string]interface{}{"type": "string"},
},
"tasks_text": map[string]interface{}{
"type": "string",
"description": "多行文本,每行一条子任务指令(与 tasks 二选一)",
},
"agent_mode": map[string]interface{}{
"type": "string",
"description": "执行模式:eino_singleEino ADK,默认)、deep/plan_execute/supervisorEino 编排,需启用多代理)",
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
},
"schedule_mode": map[string]interface{}{
"type": "string",
"description": "manual(仅手工/启动后跑)或 cron(按表达式触发)",
"enum": []string{"manual", "cron"},
},
"cron_expr": map[string]interface{}{
"type": "string",
"description": "schedule_mode 为 cron 时必填。标准 5 段:分钟 小时 日 月 星期,例如 \"0 */6 * * *\"、\"30 2 * * 1-5\"",
},
"execute_now": map[string]interface{}{
"type": "boolean",
"description": "创建后是否立即开始执行队列,默认 falsepending,需 batch_task_start",
},
"project_id": map[string]interface{}{
"type": "string",
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id",
},
},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
tasks, errMsg := batchMCPTasksFromArgs(args)
if errMsg != "" {
return batchMCPTextResult(errMsg, true), nil
}
title := mcpArgString(args, "title")
role := mcpArgString(args, "role")
agentMode := config.NormalizeAgentMode(mcpArgString(args, "agent_mode"))
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
var nextRunAt *time.Time
if scheduleMode == "cron" {
if cronExpr == "" {
return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil
}
sch, err := h.batchCronParser.Parse(cronExpr)
if err != nil {
return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil
}
n := sch.Next(time.Now())
nextRunAt = &n
}
executeNow, ok := mcpArgBool(args, "execute_now")
if !ok {
executeNow = false
}
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
if createErr != nil {
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
}
started := false
if executeNow {
ok, err := h.startBatchQueueExecution(queue.ID, false)
if !ok {
return batchMCPTextResult("队列不存在: "+queue.ID, true), nil
}
if err != nil {
return batchMCPTextResult("创建成功但启动失败: "+err.Error(), true), nil
}
started = true
if refreshed, exists := h.batchTaskManager.GetBatchQueue(queue.ID); exists {
queue = refreshed
}
}
logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks)))
return batchMCPJSONResult(map[string]interface{}{
"queue_id": queue.ID,
"queue": queue,
"started": started,
"execute_now": executeNow,
"reminder": func() string {
if started {
return "队列已创建并立即启动。"
}
return "队列已创建,当前为 pending。需要开始执行时请调用 MCP 工具 batch_task_startqueue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。"
}(),
})
})
// --- start ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskStart,
Description: `启动或继续执行批量任务队列(pending / paused)。
与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。
⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求启动/继续批量任务时才可调用。不要在用户未要求时自行调用。`,
ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
ok, err := h.startBatchQueueExecution(qid, false)
if !ok {
return batchMCPTextResult("队列不存在: "+qid, true), nil
}
if err != nil {
return batchMCPTextResult("启动失败: "+err.Error(), true), nil
}
logger.Info("MCP batch_task_start", zap.String("queueId", qid))
return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil
})
// --- rerun (reset + start for completed/cancelled queues) ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskRerun,
Description: "重跑已完成或已取消的批量任务队列。会重置所有子任务状态后重新执行一轮。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求重跑批量任务时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "重跑批量任务队列",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
queue, exists := h.batchTaskManager.GetBatchQueue(qid)
if !exists {
return batchMCPTextResult("队列不存在: "+qid, true), nil
}
if queue.Status != "completed" && queue.Status != "cancelled" {
return batchMCPTextResult("仅已完成或已取消的队列可以重跑,当前状态: "+queue.Status, true), nil
}
if !h.batchTaskManager.ResetQueueForRerun(qid) {
return batchMCPTextResult("重置队列失败", true), nil
}
ok, err := h.startBatchQueueExecution(qid, false)
if !ok {
return batchMCPTextResult("启动失败", true), nil
}
if err != nil {
return batchMCPTextResult("启动失败: "+err.Error(), true), nil
}
logger.Info("MCP batch_task_rerun", zap.String("queueId", qid))
return batchMCPTextResult("已重置并重新启动队列。", false), nil
})
// --- pause ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskPause,
Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求暂停批量任务时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "暂停批量任务队列",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
if !h.batchTaskManager.PauseQueue(qid) {
return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil
}
logger.Info("MCP batch_task_pause", zap.String("queueId", qid))
return batchMCPTextResult("队列已暂停。", false), nil
})
// --- delete queue ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskDelete,
Description: "删除批量任务队列及其子任务记录。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求删除批量任务队列时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "删除批量任务队列",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
if !h.batchTaskManager.DeleteQueue(qid) {
return batchMCPTextResult("删除失败:队列不存在", true), nil
}
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
return batchMCPTextResult("队列已删除。", false), nil
})
// --- update metadata (title/role/agentMode) ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskUpdateMetadata,
Description: "修改批量任务队列的标题、角色和代理模式。仅在队列非 running 状态下可修改。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量任务队列属性时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "修改批量任务队列标题/角色/代理模式",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"title": map[string]interface{}{
"type": "string",
"description": "新标题(空字符串清除标题)",
},
"role": map[string]interface{}{
"type": "string",
"description": "新角色名(空字符串使用默认角色)",
},
"agent_mode": map[string]interface{}{
"type": "string",
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
title := mcpArgString(args, "title")
role := mcpArgString(args, "role")
agentMode := mcpArgString(args, "agent_mode")
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil {
return batchMCPTextResult(err.Error(), true), nil
}
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_update_metadata", zap.String("queueId", qid))
return batchMCPJSONResult(updated)
})
// --- update schedule ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskUpdateSchedule,
Description: `修改批量任务队列的调度方式和 Cron 表达式。仅在队列非 running 状态下可修改。
schedule_mode 为 cron 时必须提供有效 cron_expr;为 manual 时会清除 Cron 配置。
⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量任务调度配置时才可调用。不要在用户未要求时自行调用。`,
ShortDescription: "修改批量任务调度配置(Cron 表达式)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"schedule_mode": map[string]interface{}{
"type": "string",
"description": "manual 或 cron",
"enum": []string{"manual", "cron"},
},
"cron_expr": map[string]interface{}{
"type": "string",
"description": "Cron 表达式(schedule_mode 为 cron 时必填)。标准 5 段格式:分钟 小时 日 月 星期,如 \"0 */6 * * *\"(每6小时)、\"30 2 * * 1-5\"(工作日凌晨2:30",
},
},
"required": []string{"queue_id", "schedule_mode"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
queue, exists := h.batchTaskManager.GetBatchQueue(qid)
if !exists {
return batchMCPTextResult("队列不存在: "+qid, true), nil
}
if queue.Status == "running" {
return batchMCPTextResult("队列正在运行中,无法修改调度配置", true), nil
}
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
var nextRunAt *time.Time
if scheduleMode == "cron" {
if cronExpr == "" {
return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil
}
sch, err := h.batchCronParser.Parse(cronExpr)
if err != nil {
return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil
}
n := sch.Next(time.Now())
nextRunAt = &n
}
h.batchTaskManager.UpdateQueueSchedule(qid, scheduleMode, cronExpr, nextRunAt)
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_update_schedule", zap.String("queueId", qid), zap.String("scheduleMode", scheduleMode), zap.String("cronExpr", cronExpr))
return batchMCPJSONResult(updated)
})
// --- schedule enabled ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskScheduleEnabled,
Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。
仅对 schedule_mode 为 cron 的队列有意义。
⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求开关批量任务自动调度时才可调用。不要在用户未要求时自行调用。`,
ShortDescription: "开关批量任务 Cron 自动调度",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"schedule_enabled": map[string]interface{}{
"type": "boolean",
"description": "true 允许定时触发,false 仅手工执行",
},
},
"required": []string{"queue_id", "schedule_enabled"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
en, ok := mcpArgBool(args, "schedule_enabled")
if !ok {
return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil
}
if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists {
return batchMCPTextResult("队列不存在", true), nil
}
if !h.batchTaskManager.SetScheduleEnabled(qid, en) {
return batchMCPTextResult("更新失败", true), nil
}
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en))
return batchMCPJSONResult(queue)
})
// --- add task ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskAdd,
Description: "向处于 pending 状态的队列追加一条子任务。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求向批量任务队列添加子任务时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "批量队列添加子任务",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"message": map[string]interface{}{
"type": "string",
"description": "任务指令内容",
},
},
"required": []string{"queue_id", "message"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
msg := strings.TrimSpace(mcpArgString(args, "message"))
if qid == "" || msg == "" {
return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil
}
task, err := h.batchTaskManager.AddTaskToQueue(qid, msg)
if err != nil {
return batchMCPTextResult(err.Error(), true), nil
}
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID))
return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue})
})
// --- update task ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskUpdate,
Description: "修改 pending 队列中仍为 pending 的子任务文案。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量子任务内容时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "更新批量子任务内容",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"task_id": map[string]interface{}{
"type": "string",
"description": "子任务 ID",
},
"message": map[string]interface{}{
"type": "string",
"description": "新的任务指令",
},
},
"required": []string{"queue_id", "task_id", "message"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
tid := mcpArgString(args, "task_id")
msg := strings.TrimSpace(mcpArgString(args, "message"))
if qid == "" || tid == "" || msg == "" {
return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil
}
if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil {
return batchMCPTextResult(err.Error(), true), nil
}
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid))
return batchMCPJSONResult(queue)
})
// --- remove task ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskRemove,
Description: "从 pending 队列中删除仍为 pending 的子任务。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求删除批量子任务时才可调用。不要在用户未要求时自行调用。",
ShortDescription: "删除批量子任务",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"task_id": map[string]interface{}{
"type": "string",
"description": "子任务 ID",
},
},
"required": []string{"queue_id", "task_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
tid := mcpArgString(args, "task_id")
if qid == "" || tid == "" {
return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil
}
if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil {
return batchMCPTextResult(err.Error(), true), nil
}
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid))
return batchMCPJSONResult(queue)
})
logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 12))
}
// --- batch_task_list 精简结构(避免把每条子任务的 result 等大段文本塞进列表上下文) ---
const mcpBatchListTaskMessageMaxRunes = 160
// batchTaskMCPListSummary 列表中的子任务摘要(完整字段用 batch_task_get
type batchTaskMCPListSummary struct {
ID string `json:"id"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
// batchTaskQueueMCPListItem 列表中的队列摘要
type batchTaskQueueMCPListItem struct {
ID string `json:"id"`
Title string `json:"title,omitempty"`
Role string `json:"role,omitempty"`
AgentMode string `json:"agentMode"`
ScheduleMode string `json:"scheduleMode"`
CronExpr string `json:"cronExpr,omitempty"`
NextRunAt *time.Time `json:"nextRunAt,omitempty"`
ScheduleEnabled bool `json:"scheduleEnabled"`
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
Status string `json:"status"`
CreatedAt time.Time `json:"createdAt"`
StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
CurrentIndex int `json:"currentIndex"`
TaskTotal int `json:"task_total"`
TaskCounts map[string]int `json:"task_counts"`
Tasks []batchTaskMCPListSummary `json:"tasks"`
}
func truncateStringRunes(s string, maxRunes int) string {
if maxRunes <= 0 {
return ""
}
n := 0
for i := range s {
if n == maxRunes {
out := strings.TrimSpace(s[:i])
if out == "" {
return "…"
}
return out + "…"
}
n++
}
return s
}
const mcpBatchListMaxTasksPerQueue = 200 // 列表中每个队列最多返回的子任务摘要数
func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
counts := map[string]int{
"pending": 0,
"running": 0,
"completed": 0,
"failed": 0,
"cancelled": 0,
}
tasks := make([]batchTaskMCPListSummary, 0, len(q.Tasks))
for _, t := range q.Tasks {
if t == nil {
continue
}
counts[t.Status]++
// 列表视图限制子任务摘要数量,完整列表通过 batch_task_get 查看
if len(tasks) < mcpBatchListMaxTasksPerQueue {
tasks = append(tasks, batchTaskMCPListSummary{
ID: t.ID,
Status: t.Status,
Message: truncateStringRunes(t.Message, mcpBatchListTaskMessageMaxRunes),
})
}
}
return batchTaskQueueMCPListItem{
ID: q.ID,
Title: q.Title,
Role: q.Role,
AgentMode: q.AgentMode,
ScheduleMode: q.ScheduleMode,
CronExpr: q.CronExpr,
NextRunAt: q.NextRunAt,
ScheduleEnabled: q.ScheduleEnabled,
LastScheduleTriggerAt: q.LastScheduleTriggerAt,
Status: q.Status,
CreatedAt: q.CreatedAt,
StartedAt: q.StartedAt,
CompletedAt: q.CompletedAt,
CurrentIndex: q.CurrentIndex,
TaskTotal: len(tasks),
TaskCounts: counts,
Tasks: tasks,
}
}
func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult {
return &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: text}},
IsError: isErr,
}
}
func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) {
b, err := json.MarshalIndent(v, "", " ")
if err != nil {
return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil
}
return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil
}
func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) {
if raw, ok := args["tasks"]; ok && raw != nil {
switch t := raw.(type) {
case []interface{}:
out := make([]string, 0, len(t))
for _, x := range t {
if s, ok := x.(string); ok {
if tr := strings.TrimSpace(s); tr != "" {
out = append(out, tr)
}
}
}
if len(out) > 0 {
return out, ""
}
}
}
if txt := mcpArgString(args, "tasks_text"); txt != "" {
lines := strings.Split(txt, "\n")
out := make([]string, 0, len(lines))
for _, line := range lines {
if tr := strings.TrimSpace(line); tr != "" {
out = append(out, tr)
}
}
if len(out) > 0 {
return out, ""
}
}
return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)"
}
func mcpArgString(args map[string]interface{}, key string) string {
v, ok := args[key]
if !ok || v == nil {
return ""
}
switch t := v.(type) {
case string:
return strings.TrimSpace(t)
case float64:
return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64))
case json.Number:
return strings.TrimSpace(t.String())
default:
return strings.TrimSpace(fmt.Sprint(t))
}
}
func mcpArgFloat(args map[string]interface{}, key string) float64 {
v, ok := args[key]
if !ok || v == nil {
return 0
}
switch t := v.(type) {
case float64:
return t
case int:
return float64(t)
case int64:
return float64(t)
case json.Number:
f, _ := t.Float64()
return f
case string:
f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64)
return f
default:
return 0
}
}
func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) {
v, exists := args[key]
if !exists {
return false, false
}
switch t := v.(type) {
case bool:
return t, true
case string:
s := strings.ToLower(strings.TrimSpace(t))
if s == "true" || s == "1" || s == "yes" {
return true, true
}
if s == "false" || s == "0" || s == "no" {
return false, true
}
case float64:
return t != 0, true
}
return false, false
}
File diff suppressed because it is too large Load Diff
+528
View File
@@ -0,0 +1,528 @@
package handler
import (
"crypto/rand"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"time"
"unicode/utf8"
"cyberstrike-ai/internal/audit"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const (
chatUploadsRootDirName = "chat_uploads"
maxChatUploadEditBytes = 2 * 1024 * 1024 // 文本编辑上限
)
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
type ChatUploadsHandler struct {
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewChatUploadsHandler 创建处理器
func NewChatUploadsHandler(logger *zap.Logger) *ChatUploadsHandler {
return &ChatUploadsHandler{logger: logger}
}
func (h *ChatUploadsHandler) absRoot() (string, error) {
cwd, err := os.Getwd()
if err != nil {
return "", err
}
return filepath.Abs(filepath.Join(cwd, chatUploadsRootDirName))
}
// resolveUnderChatUploads 校验 relativePath(使用 / 分隔)对应文件必须在 chat_uploads 根下
func (h *ChatUploadsHandler) resolveUnderChatUploads(relativePath string) (abs string, err error) {
root, err := h.absRoot()
if err != nil {
return "", err
}
rel := strings.TrimSpace(relativePath)
if rel == "" {
return "", fmt.Errorf("empty path")
}
rel = filepath.Clean(filepath.FromSlash(rel))
if rel == "." || strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("invalid path")
}
full := filepath.Join(root, rel)
full, err = filepath.Abs(full)
if err != nil {
return "", err
}
rootAbs, _ := filepath.Abs(root)
if full != rootAbs && !strings.HasPrefix(full, rootAbs+string(filepath.Separator)) {
return "", fmt.Errorf("path escapes chat_uploads root")
}
return full, nil
}
// ChatUploadFileItem 列表项
type ChatUploadFileItem struct {
RelativePath string `json:"relativePath"`
AbsolutePath string `json:"absolutePath"` // 服务器上的绝对路径,便于在对话中引用(与附件落盘路径一致)
Name string `json:"name"`
Size int64 `json:"size"`
ModifiedUnix int64 `json:"modifiedUnix"`
Date string `json:"date"`
ConversationID string `json:"conversationId"`
// SubPath 为日期、会话目录之下的子路径(不含文件名),如 date/conv/a/b/file 则为 "a/b";无嵌套则为 ""。
SubPath string `json:"subPath"`
}
// List GET /api/chat-uploads
func (h *ChatUploadsHandler) List(c *gin.Context) {
conversationFilter := strings.TrimSpace(c.Query("conversation"))
root, err := h.absRoot()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 保证根目录存在,否则「按文件夹」浏览时无法 mkdir,且首次列表为空时界面无路径工具栏
if err := os.MkdirAll(root, 0755); err != nil {
h.logger.Warn("创建 chat_uploads 根目录失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var files []ChatUploadFileItem
var folders []string
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
if rel == "." {
return nil
}
relSlash := filepath.ToSlash(rel)
if d.IsDir() {
folders = append(folders, relSlash)
return nil
}
info, err := d.Info()
if err != nil {
return err
}
parts := strings.Split(relSlash, "/")
var dateStr, convID string
if len(parts) >= 2 {
dateStr = parts[0]
}
if len(parts) >= 3 {
convID = parts[1]
}
var subPath string
if len(parts) >= 4 {
subPath = strings.Join(parts[2:len(parts)-1], "/")
}
if conversationFilter != "" && convID != conversationFilter {
return nil
}
absPath, _ := filepath.Abs(path)
files = append(files, ChatUploadFileItem{
RelativePath: relSlash,
AbsolutePath: absPath,
Name: d.Name(),
Size: info.Size(),
ModifiedUnix: info.ModTime().Unix(),
Date: dateStr,
ConversationID: convID,
SubPath: subPath,
})
return nil
})
if err != nil {
h.logger.Warn("列举对话附件失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if conversationFilter != "" {
filteredFolders := make([]string, 0, len(folders))
for _, rel := range folders {
parts := strings.Split(rel, "/")
if len(parts) >= 2 && parts[1] == conversationFilter {
filteredFolders = append(filteredFolders, rel)
continue
}
if len(parts) == 1 {
prefix := rel + "/"
for _, f := range files {
if strings.HasPrefix(f.RelativePath, prefix) {
filteredFolders = append(filteredFolders, rel)
break
}
}
}
}
folders = filteredFolders
}
sort.Strings(folders)
sort.Slice(files, func(i, j int) bool {
return files[i].ModifiedUnix > files[j].ModifiedUnix
})
c.JSON(http.StatusOK, gin.H{"files": files, "folders": folders})
}
// Download GET /api/chat-uploads/download?path=...
func (h *ChatUploadsHandler) Download(c *gin.Context) {
p := c.Query("path")
abs, err := h.resolveUnderChatUploads(p)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(abs)
if err != nil || st.IsDir() {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
c.FileAttachment(abs, filepath.Base(abs))
}
type chatUploadPathBody struct {
Path string `json:"path"`
}
// Delete DELETE /api/chat-uploads
func (h *ChatUploadsHandler) Delete(c *gin.Context) {
var body chatUploadPathBody
if err := c.ShouldBindJSON(&body); err != nil || strings.TrimSpace(body.Path) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
abs, err := h.resolveUnderChatUploads(body.Path)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(abs)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if st.IsDir() {
if err := os.RemoveAll(abs); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
if err := os.Remove(abs); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if h.audit != nil {
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
type chatUploadMkdirBody struct {
Parent string `json:"parent"`
Name string `json:"name"`
}
// Mkdir POST /api/chat-uploads/mkdir — 在 parent 目录下新建子目录(parent 为 chat_uploads 下相对路径,空表示根目录;name 为单段目录名)
func (h *ChatUploadsHandler) Mkdir(c *gin.Context) {
var body chatUploadMkdirBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
name := strings.TrimSpace(body.Name)
if name == "" || strings.ContainsAny(name, `/\`) || name == "." || name == ".." {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
if utf8.RuneCountInString(name) > 200 {
c.JSON(http.StatusBadRequest, gin.H{"error": "name too long"})
return
}
parent := strings.TrimSpace(body.Parent)
parent = filepath.ToSlash(filepath.Clean(filepath.FromSlash(parent)))
parent = strings.Trim(parent, "/")
if parent == "." {
parent = ""
}
root, err := h.absRoot()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if parent != "" {
absParent, err := h.resolveUnderChatUploads(parent)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(absParent)
if err != nil || !st.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "parent not found"})
return
}
}
var rel string
if parent == "" {
rel = name
} else {
rel = parent + "/" + name
}
absNew, err := h.resolveUnderChatUploads(rel)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(absNew); err == nil {
c.JSON(http.StatusConflict, gin.H{"error": "already exists"})
return
}
if err := os.Mkdir(absNew, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
relOut, _ := filepath.Rel(root, absNew)
c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(relOut)})
}
type chatUploadRenameBody struct {
Path string `json:"path"`
NewName string `json:"newName"`
}
// Rename PUT /api/chat-uploads/rename
func (h *ChatUploadsHandler) Rename(c *gin.Context) {
var body chatUploadRenameBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
newName := strings.TrimSpace(body.NewName)
if newName == "" || strings.ContainsAny(newName, `/\`) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid newName"})
return
}
abs, err := h.resolveUnderChatUploads(body.Path)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
dir := filepath.Dir(abs)
newAbs := filepath.Join(dir, filepath.Base(newName))
root, _ := h.absRoot()
newAbs, _ = filepath.Abs(newAbs)
if newAbs != root && !strings.HasPrefix(newAbs, root+string(filepath.Separator)) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid target path"})
return
}
if err := os.Rename(abs, newAbs); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
newRel, _ := filepath.Rel(root, newAbs)
c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(newRel)})
}
type chatUploadContentBody struct {
Path string `json:"path"`
Content string `json:"content"`
}
// GetContent GET /api/chat-uploads/content?path=...
func (h *ChatUploadsHandler) GetContent(c *gin.Context) {
p := c.Query("path")
abs, err := h.resolveUnderChatUploads(p)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(abs)
if err != nil || st.IsDir() {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
if st.Size() > maxChatUploadEditBytes {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "file too large for editor"})
return
}
b, err := os.ReadFile(abs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !utf8.Valid(b) {
c.JSON(http.StatusBadRequest, gin.H{"error": "binary file not editable in UI"})
return
}
c.JSON(http.StatusOK, gin.H{"content": string(b)})
}
// PutContent PUT /api/chat-uploads/content
func (h *ChatUploadsHandler) PutContent(c *gin.Context) {
var body chatUploadContentBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
if !utf8.ValidString(body.Content) {
c.JSON(http.StatusBadRequest, gin.H{"error": "content must be valid UTF-8"})
return
}
if len(body.Content) > maxChatUploadEditBytes {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "content too large"})
return
}
abs, err := h.resolveUnderChatUploads(body.Path)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := os.WriteFile(abs, []byte(body.Content), 0644); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
func chatUploadShortRand(n int) string {
const letters = "0123456789abcdef"
b := make([]byte, n)
_, _ = rand.Read(b)
for i := range b {
b[i] = letters[int(b[i])%len(letters)]
}
return string(b)
}
// Upload POST /api/chat-uploads multipart: fileconversationId 可选;relativeDir 可选(chat_uploads 下目录的相对路径,将文件直接上传至该目录)
func (h *ChatUploadsHandler) Upload(c *gin.Context) {
fh, err := c.FormFile("file")
if err != nil || fh == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing file"})
return
}
root, err := h.absRoot()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var targetDir string
targetRel := strings.TrimSpace(c.PostForm("relativeDir"))
if targetRel != "" {
absDir, err := h.resolveUnderChatUploads(targetRel)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(absDir)
if err != nil {
if os.IsNotExist(err) {
if err := os.MkdirAll(absDir, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else if !st.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "relativeDir is not a directory"})
return
}
targetDir = absDir
} else {
convID := strings.TrimSpace(c.PostForm("conversationId"))
convDir := convID
if convDir == "" {
convDir = "_manual"
} else {
convDir = strings.ReplaceAll(convDir, string(filepath.Separator), "_")
}
dateStr := time.Now().Format("2006-01-02")
targetDir = filepath.Join(root, dateStr, convDir)
if err := os.MkdirAll(targetDir, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
baseName := filepath.Base(fh.Filename)
if baseName == "" || baseName == "." {
baseName = "file"
}
baseName = strings.ReplaceAll(baseName, string(filepath.Separator), "_")
ext := filepath.Ext(baseName)
nameNoExt := strings.TrimSuffix(baseName, ext)
suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), chatUploadShortRand(6))
var unique string
if ext != "" {
unique = nameNoExt + suffix + ext
} else {
unique = baseName + suffix
}
fullPath := filepath.Join(targetDir, unique)
src, err := fh.Open()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer src.Close()
dst, err := os.Create(fullPath)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer dst.Close()
if _, err := io.Copy(dst, src); err != nil {
_ = os.Remove(fullPath)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rel, _ := filepath.Rel(root, fullPath)
absSaved, _ := filepath.Abs(fullPath)
if h.audit != nil {
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
"name": unique,
})
}
c.JSON(http.StatusOK, gin.H{
"ok": true,
"relativePath": filepath.ToSlash(rel),
"absolutePath": absSaved,
"name": unique,
})
}
File diff suppressed because it is too large Load Diff
+312
View File
@@ -0,0 +1,312 @@
package handler
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ConversationHandler 对话处理器
type ConversationHandler struct {
db *database.DB
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *ConversationHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewConversationHandler 创建新的对话处理器
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
return &ConversationHandler{
db: db,
logger: logger,
}
}
// CreateConversationRequest 创建对话请求
type CreateConversationRequest struct {
Title string `json:"title"`
ProjectID string `json:"projectId,omitempty"`
}
// SetConversationProjectRequest 设置对话所属项目
type SetConversationProjectRequest struct {
ProjectID string `json:"projectId"` // 空字符串表示解除绑定
}
// CreateConversation 创建新对话
func (h *ConversationHandler) CreateConversation(c *gin.Context) {
var req CreateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
title := req.Title
if title == "" {
title = "新对话"
}
meta := audit.ConversationCreateMetaFromGin(c, "api")
meta.ProjectID = strings.TrimSpace(req.ProjectID)
conv, err := h.db.CreateConversation(title, meta)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
// SetConversationProject 设置或清除对话绑定的项目
func (h *ConversationHandler) SetConversationProject(c *gin.Context) {
id := c.Param("id")
var req SetConversationProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := h.db.GetConversation(id); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)})
}
// ListConversations 列出对话
func (h *ConversationHandler) ListConversations(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50")
offsetStr := c.DefaultQuery("offset", "0")
search := c.Query("search") // 获取搜索参数
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
if limit <= 0 {
limit = 50
}
if limit > 1000 {
limit = 1000
}
excludeGrouped := strings.TrimSpace(search) == "" &&
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
var conversations []*database.Conversation
var total int
var err error
if excludeGrouped {
conversations, err = h.db.ListUngroupedConversations(limit, offset)
if err == nil {
total, err = h.db.CountUngroupedConversations()
}
} else {
conversations, err = h.db.ListConversations(limit, offset, search)
if err == nil {
total, err = h.db.CountConversations(search)
}
}
if err != nil {
h.logger.Error("获取对话列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if conversations == nil {
conversations = []*database.Conversation{}
}
c.JSON(http.StatusOK, gin.H{
"conversations": conversations,
"total": total,
"limit": limit,
"offset": offset,
})
}
// GetConversation 获取对话
func (h *ConversationHandler) GetConversation(c *gin.Context) {
id := c.Param("id")
// 默认轻量加载,只有用户需要展开详情时再按需拉取
// include_process_details=1/true 时返回全量 processDetails(兼容旧行为)
includeStr := c.DefaultQuery("include_process_details", "0")
include := includeStr == "1" || includeStr == "true" || includeStr == "yes"
var (
conv *database.Conversation
err error
)
if include {
conv, err = h.db.GetConversation(id)
} else {
conv, err = h.db.GetConversationLite(id)
}
if err != nil {
h.logger.Error("获取对话失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
c.JSON(http.StatusOK, conv)
}
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
messageID := c.Param("id")
if messageID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "message id required"})
return
}
details, err := h.db.GetProcessDetails(messageID)
if err != nil {
h.logger.Error("获取过程详情失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
details = database.DedupeConsecutiveProcessDetails(details)
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
out := make([]map[string]interface{}, 0, len(details))
for _, d := range details {
var data interface{}
if d.Data != "" {
if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
h.logger.Warn("解析过程详情数据失败", zap.Error(err))
}
}
out = append(out, map[string]interface{}{
"id": d.ID,
"messageId": d.MessageID,
"conversationId": d.ConversationID,
"eventType": d.EventType,
"message": d.Message,
"data": data,
"createdAt": d.CreatedAt,
})
}
c.JSON(http.StatusOK, gin.H{"processDetails": out})
}
// UpdateConversationRequest 更新对话请求
type UpdateConversationRequest struct {
Title string `json:"title"`
}
// UpdateConversation 更新对话
func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
id := c.Param("id")
var req UpdateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Title == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "标题不能为空"})
return
}
if err := h.db.UpdateConversationTitle(id, req.Title); err != nil {
h.logger.Error("更新对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 返回更新后的对话
conv, err := h.db.GetConversation(id)
if err != nil {
h.logger.Error("获取更新后的对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
// DeleteConversation 删除对话
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteConversation(id); err != nil {
h.logger.Error("删除对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "conversation",
Action: "delete",
Result: "success",
ResourceType: "conversation",
ResourceID: id,
Message: "删除对话",
})
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// DeleteTurnRequest 删除一轮对话(POST /api/conversations/:id/delete-turn
type DeleteTurnRequest struct {
MessageID string `json:"messageId"`
}
// DeleteConversationTurn 删除锚点消息所在轮次(从该轮 user 到下一轮 user 之前),并清空 last_react_*。
func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
conversationID := c.Param("id")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversation id required"})
return
}
var req DeleteTurnRequest
if err := c.ShouldBindJSON(&req); err != nil || req.MessageID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "messageId required"})
return
}
if _, err := h.db.GetConversation(conversationID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
deletedIDs, err := h.db.DeleteConversationTurn(conversationID, req.MessageID)
if err != nil {
h.logger.Warn("删除对话轮次失败",
zap.String("conversationId", conversationID),
zap.String("messageId", req.MessageID),
zap.Error(err),
)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
"message_id": req.MessageID,
"deleted": len(deletedIDs),
})
}
c.JSON(http.StatusOK, gin.H{
"deletedMessageIds": deletedIDs,
"message": "ok",
})
}
+180
View File
@@ -0,0 +1,180 @@
package handler
import (
"context"
"errors"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/multiagent"
"go.uber.org/zap"
)
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
if h.config != nil {
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
}
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
}
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
}
return 0
}
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
func (h *AgentHandler) applyEinoTraceResumeSegment(
conversationID string,
result *multiagent.RunResult,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
) {
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
*curHistory = hist
}
if segmentUserMessage != "" {
*curFinalMessage = segmentUserMessage
}
}
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
func (h *AgentHandler) applyEinoTransientRetrySegment(
conversationID string,
result *multiagent.RunResult,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
) {
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
*curHistory = hist
}
if s := strings.TrimSpace(segmentUserMessage); s != "" {
*curFinalMessage = segmentUserMessage
}
}
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
func (h *AgentHandler) handleEinoTransientRetryContinue(
baseCtx context.Context,
conversationID string,
result *multiagent.RunResult,
runErr error,
transientAttempts *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
progressCallback func(eventType, message string, data interface{}),
sendProgress func(msg string, extra map[string]interface{}),
) (handled bool, fatal error) {
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
return false, nil
}
maxAttempts := h.einoRunRetryMaxAttempts()
*transientAttempts++
if *transientAttempts > maxAttempts {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
return false, errors.New("transient retry exhausted: " + runErr.Error())
}
attemptNo := *transientAttempts
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
if progressCallback != nil {
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
"maxAttempts": maxAttempts,
"backoffSec": int(backoff.Seconds()),
})
}
select {
case <-baseCtx.Done():
return false, context.Cause(baseCtx)
case <-time.After(backoff):
}
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
if progressCallback != nil {
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
})
}
if sendProgress != nil {
sendProgress("正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "transient_retry",
})
}
return true, nil
}
// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。
// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。
func (h *AgentHandler) handleEinoEmptyResponseContinue(
baseCtx context.Context,
conversationID string,
result *multiagent.RunResult,
runErr error,
emptyResponseAttempts *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
progressCallback func(eventType, message string, data interface{}),
sendProgress func(msg string, extra map[string]interface{}),
) (handled bool, exhausted bool) {
if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) {
return false, false
}
maxAttempts := h.einoRunRetryMaxAttempts()
*emptyResponseAttempts++
if *emptyResponseAttempts > maxAttempts {
if h.logger != nil {
h.logger.Warn("eino empty response auto resume exhausted",
zap.String("conversationId", conversationID),
zap.Int("maxAttempts", maxAttempts))
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
return false, true
}
attemptNo := *emptyResponseAttempts
if h.logger != nil {
h.logger.Info("eino empty response, auto resume from trace",
zap.String("conversationId", conversationID),
zap.Int("attempt", attemptNo),
zap.Int("maxAttempts", maxAttempts))
}
if progressCallback != nil {
progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
"maxAttempts": maxAttempts,
"resumeKind": "trace_segment",
})
}
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
if sendProgress != nil {
sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{
"conversationId": conversationID,
"source": "empty_response_continue",
})
}
return true, false
}
+511
View File
@@ -0,0 +1,511 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/multiagent"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// EinoSingleAgentLoopStream Eino ADK 单代理(ChatModelAgent + Runner)流式对话;不依赖 multi_agent.enabled。
func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
c.Header("Content-Type", "text/event-stream; charset=utf-8")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
ev := StreamEvent{Type: "error", Message: "请求参数错误: " + err.Error()}
b, _ := json.Marshal(ev)
fmt.Fprintf(c.Writer, "data: %s\n\n", b)
done := StreamEvent{Type: "done", Message: ""}
db, _ := json.Marshal(done)
fmt.Fprintf(c.Writer, "data: %s\n\n", db)
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
return
}
c.Header("X-Accel-Buffering", "no")
var baseCtx context.Context
clientDisconnected := false
var sseWriteMu sync.Mutex
var ssePublishConversationID string
sendEvent := func(eventType, message string, data interface{}) {
if eventType == "error" && baseCtx != nil {
cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
return
}
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, errMarshal := json.Marshal(ev)
if errMarshal != nil {
b = []byte(`{"type":"error","message":"marshal failed"}`)
}
sseLine := make([]byte, 0, len(b)+8)
sseLine = append(sseLine, []byte("data: ")...)
sseLine = append(sseLine, b...)
sseLine = append(sseLine, '\n', '\n')
if ssePublishConversationID != "" && h.taskEventBus != nil {
h.taskEventBus.Publish(ssePublishConversationID, sseLine)
}
if clientDisconnected {
return
}
select {
case <-c.Request.Context().Done():
clientDisconnected = true
return
default:
}
sseWriteMu.Lock()
_, err := c.Writer.Write(sseLine)
if err != nil {
sseWriteMu.Unlock()
clientDisconnected = true
return
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
c.Writer.Flush()
}
sseWriteMu.Unlock()
}
h.logger.Info("收到 Eino ADK 单代理流式请求",
zap.String("conversationId", req.ConversationID),
)
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
if err != nil {
sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil)
return
}
ssePublishConversationID = prep.ConversationID
if prep.CreatedNew {
sendEvent("conversation", "会话已创建", map[string]interface{}{
"conversationId": prep.ConversationID,
})
}
conversationID := prep.ConversationID
assistantMessageID := prep.AssistantMessageID
h.activateHITLForConversation(conversationID, req.Hitl)
if h.hitlManager != nil {
defer h.hitlManager.DeactivateConversation(conversationID)
}
if prep.UserMessageID != "" {
sendEvent("message_saved", "", map[string]interface{}{
"conversationId": conversationID,
"userMessageId": prep.UserMessageID,
})
}
var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History
roleTools := prep.RoleTools
taskStatus := "completed"
// 仅在成功 StartTask 后再 FinishTask。若 StartTask 因 ErrTaskAlreadyRunning 失败仍 defer FinishTask
// 会误删其他连接上正在运行的同会话任务,导致「第一次拦截、第二次却放行」。
taskOwned := false
defer func() {
if taskOwned {
h.tasks.FinishTask(conversationID, taskStatus)
}
}()
sendEvent("progress", "正在启动 Eino ADK 单代理(ChatModelAgent...", map[string]interface{}{
"conversationId": conversationID,
})
stopKeepalive := make(chan struct{})
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
defer close(stopKeepalive)
if h.config == nil {
taskStatus = "failed"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
sendEvent("error", "服务器配置未加载", nil)
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
return
}
var result *multiagent.RunResult
var runErr error
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
var errorMsg string
if errors.Is(err, ErrTaskAlreadyRunning) {
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
sendEvent("error", errorMsg, map[string]interface{}{
"conversationId": conversationID,
"errorType": "task_already_running",
})
} else {
errorMsg = "❌ 无法启动任务: " + err.Error()
sendEvent("error", errorMsg, nil)
}
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
}
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
taskOwned = true
var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
var emptyResponseAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int
for {
segmentMainIterationMax := 0
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
progressCallback := func(eventType, message string, data interface{}) {
if eventType == "iteration" {
if m, ok := data.(map[string]interface{}); ok {
if scope, _ := m["einoScope"].(string); scope == "main" {
raw := 0
switch v := m["iteration"].(type) {
case int:
raw = v
case int32:
raw = int(v)
case int64:
raw = int(v)
case float64:
raw = int(v)
case float32:
raw = int(v)
}
if raw > 0 {
if raw > segmentMainIterationMax {
segmentMainIterationMax = raw
}
m["iteration"] = raw + mainIterationOffset
}
}
}
}
rawProgressCallback(eventType, message, data)
}
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
})
result, runErr = multiagent.RunEinoSingleChatModelAgent(
taskCtxLoop,
h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
conversationID,
h.conversationProjectID(conversationID),
curFinalMessage,
curHistory,
roleTools,
progressCallback,
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(conversationID),
)
if result != nil && len(result.MCPExecutionIDs) > 0 {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
}
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if exhaustedEmpty {
runErr = nil
transientRunAttempts = 0
timeoutCancel()
break
}
if handledEmpty {
mainIterationOffset += segmentMainIterationMax
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
emptyResponseAttempts = 0
timeoutCancel()
break
}
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
note := h.tasks.TakeInterruptContinueNote(conversationID)
icSummary := interruptContinueTimelineSummary(note)
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
"conversationId": conversationID,
"rawReason": strings.TrimSpace(note),
"emptyReason": strings.TrimSpace(note) == "",
"kind": "no_active_mcp_tool",
})
inject := formatInterruptContinueUserMessage(note)
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
curHistory = hist
}
curFinalMessage = inject
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
"conversationId": conversationID,
"source": "interrupt_continue",
})
mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
cancelMsg := "任务已被用户取消,后续操作已停止。"
if assistantMessageID != "" {
if result != nil {
if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil {
h.logger.Warn("合并取消前的部分回复失败", zap.Error(err))
}
}
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
}
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
}
sendEvent("cancelled", cancelMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) {
taskStatus = "timeout"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
timeoutMsg := "任务执行超时,已自动终止。"
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
}
sendEvent("error", timeoutMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
"errorType": "timeout",
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
h.logger.Error("Eino ADK 单代理执行失败", zap.Error(runErr))
taskStatus = "failed"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
errMsg := "执行失败: " + runErr.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
}
sendEvent("error", errMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
timeoutCancel()
if assistantMessageID != "" {
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
}
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
sendEvent("response", result.Response, map[string]interface{}{
"mcpExecutionIds": cumulativeMCPExecutionIDs,
"conversationId": conversationID,
"messageId": assistantMessageID,
"agentMode": "eino_single",
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
}
// EinoSingleAgentLoop Eino ADK 单代理非流式对话。
func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.activateHITLForConversation(prep.ConversationID, req.Hitl)
if h.hitlManager != nil {
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
}
var progressBuf strings.Builder
progressCallbackRaw := func(eventType, message string, data interface{}) {
progressBuf.WriteString(eventType)
progressBuf.WriteByte('\n')
}
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
defer cancelWithCause(nil)
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
defer timeoutCancel()
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, progressCallbackRaw)
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
})
if h.config == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器配置未加载"})
return
}
curHist := prep.History
curMsg := prep.FinalMessage
var result *multiagent.RunResult
var runErr error
var transientRunAttempts int
var emptyResponseAttempts int
for {
result, runErr = multiagent.RunEinoSingleChatModelAgent(
taskCtx,
h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
prep.ConversationID,
h.conversationProjectID(prep.ConversationID),
curMsg,
curHist,
prep.RoleTools,
progressCallback,
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
)
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
)
if exhaustedEmpty {
runErr = nil
break
}
if handledEmpty {
continue
}
if runErr == nil {
break
}
if handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
); handled {
continue
} else if fatalErr != nil {
runErr = fatalErr
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
}
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
return
}
if prep.AssistantMessageID != "" {
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
}
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
}
c.JSON(http.StatusOK, gin.H{
"response": result.Response,
"conversationId": prep.ConversationID,
"mcpExecutionIds": result.MCPExecutionIDs,
"assistantMessageId": prep.AssistantMessageID,
"agentMode": "eino_single",
})
}
+485
View File
@@ -0,0 +1,485 @@
package handler
import (
"fmt"
"net/http"
"os"
"sync"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
// ExternalMCPHandler 外部MCP处理器
type ExternalMCPHandler struct {
manager *mcp.ExternalMCPManager
config *config.Config
configPath string
logger *zap.Logger
audit *audit.Service
mu sync.RWMutex
}
// SetAudit wires platform audit logging.
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewExternalMCPHandler 创建外部MCP处理器
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
return &ExternalMCPHandler{
manager: manager,
config: cfg,
configPath: configPath,
logger: logger,
}
}
// GetExternalMCPs 获取所有外部MCP配置
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
// 获取所有外部MCP的工具数量
toolCounts := h.manager.GetToolCounts()
// 转换为响应格式
result := make(map[string]ExternalMCPResponse)
for name, cfg := range configs {
client, exists := h.manager.GetClient(name)
status := "disconnected"
if exists {
status = client.GetStatus()
} else if h.isEnabled(cfg) {
status = "disconnected"
} else {
status = "disabled"
}
toolCount := toolCounts[name]
errorMsg := externalMCPStatusError(h.manager, name, status)
result[name] = ExternalMCPResponse{
Config: cfg,
Status: status,
ToolCount: toolCount,
Error: errorMsg,
}
}
c.JSON(http.StatusOK, gin.H{
"servers": result,
"stats": h.manager.GetStats(),
})
}
// GetExternalMCP 获取单个外部MCP配置
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
cfg, exists := configs[name]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
return
}
client, clientExists := h.manager.GetClient(name)
status := "disconnected"
if clientExists {
status = client.GetStatus()
} else if h.isEnabled(cfg) {
status = "disconnected"
} else {
status = "disabled"
}
// 获取工具数量
toolCount := 0
if clientExists && client.IsConnected() {
if count, err := h.manager.GetToolCount(name); err == nil {
toolCount = count
}
}
c.JSON(http.StatusOK, ExternalMCPResponse{
Config: cfg,
Status: status,
ToolCount: toolCount,
Error: externalMCPStatusError(h.manager, name, status),
})
}
// externalMCPStatusError 在 error/disconnected 状态下返回最近错误(含断连原因)。
func externalMCPStatusError(manager *mcp.ExternalMCPManager, name, status string) string {
if status != "error" && status != "disconnected" {
return ""
}
return manager.GetError(name)
}
// AddOrUpdateExternalMCP 添加或更新外部MCP配置
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
var req AddOrUpdateExternalMCPRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
name := c.Param("name")
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
return
}
// 验证配置
if err := h.validateConfig(req.Config); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.mu.Lock()
defer h.mu.Unlock()
// 添加或更新配置
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
return
}
// 更新内存中的配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
cfg := req.Config
// 官方 disabled 字段 → ExternalMCPEnable 取反
if cfg.Disabled {
cfg.ExternalMCPEnable = false
} else if !cfg.ExternalMCPEnable {
// 用户未显式设置 external_mcp_enable,官方配置默认就是启用的
cfg.ExternalMCPEnable = true
}
// 展开 ${VAR} 环境变量
config.ExpandConfigEnv(&cfg)
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "external_mcp",
Action: "upsert",
Result: "success",
ResourceType: "external_mcp",
ResourceID: name,
Message: "更新外部 MCP 配置",
})
}
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
}
// DeleteExternalMCP 删除外部MCP配置
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 移除配置
if err := h.manager.RemoveConfig(name); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
return
}
// 从内存配置中删除
if h.config.ExternalMCP.Servers != nil {
delete(h.config.ExternalMCP.Servers, name)
}
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "external_mcp",
Action: "delete",
Result: "success",
ResourceType: "external_mcp",
ResourceID: name,
Message: "删除外部 MCP 配置",
})
}
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
}
// StartExternalMCP 启动外部MCP
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 更新配置为启用
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = true
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
// 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行)
h.logger.Info("开始启动外部MCP", zap.String("name", name))
if err := h.manager.StartClient(name); err != nil {
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
"status": "error",
})
return
}
// 获取客户端状态(应该是connecting)
client, exists := h.manager.GetClient(name)
status := "connecting"
if exists {
status = client.GetStatus()
}
// 立即返回,不等待连接完成
// 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态
c.JSON(http.StatusOK, gin.H{
"message": "外部MCP启动请求已提交,正在后台连接中",
"status": status,
})
}
// StopExternalMCP 停止外部MCP
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 停止客户端
if err := h.manager.StopClient(name); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 更新配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = false
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP已停止", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
}
// GetExternalMCPStats 获取统计信息
func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) {
stats := h.manager.GetStats()
c.JSON(http.StatusOK, stats)
}
// validateConfig 验证配置(同时支持官方 type 字段和旧版 transport 字段)
func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error {
transport := cfg.GetTransportType()
if transport == "" {
return fmt.Errorf("需要指定 commandstdio模式)或 url + typehttp/sse模式)")
}
switch transport {
case "http":
if cfg.URL == "" {
return fmt.Errorf("HTTP模式需要 url")
}
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("stdio模式需要 command")
}
case "sse":
if cfg.URL == "" {
return fmt.Errorf("SSE模式需要 url")
}
default:
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
}
return nil
}
// isEnabled 检查是否启用
func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool {
return cfg.ExternalMCPEnable
}
// saveConfig 保存配置到文件
func (h *ExternalMCPHandler) saveConfig() error {
data, err := os.ReadFile(h.configPath)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
}
if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil {
h.logger.Warn("创建配置备份失败", zap.Error(err))
}
root, err := loadYAMLDocument(h.configPath)
if err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
updateExternalMCPConfig(root, h.config.ExternalMCP)
if err := writeYAMLDocument(h.configPath, root); err != nil {
return fmt.Errorf("保存配置文件失败: %w", err)
}
h.logger.Info("配置已保存", zap.String("path", h.configPath))
return nil
}
// updateExternalMCPConfig 更新外部MCP配置
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig) {
root := doc.Content[0]
externalMCPNode := ensureMap(root, "external_mcp")
serversNode := ensureMap(externalMCPNode, "servers")
// 清空现有服务器配置
serversNode.Content = nil
// 添加新的服务器配置
for name, serverCfg := range cfg.Servers {
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
// type(官方 MCP 传输类型)
effectiveType := serverCfg.GetTransportType()
if effectiveType != "" && effectiveType != "stdio" {
// stdio 可省略(有 command 时自动推断)
setStringInMap(serverNode, "type", effectiveType)
}
if serverCfg.Command != "" {
setStringInMap(serverNode, "command", serverCfg.Command)
}
if len(serverCfg.Args) > 0 {
setStringArrayInMap(serverNode, "args", serverCfg.Args)
}
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
envNode := ensureMap(serverNode, "env")
for envKey, envValue := range serverCfg.Env {
setStringInMap(envNode, envKey, envValue)
}
}
if serverCfg.URL != "" {
setStringInMap(serverNode, "url", serverCfg.URL)
}
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
headersNode := ensureMap(serverNode, "headers")
for k, v := range serverCfg.Headers {
setStringInMap(headersNode, k, v)
}
}
if serverCfg.Description != "" {
setStringInMap(serverNode, "description", serverCfg.Description)
}
if serverCfg.Timeout > 0 {
setIntInMap(serverNode, "timeout", serverCfg.Timeout)
}
// 官方标准字段
if serverCfg.Disabled {
setBoolInMap(serverNode, "disabled", true)
}
if len(serverCfg.AutoApprove) > 0 {
setStringArrayInMap(serverNode, "autoApprove", serverCfg.AutoApprove)
}
// SDK 高级配置
if serverCfg.MaxRetries > 0 {
setIntInMap(serverNode, "max_retries", serverCfg.MaxRetries)
}
if serverCfg.TerminateDuration > 0 {
setIntInMap(serverNode, "terminate_duration", serverCfg.TerminateDuration)
}
if serverCfg.KeepAlive > 0 {
setIntInMap(serverNode, "keep_alive", serverCfg.KeepAlive)
}
setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable)
if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 {
toolEnabledNode := ensureMap(serverNode, "tool_enabled")
for toolName, enabled := range serverCfg.ToolEnabled {
setBoolInMap(toolEnabledNode, toolName, enabled)
}
}
}
}
// setStringArrayInMap 设置字符串数组
func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) {
_, valueNode := ensureKeyValue(mapNode, key)
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
for _, v := range values {
itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v}
valueNode.Content = append(valueNode.Content, itemNode)
}
}
// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求
type AddOrUpdateExternalMCPRequest struct {
Config config.ExternalMCPServerConfig `json:"config"`
}
// ExternalMCPResponse 外部MCP响应
type ExternalMCPResponse struct {
Config config.ExternalMCPServerConfig `json:"config"`
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
ToolCount int `json:"tool_count"` // 工具数量
Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在)
}
+518
View File
@@ -0,0 +1,518 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建临时配置文件
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
if err != nil {
panic(err)
}
tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n")
tmpFile.Close()
configPath := tmpFile.Name()
logger := zap.NewNop()
manager := mcp.NewExternalMCPManager(logger)
cfg := &config.Config{
ExternalMCP: config.ExternalMCPConfig{
Servers: make(map[string]config.ExternalMCPServerConfig),
},
}
handler := NewExternalMCPHandler(manager, cfg, configPath, logger)
api := router.Group("/api")
api.GET("/external-mcp", handler.GetExternalMCPs)
api.GET("/external-mcp/stats", handler.GetExternalMCPStats)
api.GET("/external-mcp/:name", handler.GetExternalMCP)
api.PUT("/external-mcp/:name", handler.AddOrUpdateExternalMCP)
api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP)
api.POST("/external-mcp/:name/start", handler.StartExternalMCP)
api.POST("/external-mcp/:name/stop", handler.StopExternalMCP)
return router, handler, configPath
}
func cleanupTestConfig(configPath string) {
os.Remove(configPath)
os.Remove(configPath + ".backup")
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加stdio模式的配置(官方格式:有 command 时 type 可省略)
configJSON := `{
"command": "python3",
"args": ["/path/to/script.py", "--server", "http://example.com"],
"description": "Test stdio MCP",
"timeout": 300,
"external_mcp_enable": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Command != "python3" {
t.Errorf("期望command为python3,实际%s", response.Config.Command)
}
if len(response.Config.Args) != 3 {
t.Errorf("期望args长度为3,实际%d", len(response.Config.Args))
}
if response.Config.Description != "Test stdio MCP" {
t.Errorf("期望description为'Test stdio MCP',实际%s", response.Config.Description)
}
if response.Config.Timeout != 300 {
t.Errorf("期望timeout为300,实际%d", response.Config.Timeout)
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加HTTP模式的配置(使用官方 type 字段)
configJSON := `{
"type": "http",
"url": "http://127.0.0.1:8081/mcp",
"external_mcp_enable": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Type != "http" {
t.Errorf("期望type为http,实际%s", response.Config.Type)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
testCases := []struct {
name string
configJSON string
expectedErr string
}{
{
name: "缺少command和url",
configJSON: `{"external_mcp_enable": true}`,
expectedErr: "需要指定 commandstdio模式)或 url + typehttp/sse模式)",
},
{
name: "stdio模式缺少command",
configJSON: `{"args": ["test"], "external_mcp_enable": true}`,
expectedErr: "stdio模式需要command",
},
{
name: "http模式缺少url",
configJSON: `{"type": "http", "external_mcp_enable": true}`,
expectedErr: "HTTP模式需要 url",
},
{
name: "无效的type",
configJSON: `{"type": "invalid", "external_mcp_enable": true}`,
expectedErr: "不支持的传输模式",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
errorMsg := response["error"].(string)
// 对于stdio模式缺少command的情况,错误信息可能略有不同
if tc.name == "stdio模式缺少command" {
if !strings.Contains(errorMsg, "stdio") && !strings.Contains(errorMsg, "command") {
t.Errorf("期望错误信息包含'stdio'或'command',实际'%s'", errorMsg)
}
} else if !strings.Contains(errorMsg, tc.expectedErr) {
t.Errorf("期望错误信息包含'%s',实际'%s'", tc.expectedErr, errorMsg)
}
})
}
}
func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加一个配置
configObj := config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: true,
}
handler.manager.AddOrUpdateConfig("test-delete", configObj)
// 删除配置
req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已删除
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusNotFound {
t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String())
}
}
func TestExternalMCPStatusError(t *testing.T) {
manager := mcp.NewExternalMCPManager(zap.NewNop())
if got := externalMCPStatusError(manager, "x", "connected"); got != "" {
t.Fatalf("connected status should not return error, got %q", got)
}
if got := externalMCPStatusError(manager, "x", "connecting"); got != "" {
t.Fatalf("connecting status should not return error, got %q", got)
}
}
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: false,
})
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
servers := response["servers"].(map[string]interface{})
if len(servers) != 2 {
t.Errorf("期望2个服务器,实际%d", len(servers))
}
if _, ok := servers["test1"]; !ok {
t.Error("期望包含test1")
}
if _, ok := servers["test2"]; !ok {
t.Error("期望包含test2")
}
stats := response["stats"].(map[string]interface{})
if int(stats["total"].(float64)) != 2 {
t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64)))
}
}
func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
})
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if int(stats["total"].(float64)) != 3 {
t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64)))
}
if int(stats["enabled"].(float64)) != 2 {
t.Errorf("期望启用数为2,实际%d", int(stats["enabled"].(float64)))
}
if int(stats["disabled"].(float64)) != 1 {
t.Errorf("期望停用数为1,实际%d", int(stats["disabled"].(float64)))
}
}
func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3",
})
// 测试启动(可能会失败,因为没有真实的服务器)
req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 启动可能会失败,但应该返回合理的状态码
if w.Code != http.StatusOK {
// 如果启动失败,应该是400或500
if w.Code != http.StatusBadRequest && w.Code != http.StatusInternalServerError {
t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String())
}
}
// 测试停止
req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
}
func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
router, _, _ := setupTestRouter()
req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 删除不存在的配置可能返回200(幂等操作)或404,都是合理的
if w.Code != http.StatusNotFound && w.Code != http.StatusOK {
t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 空名称应该返回404或400
if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest {
t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
router, _, _ := setupTestRouter()
// 发送无效的JSON
body := []byte(`{"config": invalid json}`)
req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加配置
config1 := config.ExternalMCPServerConfig{
Command: "python3",
ExternalMCPEnable: true,
}
handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置
config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: config2,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已更新
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
if response.Config.Command != "" {
t.Errorf("期望command为空,实际%s", response.Config.Command)
}
}
+467
View File
@@ -0,0 +1,467 @@
package handler
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"cyberstrike-ai/internal/config"
openaiClient "cyberstrike-ai/internal/openai"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type FofaHandler struct {
cfg *config.Config
logger *zap.Logger
client *http.Client
openAIClient *openaiClient.Client
}
func NewFofaHandler(cfg *config.Config, logger *zap.Logger) *FofaHandler {
// LLM 请求通常比 FOFA 查询更慢一点,单独给一个更宽松的超时。
llmHTTPClient := &http.Client{Timeout: 2 * time.Minute}
var llmCfg *config.OpenAIConfig
if cfg != nil {
llmCfg = &cfg.OpenAI
}
return &FofaHandler{
cfg: cfg,
logger: logger,
client: &http.Client{Timeout: 30 * time.Second},
openAIClient: openaiClient.NewClient(llmCfg, llmHTTPClient, logger),
}
}
type fofaSearchRequest struct {
Query string `json:"query" binding:"required"`
Size int `json:"size,omitempty"`
Page int `json:"page,omitempty"`
Fields string `json:"fields,omitempty"`
Full bool `json:"full,omitempty"`
}
type fofaParseRequest struct {
Text string `json:"text" binding:"required"`
}
type fofaParseResponse struct {
Query string `json:"query"`
Explanation string `json:"explanation,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
type fofaAPIResponse struct {
Error bool `json:"error"`
ErrMsg string `json:"errmsg"`
Size int `json:"size"`
Page int `json:"page"`
Total int `json:"total"`
Mode string `json:"mode"`
Query string `json:"query"`
Results [][]interface{} `json:"results"`
}
type fofaSearchResponse struct {
Query string `json:"query"`
Size int `json:"size"`
Page int `json:"page"`
Total int `json:"total"`
Fields []string `json:"fields"`
ResultsCount int `json:"results_count"`
Results []map[string]interface{} `json:"results"`
}
func (h *FofaHandler) resolveCredentials() (email, apiKey string) {
// 优先环境变量(便于容器部署),其次配置文件
email = strings.TrimSpace(os.Getenv("FOFA_EMAIL"))
apiKey = strings.TrimSpace(os.Getenv("FOFA_API_KEY"))
if email != "" && apiKey != "" {
return email, apiKey
}
if h.cfg != nil {
if email == "" {
email = strings.TrimSpace(h.cfg.FOFA.Email)
}
if apiKey == "" {
apiKey = strings.TrimSpace(h.cfg.FOFA.APIKey)
}
}
return email, apiKey
}
func (h *FofaHandler) resolveBaseURL() string {
if h.cfg != nil {
if v := strings.TrimSpace(h.cfg.FOFA.BaseURL); v != "" {
return v
}
}
return "https://fofa.info/api/v1/search/all"
}
// ParseNaturalLanguage 将自然语言解析为 FOFA 查询语法(仅生成,不执行查询)
func (h *FofaHandler) ParseNaturalLanguage(c *gin.Context) {
var req fofaParseRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
req.Text = strings.TrimSpace(req.Text)
if req.Text == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "text 不能为空"})
return
}
if h.cfg == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "系统配置未初始化"})
return
}
if strings.TrimSpace(h.cfg.OpenAI.APIKey) == "" || strings.TrimSpace(h.cfg.OpenAI.Model) == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "未配置 AI 模型:请在系统设置中填写 openai.api_key 与 openai.model(支持 OpenAI 兼容 API,如 DeepSeek",
"need": []string{"openai.api_key", "openai.model"},
})
return
}
if h.openAIClient == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "AI 客户端未初始化"})
return
}
systemPrompt := strings.TrimSpace(`
你是“FOFA 查询语法生成器”。任务:把用户输入的自然语言搜索意图,转换成 FOFA 查询语法。
输出要求(非常重要):
1) 只输出 JSON(不要 markdown、不要代码块、不要额外解释文本)
2) JSON 结构必须是:
{
"query": "stringFOFA查询语法(可直接粘贴到 FOFA 或本系统查询框)",
"explanation": "string,可选,解释你如何映射字段/逻辑",
"warnings": ["string"...] 可选,列出歧义/风险/需要人工确认的点
}
3) 如果用户输入本身已经是 FOFA 查询语法(或非常接近 FOFA 语法的表达式),应当“原样返回”为 query:
- 不要擅自改写字段名、操作符、括号结构
- 不要改写任何字符串值(尤其是地理位置类值),不要做缩写/同义词替换/翻译/音译
查询语法要点(来自 FOFA 语法参考):
- 逻辑连接符:&&(与)、||(或),必要时用 () 包住子表达式以确认优先级(括号优先级最高)
- 当同一层级同时出现 && 与 ||(混用)时,用 () 明确优先级(避免歧义)
- 比较/匹配:
- = 匹配;当字段="" 时,可查询“不存在该字段”或“值为空”的情况
- == 完全匹配;当字段=="" 时,可查询“字段存在且值为空”的情况
- != 不匹配;当字段!="" 时,可查询“值不为空”的情况
- *= 模糊匹配;可使用 * 或 ? 进行搜索
- 直接输入关键词(不带字段)会在标题、HTML内容、HTTP头、URL字段中搜索;但当意图明确时优先用字段表达(更可控、更准确)
字段示例速查(来自用户提供的案例,可直接套用/拼接):
- 高级搜索操作符示例:
- title="beijing" (= 匹配)
- title=="" (== 完全匹配,字段存在且值为空)
- title="" (= 匹配,可能表示字段不存在或值为空)
- title!="" (!= 不匹配,可用于值不为空)
- title*="*Home*" (*= 模糊匹配,用 * 或 ?)
- (app="Apache" || app="Nginx") && country="CN" (混用 && / || 时用括号)
- 基础类(General):
- ip="1.1.1.1"
- ip="220.181.111.1/24"
- ip="2600:9000:202a:2600:18:4ab7:f600:93a1"
- port="6379"
- domain="qq.com"
- host=".fofa.info"
- os="centos"
- server="Microsoft-IIS/10"
- asn="19551"
- org="LLC Baxet"
- is_domain=true / is_domain=false
- is_ipv6=true / is_ipv6=false
- 标记类(Special Label):
- app="Microsoft-Exchange"
- fid="sSXXGNUO2FefBTcCLIT/2Q=="
- product="NGINX"
- product="Roundcube-Webmail" && product.version="1.6.10"
- category="服务"
- type="service" / type="subdomain"
- cloud_name="Aliyundun"
- is_cloud=true / is_cloud=false
- is_fraud=true / is_fraud=false
- is_honeypot=true / is_honeypot=false
- 协议类(type=service):
- protocol="quic"
- banner="users"
- banner_hash="7330105010150477363"
- banner_fid="zRpqmn0FXQRjZpH8MjMX55zpMy9SgsW8"
- base_protocol="udp" / base_protocol="tcp"
- 网站类(type=subdomain):
- title="beijing"
- header="elastic"
- header_hash="1258854265"
- body="网络空间测绘"
- body_hash="-2090962452"
- js_name="js/jquery.js"
- js_md5="82ac3f14327a8b7ba49baa208d4eaa15"
- cname="customers.spektrix.com"
- cname_domain="siteforce.com"
- icon_hash="-247388890"
- status_code="402"
- icp="京ICP证030173号"
- sdk_hash="Are3qNnP2Eqn7q5kAoUO3l+w3mgVIytO"
- 地理位置(Location):
- country="CN" 或 country="中国"
- region="Zhejiang" 或 region="浙江"(仅支持中国地区中文)
- city="Hangzhou"
- 证书类(Certificate):
- cert="baidu"
- cert.subject="Oracle Corporation"
- cert.issuer="DigiCert"
- cert.subject.org="Oracle Corporation"
- cert.subject.cn="baidu.com"
- cert.issuer.org="cPanel, Inc."
- cert.issuer.cn="Synology Inc. CA"
- cert.domain="huawei.com"
- cert.is_equal=true / cert.is_equal=false
- cert.is_valid=true / cert.is_valid=false
- cert.is_match=true / cert.is_match=false
- cert.is_expired=true / cert.is_expired=false
- jarm="2ad2ad0002ad2ad22c2ad2ad2ad2ad2eac92ec34bcc0cf7520e97547f83e81"
- tls.version="TLS 1.3"
- tls.ja3s="15af977ce25de452b96affa2addb1036"
- cert.sn="356078156165546797850343536942784588840297"
- cert.not_after.after="2025-03-01" / cert.not_after.before="2025-03-01"
- cert.not_before.after="2025-03-01" / cert.not_before.before="2025-03-01"
- 时间类(Last update time):
- after="2023-01-01"
- before="2023-12-01"
- after="2023-01-01" && before="2023-12-01"
- 独立IP语法(需配合 ip_filter / ip_exclude):
- ip_filter(banner="SSH-2.0-OpenSSH_6.7p2") && ip_filter(icon_hash="-1057022626")
- ip_filter(banner="SSH-2.0-OpenSSH_6.7p2" && asn="3462") && ip_exclude(title="EdgeOS")
- port_size="6" / port_size_gt="6" / port_size_lt="12"
- ip_ports="80,161"
- ip_country="CN"
- ip_region="Zhejiang"
- ip_city="Hangzhou"
- ip_after="2021-03-18"
- ip_before="2019-09-09"
生成约束与注意事项:
- 字符串值一律用英文双引号包裹,例如 title="登录"、country="CN"
- 字符串值保持字面一致:不要缩写(例如 city="beijing" 不要变成 city="BJ"),不要用别名(例如 Beijing/Peking),不要擅自翻译/音译/改写大小写
- 地理位置字段(country/region/city)更倾向于“按用户给定值输出”;不确定合法取值时,不要猜测,把备选写进 warnings
- 不要捏造不存在的 FOFA 字段;不确定时把不确定点写进 warnings,并输出一个保守的 query
- 当用户描述里有“多个与/或条件”,优先加 () 明确优先级,例如:(app="Apache" || app="Nginx") && country="CN"
- 当用户缺少关键条件导致范围过大或歧义(如地点/协议/端口/服务类型未说明),允许 query 为空字符串,并在 warnings 里明确需要补充的信息
`)
userPrompt := fmt.Sprintf("自然语言意图:%s", req.Text)
requestBody := map[string]interface{}{
"model": h.cfg.OpenAI.Model,
"messages": []map[string]interface{}{
{"role": "system", "content": systemPrompt},
{"role": "user", "content": userPrompt},
},
"temperature": 0.1,
"max_completion_tokens": 12000,
}
// OpenAI 返回结构:只需要 choices[0].message.content
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 90*time.Second)
defer cancel()
if err := h.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
var apiErr *openaiClient.APIError
if errors.As(err, &apiErr) {
h.logger.Warn("FOFA自然语言解析:LLM返回错误", zap.Int("status", apiErr.StatusCode))
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败(上游返回非 200),请检查模型配置或稍后重试"})
return
}
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败: " + err.Error()})
return
}
if len(apiResponse.Choices) == 0 {
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 未返回有效结果"})
return
}
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
// 兼容模型偶尔返回 ```json ... ``` 的情况
content = strings.TrimPrefix(content, "```json")
content = strings.TrimPrefix(content, "```")
content = strings.TrimSuffix(content, "```")
content = strings.TrimSpace(content)
var parsed fofaParseResponse
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
// 直接回传一部分原文,方便排查,但避免太大
snippet := content
if len(snippet) > 1200 {
snippet = snippet[:1200]
}
c.JSON(http.StatusBadGateway, gin.H{
"error": "AI 返回内容无法解析为 JSON,请稍后重试或换个描述方式",
"snippet": snippet,
})
return
}
parsed.Query = strings.TrimSpace(parsed.Query)
if parsed.Query == "" {
// query 允许为空(表示需求不明确),但前端需要明确提示
if len(parsed.Warnings) == 0 {
parsed.Warnings = []string{"需求信息不足,未能生成可用的 FOFA 查询语法,请补充关键条件(如国家/端口/产品/域名等)。"}
}
}
c.JSON(http.StatusOK, parsed)
}
// Search FOFA 查询(后端代理,避免前端暴露 key)
func (h *FofaHandler) Search(c *gin.Context) {
var req fofaSearchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
req.Query = strings.TrimSpace(req.Query)
if req.Query == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "query 不能为空"})
return
}
if req.Size <= 0 {
req.Size = 100
}
if req.Page <= 0 {
req.Page = 1
}
// FOFA 接口 size 上限和账户权限相关,这里只做一个合理的保护
if req.Size > 10000 {
req.Size = 10000
}
if req.Fields == "" {
req.Fields = "host,ip,port,domain,title,protocol,country,province,city,server"
}
email, apiKey := h.resolveCredentials()
if email == "" || apiKey == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "FOFA 未配置:请在系统设置中填写 FOFA Email/API Key,或设置环境变量 FOFA_EMAIL/FOFA_API_KEY",
"need": []string{"fofa.email", "fofa.api_key"},
"env_key": []string{"FOFA_EMAIL", "FOFA_API_KEY"},
})
return
}
baseURL := h.resolveBaseURL()
qb64 := base64.StdEncoding.EncodeToString([]byte(req.Query))
u, err := url.Parse(baseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "FOFA base_url 无效: " + err.Error()})
return
}
params := u.Query()
params.Set("email", email)
params.Set("key", apiKey)
params.Set("qbase64", qb64)
params.Set("size", fmt.Sprintf("%d", req.Size))
params.Set("page", fmt.Sprintf("%d", req.Page))
params.Set("fields", strings.TrimSpace(req.Fields))
if req.Full {
params.Set("full", "true")
} else {
// 明确传 false,便于排查
params.Set("full", "false")
}
u.RawQuery = params.Encode()
httpReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, u.String(), nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()})
return
}
resp, err := h.client.Do(httpReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "请求 FOFA 失败: " + err.Error()})
return
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("FOFA 返回非 2xx: %d", resp.StatusCode)})
return
}
var apiResp fofaAPIResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "解析 FOFA 响应失败: " + err.Error()})
return
}
if apiResp.Error {
msg := strings.TrimSpace(apiResp.ErrMsg)
if msg == "" {
msg = "FOFA 返回错误"
}
c.JSON(http.StatusBadGateway, gin.H{"error": msg})
return
}
fields := splitAndCleanCSV(req.Fields)
results := make([]map[string]interface{}, 0, len(apiResp.Results))
for _, row := range apiResp.Results {
item := make(map[string]interface{}, len(fields))
for i, f := range fields {
if i < len(row) {
item[f] = row[i]
} else {
item[f] = nil
}
}
results = append(results, item)
}
c.JSON(http.StatusOK, fofaSearchResponse{
Query: req.Query,
Size: apiResp.Size,
Page: apiResp.Page,
Total: apiResp.Total,
Fields: fields,
ResultsCount: len(results),
Results: results,
})
}
func splitAndCleanCSV(s string) []string {
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
out = append(out, v)
}
return out
}
+320
View File
@@ -0,0 +1,320 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// GroupHandler 分组处理器
type GroupHandler struct {
db *database.DB
logger *zap.Logger
}
// NewGroupHandler 创建新的分组处理器
func NewGroupHandler(db *database.DB, logger *zap.Logger) *GroupHandler {
return &GroupHandler{
db: db,
logger: logger,
}
}
// CreateGroupRequest 创建分组请求
type CreateGroupRequest struct {
Name string `json:"name"`
Icon string `json:"icon"`
}
// CreateGroup 创建分组
func (h *GroupHandler) CreateGroup(c *gin.Context) {
var req CreateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"})
return
}
group, err := h.db.CreateGroup(req.Name, req.Icon)
if err != nil {
h.logger.Error("创建分组失败", zap.Error(err))
// 如果是名称重复错误,返回400状态码
if err.Error() == "分组名称已存在" {
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, group)
}
// ListGroups 列出所有分组
func (h *GroupHandler) ListGroups(c *gin.Context) {
groups, err := h.db.ListGroups()
if err != nil {
h.logger.Error("获取分组列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, groups)
}
// GetGroup 获取分组
func (h *GroupHandler) GetGroup(c *gin.Context) {
id := c.Param("id")
group, err := h.db.GetGroup(id)
if err != nil {
h.logger.Error("获取分组失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "分组不存在"})
return
}
c.JSON(http.StatusOK, group)
}
// UpdateGroupRequest 更新分组请求
type UpdateGroupRequest struct {
Name string `json:"name"`
Icon string `json:"icon"`
}
// UpdateGroup 更新分组
func (h *GroupHandler) UpdateGroup(c *gin.Context) {
id := c.Param("id")
var req UpdateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"})
return
}
if err := h.db.UpdateGroup(id, req.Name, req.Icon); err != nil {
h.logger.Error("更新分组失败", zap.Error(err))
// 如果是名称重复错误,返回400状态码
if err.Error() == "分组名称已存在" {
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
group, err := h.db.GetGroup(id)
if err != nil {
h.logger.Error("获取更新后的分组失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, group)
}
// DeleteGroup 删除分组
func (h *GroupHandler) DeleteGroup(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteGroup(id); err != nil {
h.logger.Error("删除分组失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// AddConversationToGroupRequest 添加对话到分组请求
type AddConversationToGroupRequest struct {
ConversationID string `json:"conversationId"`
GroupID string `json:"groupId"`
}
// AddConversationToGroup 将对话添加到分组
func (h *GroupHandler) AddConversationToGroup(c *gin.Context) {
var req AddConversationToGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.AddConversationToGroup(req.ConversationID, req.GroupID); err != nil {
h.logger.Error("添加对话到分组失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "添加成功"})
}
// RemoveConversationFromGroup 从分组中移除对话
func (h *GroupHandler) RemoveConversationFromGroup(c *gin.Context) {
conversationID := c.Param("conversationId")
groupID := c.Param("id")
if err := h.db.RemoveConversationFromGroup(conversationID, groupID); err != nil {
h.logger.Error("从分组中移除对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "移除成功"})
}
// GroupConversation 分组对话响应结构
type GroupConversation struct {
ID string `json:"id"`
Title string `json:"title"`
Pinned bool `json:"pinned"`
GroupPinned bool `json:"groupPinned"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// GetGroupConversations 获取分组中的所有对话
func (h *GroupHandler) GetGroupConversations(c *gin.Context) {
groupID := c.Param("id")
searchQuery := c.Query("search") // 获取搜索参数
var conversations []*database.Conversation
var err error
// 如果有搜索关键词,使用搜索方法;否则使用普通方法
if searchQuery != "" {
conversations, err = h.db.SearchConversationsByGroup(groupID, searchQuery)
} else {
conversations, err = h.db.GetConversationsByGroup(groupID)
}
if err != nil {
h.logger.Error("获取分组对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 获取每个对话在分组中的置顶状态
groupConvs := make([]GroupConversation, 0, len(conversations))
for _, conv := range conversations {
// 查询分组内置顶状态
var groupPinned int
err := h.db.QueryRow(
"SELECT COALESCE(pinned, 0) FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?",
conv.ID, groupID,
).Scan(&groupPinned)
if err != nil {
h.logger.Warn("查询分组内置顶状态失败", zap.String("conversationId", conv.ID), zap.Error(err))
groupPinned = 0
}
groupConvs = append(groupConvs, GroupConversation{
ID: conv.ID,
Title: conv.Title,
Pinned: conv.Pinned,
GroupPinned: groupPinned != 0,
CreatedAt: conv.CreatedAt,
UpdatedAt: conv.UpdatedAt,
})
}
c.JSON(http.StatusOK, groupConvs)
}
// GetAllMappings 批量获取所有分组映射(消除前端 N+1 请求)
func (h *GroupHandler) GetAllMappings(c *gin.Context) {
mappings, err := h.db.GetAllGroupMappings()
if err != nil {
h.logger.Error("获取分组映射失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, mappings)
}
// UpdateConversationPinnedRequest 更新对话置顶状态请求
type UpdateConversationPinnedRequest struct {
Pinned bool `json:"pinned"`
}
// UpdateConversationPinned 更新对话置顶状态
func (h *GroupHandler) UpdateConversationPinned(c *gin.Context) {
conversationID := c.Param("id")
var req UpdateConversationPinnedRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.UpdateConversationPinned(conversationID, req.Pinned); err != nil {
h.logger.Error("更新对话置顶状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
}
// UpdateGroupPinnedRequest 更新分组置顶状态请求
type UpdateGroupPinnedRequest struct {
Pinned bool `json:"pinned"`
}
// UpdateGroupPinned 更新分组置顶状态
func (h *GroupHandler) UpdateGroupPinned(c *gin.Context) {
groupID := c.Param("id")
var req UpdateGroupPinnedRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.UpdateGroupPinned(groupID, req.Pinned); err != nil {
h.logger.Error("更新分组置顶状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
}
// UpdateConversationPinnedInGroupRequest 更新分组对话置顶状态请求
type UpdateConversationPinnedInGroupRequest struct {
Pinned bool `json:"pinned"`
}
// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态
func (h *GroupHandler) UpdateConversationPinnedInGroup(c *gin.Context) {
groupID := c.Param("id")
conversationID := c.Param("conversationId")
var req UpdateConversationPinnedInGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.UpdateConversationPinnedInGroup(conversationID, groupID, req.Pinned); err != nil {
h.logger.Error("更新分组对话置顶状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
}
+792
View File
@@ -0,0 +1,792 @@
package handler
import (
"context"
"database/sql"
"encoding/json"
"errors"
"math"
"net/http"
"strconv"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/multiagent"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
type hitlRuntimeConfig struct {
Enabled bool
Mode string
SensitiveTools map[string]struct{}
Timeout time.Duration
}
type hitlDecision struct {
Decision string
Comment string
EditedArguments map[string]interface{}
}
type pendingInterrupt struct {
ConversationID string
InterruptID string
Mode string
ToolName string
ToolCallID string
decideCh chan hitlDecision
}
type HITLManager struct {
db *database.DB
logger *zap.Logger
mu sync.RWMutex
runtime map[string]hitlRuntimeConfig
pending map[string]*pendingInterrupt
}
func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager {
return &HITLManager{
db: db,
logger: logger,
runtime: make(map[string]hitlRuntimeConfig),
pending: make(map[string]*pendingInterrupt),
}
}
func (m *HITLManager) EnsureSchema() error {
if _, err := m.db.Exec(`
CREATE TABLE IF NOT EXISTS hitl_interrupts (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
message_id TEXT,
mode TEXT NOT NULL,
tool_name TEXT NOT NULL,
tool_call_id TEXT,
payload TEXT,
status TEXT NOT NULL,
decision TEXT,
decision_comment TEXT,
created_at DATETIME NOT NULL,
decided_at DATETIME
);`); err != nil {
return err
}
_, err := m.db.Exec(`
CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
conversation_id TEXT PRIMARY KEY,
enabled INTEGER NOT NULL DEFAULT 0,
mode TEXT NOT NULL DEFAULT 'off',
sensitive_tools TEXT NOT NULL DEFAULT '[]',
timeout_seconds INTEGER NOT NULL DEFAULT 0,
updated_at DATETIME NOT NULL
);`)
if err != nil {
return err
}
// On startup, cancel all orphaned pending interrupts from previous process.
// Their in-memory channels are gone, so they can never be resolved.
res, err := m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
decision_comment='process restarted', decided_at=CURRENT_TIMESTAMP WHERE status='pending'`)
if err != nil {
m.logger.Warn("failed to cancel orphaned HITL interrupts", zap.Error(err))
} else if n, _ := res.RowsAffected(); n > 0 {
m.logger.Info("cancelled orphaned HITL interrupts from previous process", zap.Int64("count", n))
}
return nil
}
func normalizeHitlMode(mode string) string {
v := strings.ToLower(strings.TrimSpace(mode))
if v == "" {
return "approval"
}
switch v {
case "off":
return "off"
case "feedback", "followup":
return "approval"
case "approval", "review_edit":
return v
default:
return "approval"
}
}
func (m *HITLManager) ActivateConversation(conversationID string, req *HITLRequest) {
if req == nil || !req.Enabled {
m.DeactivateConversation(conversationID)
return
}
tools := make(map[string]struct{})
for _, t := range req.SensitiveTools {
n := strings.ToLower(strings.TrimSpace(t))
if n != "" {
tools[n] = struct{}{}
}
}
// timeout <= 0 means wait forever (no timeout).
timeout := time.Duration(0)
if req.TimeoutSeconds > 0 {
timeout = time.Duration(req.TimeoutSeconds) * time.Second
}
m.mu.Lock()
m.runtime[conversationID] = hitlRuntimeConfig{
Enabled: true,
Mode: normalizeHitlMode(req.Mode),
SensitiveTools: tools,
Timeout: timeout,
}
m.mu.Unlock()
}
func (m *HITLManager) DeactivateConversation(conversationID string) {
m.mu.Lock()
delete(m.runtime, conversationID)
m.mu.Unlock()
}
// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空)。
func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
if h == nil || h.config == nil {
return nil
}
raw := h.config.Hitl.ToolWhitelist
if len(raw) == 0 {
return nil
}
seen := make(map[string]struct{})
out := make([]string, 0, len(raw))
for _, t := range raw {
n := strings.ToLower(strings.TrimSpace(t))
if n == "" {
continue
}
if _, ok := seen[n]; ok {
continue
}
seen[n] = struct{}{}
out = append(out, strings.TrimSpace(t))
}
return out
}
// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单合并(并集),仅用于运行时 Activate;不写入数据库。
func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest {
gw := h.hitlConfigGlobalToolWhitelist()
if len(gw) == 0 {
return req
}
if req == nil {
return nil
}
seen := make(map[string]struct{})
union := make([]string, 0, len(gw)+len(req.SensitiveTools))
for _, t := range gw {
n := strings.ToLower(strings.TrimSpace(t))
if n == "" {
continue
}
if _, ok := seen[n]; ok {
continue
}
seen[n] = struct{}{}
union = append(union, strings.TrimSpace(t))
}
for _, t := range req.SensitiveTools {
n := strings.ToLower(strings.TrimSpace(t))
if n == "" {
continue
}
if _, ok := seen[n]; ok {
continue
}
seen[n] = struct{}{}
union = append(union, strings.TrimSpace(t))
}
out := *req
out.SensitiveTools = union
return &out
}
func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRuntimeConfig, bool) {
m.mu.RLock()
cfg, ok := m.runtime[conversationID]
m.mu.RUnlock()
if !ok || !cfg.Enabled {
return hitlRuntimeConfig{}, false
}
// 语义:SensitiveTools 现在作为“白名单(免审批工具)”
// 空白名单 => 全部工具都需要审批
if len(cfg.SensitiveTools) == 0 {
return cfg, true
}
_, inWhitelist := cfg.SensitiveTools[strings.ToLower(strings.TrimSpace(toolName))]
return cfg, !inWhitelist
}
// NeedsToolApproval 与 Agent 工具层 shouldInterrupt 语义一致:仅当该会话已开启人机协同且工具不在免审批白名单时为 true。
func (m *HITLManager) NeedsToolApproval(conversationID, toolName string) bool {
if m == nil {
return false
}
_, need := m.shouldInterrupt(conversationID, toolName)
return need
}
func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) {
now := time.Now()
id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "")
if _, err := m.db.Exec(`INSERT INTO hitl_interrupts
(id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)`,
id, conversationID, assistantMessageID, mode, toolName, toolCallID, payload, now); err != nil {
return nil, err
}
// 刷新页面后侧栏依赖 DB 配置;若仅内存 Activate 未落库,会导致「有待审批却显示关闭」
_ = m.ensureConversationHITLModePersisted(conversationID, mode)
p := &pendingInterrupt{
ConversationID: conversationID,
InterruptID: id,
Mode: normalizeHitlMode(mode),
ToolName: toolName,
ToolCallID: toolCallID,
decideCh: make(chan hitlDecision, 1),
}
m.mu.Lock()
m.pending[id] = p
m.mu.Unlock()
return p, nil
}
// ensureConversationHITLModePersisted 在产生待审批时把 mode 写入 hitl_conversation_configs,避免刷新后 GET 配置仍为关闭。
func (m *HITLManager) ensureConversationHITLModePersisted(conversationID, interruptMode string) error {
if strings.TrimSpace(conversationID) == "" {
return nil
}
nm := normalizeHitlMode(interruptMode)
if nm == "off" {
return nil
}
cfg, err := m.LoadConversationConfig(conversationID)
if err != nil {
return err
}
if cfg.Enabled && normalizeHitlMode(cfg.Mode) == nm {
return nil
}
cfg.Enabled = true
cfg.Mode = nm
if cfg.TimeoutSeconds < 0 {
cfg.TimeoutSeconds = 0
}
return m.SaveConversationConfig(conversationID, cfg)
}
// PendingHITLInterruptMode 返回该会话最新一条 pending 中断的协同模式(用于 GET 配置时与库内「关闭」状态对齐)。
func (m *HITLManager) PendingHITLInterruptMode(conversationID string) (string, bool) {
if strings.TrimSpace(conversationID) == "" {
return "", false
}
var mode string
err := m.db.QueryRow(`SELECT mode FROM hitl_interrupts WHERE conversation_id = ? AND status = 'pending' ORDER BY created_at DESC LIMIT 1`, conversationID).
Scan(&mode)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false
}
return "", false
}
mode = strings.TrimSpace(mode)
if mode == "" {
return "", false
}
return mode, true
}
func hitlStoredConfigEffective(cfg *HITLRequest) bool {
if cfg == nil {
return false
}
if cfg.Enabled {
return true
}
return normalizeHitlMode(cfg.Mode) != "off"
}
func (m *HITLManager) ResolveInterrupt(interruptID, decision, comment string, editedArguments map[string]interface{}) error {
decision = strings.ToLower(strings.TrimSpace(decision))
if decision != "approve" && decision != "reject" {
return errors.New("decision must be approve/reject")
}
m.mu.RLock()
p, ok := m.pending[interruptID]
m.mu.RUnlock()
if !ok {
return errors.New("interrupt not found or already resolved")
}
d := hitlDecision{
Decision: decision,
Comment: strings.TrimSpace(comment),
EditedArguments: editedArguments,
}
select {
case p.decideCh <- d:
return nil
default:
return errors.New("interrupt already resolved or decision channel busy")
}
}
func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLRequest) error {
if strings.TrimSpace(conversationID) == "" {
return errors.New("conversationId is required")
}
if req == nil {
req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 0}
}
mode := normalizeHitlMode(req.Mode)
if !req.Enabled {
mode = "off"
}
tools, _ := json.Marshal(req.SensitiveTools)
timeout := req.TimeoutSeconds
if timeout < 0 {
timeout = 0
}
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(conversation_id) DO UPDATE SET
enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now())
return err
}
func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) {
var enabledInt int
var mode, toolsJSON string
var timeout int
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
Scan(&enabledInt, &mode, &toolsJSON, &timeout)
if errors.Is(err, sql.ErrNoRows) {
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
}
if err != nil {
return nil, err
}
if timeout < 0 {
timeout = 0
}
tools := make([]string, 0)
_ = json.Unmarshal([]byte(toolsJSON), &tools)
return &HITLRequest{
Enabled: enabledInt == 1,
Mode: mode,
SensitiveTools: tools,
TimeoutSeconds: timeout,
}, nil
}
func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, timeout time.Duration) (hitlDecision, error) {
defer func() {
m.mu.Lock()
delete(m.pending, p.InterruptID)
m.mu.Unlock()
}()
var timeoutCh <-chan time.Time
if timeout > 0 {
timer := time.NewTimer(timeout)
defer timer.Stop()
timeoutCh = timer.C
}
select {
case d := <-p.decideCh:
// 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments
if p.Mode != "review_edit" && len(d.EditedArguments) > 0 {
d.EditedArguments = nil
}
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
d.Decision, d.Comment, time.Now(), p.InterruptID)
return d, nil
case <-timeoutCh:
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
time.Now(), p.InterruptID)
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
case <-ctx.Done():
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`,
time.Now(), p.InterruptID)
return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err()
}
}
func (h *AgentHandler) activateHITLForConversation(conversationID string, req *HITLRequest) {
if h.hitlManager == nil {
return
}
if req == nil {
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
if err == nil {
req = cfg
}
}
h.hitlManager.ActivateConversation(conversationID, h.hitlRequestWithMergedConfigWhitelist(req))
}
func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID, toolName, toolCallID string, payload map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) (*hitlDecision, error) {
cfg, need := h.hitlManager.shouldInterrupt(conversationID, toolName)
if !need {
return nil, nil
}
payloadRaw, _ := json.Marshal(payload)
p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw))
if err != nil {
h.logger.Warn("创建 HITL 中断失败", zap.Error(err))
return nil, err
}
if sendEventFunc != nil {
sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{
"conversationId": conversationID,
"interruptId": p.InterruptID,
"mode": cfg.Mode,
"toolName": toolName,
"toolCallId": toolCallID,
"payload": payload,
})
}
d, waitErr := h.hitlManager.waitDecision(runCtx, p, cfg.Timeout)
if waitErr != nil {
if cancelRun != nil && (errors.Is(waitErr, context.Canceled) || errors.Is(waitErr, context.DeadlineExceeded)) {
cause := context.Cause(runCtx)
switch {
case errors.Is(cause, ErrTaskCancelled):
cancelRun(ErrTaskCancelled)
case cause != nil:
cancelRun(cause)
case errors.Is(waitErr, context.DeadlineExceeded):
cancelRun(context.DeadlineExceeded)
default:
cancelRun(ErrTaskCancelled)
}
}
return nil, waitErr
}
if d.Decision == "reject" {
if sendEventFunc != nil {
sendEventFunc("hitl_rejected", "人工拒绝本次工具调用,模型将基于反馈继续迭代", map[string]interface{}{
"conversationId": conversationID,
"interruptId": p.InterruptID,
"toolName": toolName,
"comment": d.Comment,
})
}
return &d, nil
}
if sendEventFunc != nil {
sendEventFunc("hitl_resumed", "人工确认通过,继续执行", map[string]interface{}{
"conversationId": conversationID,
"interruptId": p.InterruptID,
"toolName": toolName,
"comment": d.Comment,
"editedArgs": d.EditedArguments,
})
}
return &d, nil
}
func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, data map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) {
if h.hitlManager == nil {
return
}
toolName, _ := data["toolName"].(string)
toolCallID, _ := data["toolCallId"].(string)
d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, toolCallID, data, sendEventFunc)
if err != nil || d == nil {
return
}
if len(d.EditedArguments) > 0 {
if argsObj, ok := data["argumentsObj"].(map[string]interface{}); ok {
for k := range argsObj {
delete(argsObj, k)
}
for k, v := range d.EditedArguments {
argsObj[k] = v
}
if b, mErr := json.Marshal(argsObj); mErr == nil {
data["arguments"] = string(b)
}
}
}
}
func (h *AgentHandler) ListHITLPending(c *gin.Context) {
conversationID := strings.TrimSpace(c.Query("conversationId"))
status := strings.TrimSpace(c.Query("status"))
if status == "" {
status = "pending"
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
if page < 1 {
page = 1
}
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
offset := (page - 1) * pageSize
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1`
args := []interface{}{}
if conversationID != "" {
q += " AND conversation_id = ?"
args = append(args, conversationID)
}
if status != "all" {
q += " AND status = ?"
args = append(args, status)
}
q += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, pageSize, offset)
rows, err := h.db.Query(q, args...)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
defer rows.Close()
items := make([]map[string]interface{}, 0)
for rows.Next() {
var id, cid, mode, toolName, toolCallID, payload, rowStatus string
var messageID sql.NullString
var decision, comment sql.NullString
var createdAt time.Time
var decidedAt sql.NullTime
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil {
continue
}
msgID := ""
if messageID.Valid {
msgID = messageID.String
}
items = append(items, map[string]interface{}{
"id": id,
"conversationId": cid,
"messageId": msgID,
"mode": mode,
"toolName": toolName,
"toolCallId": toolCallID,
"payload": payload,
"status": rowStatus,
"decision": decision.String,
"comment": comment.String,
"createdAt": createdAt,
"decidedAt": func() interface{} {
if decidedAt.Valid {
return decidedAt.Time
}
return nil
}(),
})
}
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize})
}
type hitlDecisionReq struct {
InterruptID string `json:"interruptId" binding:"required"`
Decision string `json:"decision" binding:"required"`
Comment string `json:"comment,omitempty"`
EditedArguments map[string]interface{} `json:"editedArguments,omitempty"`
}
func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
var req hitlDecisionReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
if h.hitlManager == nil {
c.JSON(500, gin.H{"error": "hitl manager unavailable"})
return
}
if err := h.hitlManager.ResolveInterrupt(req.InterruptID, req.Decision, req.Comment, req.EditedArguments); err != nil {
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
"decision": req.Decision,
})
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
func (h *AgentHandler) DismissHITLInterrupt(c *gin.Context) {
var req struct {
InterruptID string `json:"interruptId" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
if h.hitlManager == nil {
c.JSON(500, gin.H{"error": "hitl manager unavailable"})
return
}
res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP
WHERE id=? AND status='pending'`, req.InterruptID)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
n, _ := res.RowsAffected()
if n == 0 {
c.JSON(404, gin.H{"error": "interrupt not found or already resolved"})
return
}
// Also drain from in-memory map if present
h.hitlManager.mu.Lock()
if p, ok := h.hitlManager.pending[req.InterruptID]; ok {
delete(h.hitlManager.pending, req.InterruptID)
select {
case p.decideCh <- hitlDecision{Decision: "reject", Comment: "dismissed by user"}:
default:
}
}
h.hitlManager.mu.Unlock()
c.JSON(http.StatusOK, gin.H{"ok": true})
}
func (h *AgentHandler) interceptHITLForEinoTool(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{}), toolName, arguments string) (string, error) {
payload := map[string]interface{}{
"toolName": toolName,
"arguments": arguments,
"source": "eino_middleware",
"toolCallId": "",
}
var argsObj map[string]interface{}
if strings.TrimSpace(arguments) != "" {
_ = json.Unmarshal([]byte(arguments), &argsObj)
if argsObj != nil {
payload["argumentsObj"] = argsObj
}
}
d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, "", payload, sendEventFunc)
if err != nil || d == nil {
return arguments, err
}
if d.Decision == "reject" {
return arguments, multiagent.NewHumanRejectError(d.Comment)
}
if len(d.EditedArguments) > 0 {
edited, mErr := json.Marshal(d.EditedArguments)
if mErr == nil {
return string(edited), nil
}
}
return arguments, nil
}
type hitlConfigReq struct {
ConversationID string `json:"conversationId" binding:"required"`
HITLRequest
}
func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) {
conversationID := strings.TrimSpace(c.Param("conversationId"))
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
return
}
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !hitlStoredConfigEffective(cfg) {
if pendMode, ok := h.hitlManager.PendingHITLInterruptMode(conversationID); ok {
cfg2 := *cfg
cfg2.Enabled = true
cfg2.Mode = normalizeHitlMode(pendMode)
if cfg2.TimeoutSeconds < 0 {
cfg2.TimeoutSeconds = 0
}
cfg = &cfg2
}
}
c.JSON(http.StatusOK, gin.H{
"conversationId": conversationID,
"hitl": cfg,
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
})
}
func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) {
var req hitlConfigReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Mode = normalizeHitlMode(req.Mode)
if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.hitlWhitelistSaver != nil && len(req.SensitiveTools) > 0 {
if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil {
h.logger.Warn("HITL 会话配置已保存,但合并工具白名单到 config.yaml 失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": "会话配置已保存,但写入 config.yaml 失败: " + err.Error(),
})
return
}
}
h.hitlManager.ActivateConversation(req.ConversationID, h.hitlRequestWithMergedConfigWhitelist(&req.HITLRequest))
c.JSON(http.StatusOK, gin.H{"ok": true})
}
type mergeHitlGlobalWhitelistReq struct {
SensitiveTools []string `json:"sensitiveTools"`
}
// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。
func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) {
if h.hitlWhitelistSaver == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"})
return
}
var req mergeHitlGlobalWhitelistReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if len(req.SensitiveTools) == 0 {
c.JSON(http.StatusOK, gin.H{
"ok": true,
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
"hitlGlobalWhitelistMerged": false,
})
return
}
if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil {
h.logger.Warn("合并 HITL 工具白名单到 config.yaml 失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"ok": true,
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
"hitlGlobalWhitelistMerged": true,
})
}
func boolToInt(v bool) int {
if v {
return 1
}
return 0
}
+530
View File
@@ -0,0 +1,530 @@
package handler
import (
"context"
"fmt"
"net/http"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/knowledge"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// KnowledgeHandler 知识库处理器
type KnowledgeHandler struct {
manager *knowledge.Manager
retriever *knowledge.Retriever
indexer *knowledge.Indexer
db *database.DB
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewKnowledgeHandler 创建新的知识库处理器
func NewKnowledgeHandler(
manager *knowledge.Manager,
retriever *knowledge.Retriever,
indexer *knowledge.Indexer,
db *database.DB,
logger *zap.Logger,
) *KnowledgeHandler {
return &KnowledgeHandler{
manager: manager,
retriever: retriever,
indexer: indexer,
db: db,
logger: logger,
}
}
// GetCategories 获取所有分类
func (h *KnowledgeHandler) GetCategories(c *gin.Context) {
categories, err := h.manager.GetCategories()
if err != nil {
h.logger.Error("获取分类失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"categories": categories})
}
// GetItems 获取知识项列表(支持按分类分页和关键字搜索,默认不返回完整内容)
func (h *KnowledgeHandler) GetItems(c *gin.Context) {
category := c.Query("category")
searchKeyword := c.Query("search") // 搜索关键字
// 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索)
if searchKeyword != "" {
items, err := h.manager.SearchItemsByKeyword(searchKeyword, category)
if err != nil {
h.logger.Error("搜索知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 按分类分组结果
groupedByCategory := make(map[string][]*knowledge.KnowledgeItemSummary)
for _, item := range items {
cat := item.Category
if cat == "" {
cat = "未分类"
}
groupedByCategory[cat] = append(groupedByCategory[cat], item)
}
// 转换为 CategoryWithItems 格式
categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory))
for cat, catItems := range groupedByCategory {
categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{
Category: cat,
ItemCount: len(catItems),
Items: catItems,
})
}
// 按分类名称排序
for i := 0; i < len(categoriesWithItems)-1; i++ {
for j := i + 1; j < len(categoriesWithItems); j++ {
if categoriesWithItems[i].Category > categoriesWithItems[j].Category {
categoriesWithItems[i], categoriesWithItems[j] = categoriesWithItems[j], categoriesWithItems[i]
}
}
}
c.JSON(http.StatusOK, gin.H{
"categories": categoriesWithItems,
"total": len(categoriesWithItems),
"search": searchKeyword,
"is_search": true,
})
return
}
// 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容)
categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页
// 分页参数
limit := 50 // 默认每页 50 条(分类分页时为分类数,项分页时为项数)
offset := 0
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 {
limit = parsed
}
}
if offsetStr := c.Query("offset"); offsetStr != "" {
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
offset = parsed
}
}
// 如果指定了 category 参数,且使用分类分页模式,则只返回该分类
if category != "" && categoryPageMode {
// 单分类模式:返回该分类的所有知识项(不分页)
items, total, err := h.manager.GetItemsSummary(category, 0, 0)
if err != nil {
h.logger.Error("获取知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 包装成分类结构
categoriesWithItems := []*knowledge.CategoryWithItems{
{
Category: category,
ItemCount: total,
Items: items,
},
}
c.JSON(http.StatusOK, gin.H{
"categories": categoriesWithItems,
"total": 1, // 只有一个分类
"limit": limit,
"offset": offset,
})
return
}
if categoryPageMode {
// 按分类分页模式(默认)
// limit 表示每页分类数,推荐 5-10 个分类
if limit <= 0 || limit > 100 {
limit = 10 // 默认每页 10 个分类
}
categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset)
if err != nil {
h.logger.Error("获取分类知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"categories": categoriesWithItems,
"total": totalCategories,
"limit": limit,
"offset": offset,
})
return
}
// 按项分页模式(向后兼容)
// 是否包含完整内容(默认 false,只返回摘要)
includeContent := c.Query("includeContent") == "true"
if includeContent {
// 返回完整内容(向后兼容)
items, err := h.manager.GetItemsWithOptions(category, limit, offset, true)
if err != nil {
h.logger.Error("获取知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 获取总数
total, err := h.manager.GetItemsCount(category)
if err != nil {
h.logger.Warn("获取知识项总数失败", zap.Error(err))
total = len(items)
}
c.JSON(http.StatusOK, gin.H{
"items": items,
"total": total,
"limit": limit,
"offset": offset,
})
} else {
// 返回摘要(不包含完整内容,推荐方式)
items, total, err := h.manager.GetItemsSummary(category, limit, offset)
if err != nil {
h.logger.Error("获取知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"items": items,
"total": total,
"limit": limit,
"offset": offset,
})
}
}
// GetItem 获取单个知识项
func (h *KnowledgeHandler) GetItem(c *gin.Context) {
id := c.Param("id")
item, err := h.manager.GetItem(id)
if err != nil {
h.logger.Error("获取知识项失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, item)
}
// CreateItem 创建知识项
func (h *KnowledgeHandler) CreateItem(c *gin.Context) {
var req struct {
Category string `json:"category" binding:"required"`
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
item, err := h.manager.CreateItem(req.Category, req.Title, req.Content)
if err != nil {
h.logger.Error("创建知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 异步索引
go func() {
ctx := context.Background()
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
}
}()
c.JSON(http.StatusOK, item)
}
// UpdateItem 更新知识项
func (h *KnowledgeHandler) UpdateItem(c *gin.Context) {
id := c.Param("id")
var req struct {
Category string `json:"category" binding:"required"`
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content)
if err != nil {
h.logger.Error("更新知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 异步重新索引
go func() {
ctx := context.Background()
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
}
}()
c.JSON(http.StatusOK, item)
}
// DeleteItem 删除知识项
func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DeleteItem(id); err != nil {
h.logger.Error("删除知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// RebuildIndex 重建索引
func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
// 异步重建索引
go func() {
ctx := context.Background()
if err := h.indexer.RebuildIndex(ctx); err != nil {
h.logger.Error("重建索引失败", zap.Error(err))
}
}()
if h.audit != nil {
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
}
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
}
// ScanKnowledgeBase 扫描知识库
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
itemsToIndex, err := h.manager.ScanKnowledgeBase()
if err != nil {
h.logger.Error("扫描知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if len(itemsToIndex) == 0 {
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"})
return
}
// 异步索引新添加或更新的项(增量索引)
go func() {
ctx := context.Background()
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
failedCount := 0
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
for i, itemID := range itemsToIndex {
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
h.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemsToIndex)),
zap.Error(err),
)
}
// 如果连续失败 2 次,立即停止增量索引
if consecutiveFailures >= 2 {
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率
if (i+1)%10 == 0 || i+1 == len(itemsToIndex) {
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount))
}
}
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
}()
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
"items_to_index": len(itemsToIndex),
})
}
// GetRetrievalLogs 获取检索日志
func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) {
conversationID := c.Query("conversationId")
messageID := c.Query("messageId")
limit := 50 // 默认 50 条
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
limit = parsed
}
}
logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit)
if err != nil {
h.logger.Error("获取检索日志失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"logs": logs})
}
// DeleteRetrievalLog 删除检索日志
func (h *KnowledgeHandler) DeleteRetrievalLog(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DeleteRetrievalLog(id); err != nil {
h.logger.Error("删除检索日志失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// GetIndexStatus 获取索引状态
func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
status, err := h.manager.GetIndexStatus()
if err != nil {
h.logger.Error("获取索引状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 获取索引器的错误信息
if h.indexer != nil {
lastError, lastErrorTime := h.indexer.GetLastError()
if lastError != "" {
// 如果错误是最近发生的(5 分钟内),则返回错误信息
if time.Since(lastErrorTime) < 5*time.Minute {
status["last_error"] = lastError
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
}
}
// 获取重建索引状态
isRebuilding, totalItems, current, failed, lastItemID, lastChunks, startTime := h.indexer.GetRebuildStatus()
if isRebuilding {
status["is_rebuilding"] = true
status["rebuild_total"] = totalItems
status["rebuild_current"] = current
status["rebuild_failed"] = failed
status["rebuild_start_time"] = startTime.Format(time.RFC3339)
if lastItemID != "" {
status["rebuild_last_item_id"] = lastItemID
}
if lastChunks > 0 {
status["rebuild_last_chunks"] = lastChunks
}
// 重建中时,is_complete 为 false
status["is_complete"] = false
// 计算重建进度百分比
if totalItems > 0 {
status["progress_percent"] = float64(current) / float64(totalItems) * 100
}
}
}
c.JSON(http.StatusOK, status)
}
// Search 搜索知识库(用于 API 调用,Agent 内部使用 Retriever
func (h *KnowledgeHandler) Search(c *gin.Context) {
var req knowledge.SearchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Retriever.Search 经 Eino VectorEinoRetriever,与 MCP 工具链一致。
results, err := h.retriever.Search(c.Request.Context(), &req)
if err != nil {
h.logger.Error("搜索知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"results": results})
}
// GetStats 获取知识库统计信息
func (h *KnowledgeHandler) GetStats(c *gin.Context) {
totalCategories, totalItems, err := h.manager.GetStats()
if err != nil {
h.logger.Error("获取知识库统计信息失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"enabled": true,
"total_categories": totalCategories,
"total_items": totalItems,
})
}
// 辅助函数:解析整数
func parseInt(s string) (int, error) {
var result int
_, err := fmt.Sscanf(s, "%d", &result)
return result, err
}
+333
View File
@@ -0,0 +1,333 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"github.com/gin-gonic/gin"
)
var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.md$`)
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
type MarkdownAgentsHandler struct {
dir string
audit *audit.Service
}
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
}
// SetAudit wires platform audit logging.
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
h.audit = s
}
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
filename = strings.TrimSpace(filename)
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
return "", fmt.Errorf("非法文件名")
}
clean := filepath.Clean(filename)
if clean != filename || strings.Contains(clean, "..") {
return "", fmt.Errorf("非法文件名")
}
return filepath.Join(h.dir, clean), nil
}
// existingOtherOrchestrator 若目录中已有同槽位的其他主代理文件,返回其文件名;writingBasename 为当前正在写入的文件名时不冲突。
func existingOtherOrchestrator(dir, writingBasename string) (other string, err error) {
load, err := agents.LoadMarkdownAgentsDir(dir)
if err != nil {
return "", err
}
wb := filepath.Base(strings.TrimSpace(writingBasename))
switch agents.OrchestratorMarkdownKind(wb) {
case "plan_execute":
if load.OrchestratorPlanExecute != nil && !strings.EqualFold(load.OrchestratorPlanExecute.Filename, wb) {
return load.OrchestratorPlanExecute.Filename, nil
}
case "supervisor":
if load.OrchestratorSupervisor != nil && !strings.EqualFold(load.OrchestratorSupervisor.Filename, wb) {
return load.OrchestratorSupervisor.Filename, nil
}
case "deep":
if load.Orchestrator != nil && !strings.EqualFold(load.Orchestrator.Filename, wb) {
return load.Orchestrator.Filename, nil
}
default:
if load.Orchestrator != nil && !strings.EqualFold(load.Orchestrator.Filename, wb) {
return load.Orchestrator.Filename, nil
}
}
return "", nil
}
// ListMarkdownAgents GET /api/multi-agent/markdown-agents
func (h *MarkdownAgentsHandler) ListMarkdownAgents(c *gin.Context) {
if h.dir == "" {
c.JSON(http.StatusOK, gin.H{"agents": []any{}, "dir": "", "error": "未配置 agents 目录"})
return
}
files, err := agents.LoadMarkdownAgentFiles(h.dir)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
out := make([]gin.H, 0, len(files))
for _, fa := range files {
sub := fa.Config
out = append(out, gin.H{
"filename": fa.Filename,
"id": sub.ID,
"name": sub.Name,
"description": sub.Description,
"is_orchestrator": fa.IsOrchestrator,
"kind": sub.Kind,
})
}
c.JSON(http.StatusOK, gin.H{"agents": out, "dir": h.dir})
}
// GetMarkdownAgent GET /api/multi-agent/markdown-agents/:filename
func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
filename := c.Param("filename")
path, err := h.safeJoin(filename)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
b, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sub, err := agents.ParseMarkdownSubAgent(filename, string(b))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
isOrch := agents.IsOrchestratorLikeMarkdown(filename, sub.Kind)
c.JSON(http.StatusOK, gin.H{
"filename": filename,
"raw": string(b),
"id": sub.ID,
"name": sub.Name,
"description": sub.Description,
"tools": sub.RoleTools,
"instruction": sub.Instruction,
"bind_role": sub.BindRole,
"max_iterations": sub.MaxIterations,
"kind": sub.Kind,
"is_orchestrator": isOrch,
})
}
type markdownAgentBody struct {
Filename string `json:"filename"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Tools []string `json:"tools"`
Instruction string `json:"instruction"`
BindRole string `json:"bind_role"`
MaxIterations int `json:"max_iterations"`
Kind string `json:"kind"`
Raw string `json:"raw"`
}
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents
func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
if h.dir == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "未配置 agents 目录"})
return
}
var body markdownAgentBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
filename := strings.TrimSpace(body.Filename)
if filename == "" {
if strings.EqualFold(strings.TrimSpace(body.Kind), "orchestrator") {
filename = agents.OrchestratorMarkdownFilename
} else {
base := agents.SlugID(body.Name)
if base == "" {
base = "agent"
}
filename = base + ".md"
}
}
path, err := h.safeJoin(filename)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err == nil {
c.JSON(http.StatusConflict, gin.H{"error": "文件已存在"})
return
}
sub := config.MultiAgentSubConfig{
ID: strings.TrimSpace(body.ID),
Name: strings.TrimSpace(body.Name),
Description: strings.TrimSpace(body.Description),
Instruction: strings.TrimSpace(body.Instruction),
RoleTools: body.Tools,
BindRole: strings.TrimSpace(body.BindRole),
MaxIterations: body.MaxIterations,
Kind: strings.TrimSpace(body.Kind),
}
base := filepath.Base(path)
if (strings.EqualFold(base, agents.OrchestratorMarkdownFilename) ||
strings.EqualFold(base, agents.OrchestratorPlanExecuteMarkdownFilename) ||
strings.EqualFold(base, agents.OrchestratorSupervisorMarkdownFilename)) && sub.Kind == "" {
sub.Kind = "orchestrator"
}
if sub.ID == "" {
sub.ID = agents.SlugID(sub.Name)
}
if sub.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"})
return
}
var out []byte
if strings.TrimSpace(body.Raw) != "" {
out = []byte(body.Raw)
} else {
out, err = agents.BuildMarkdownFile(sub)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if want := agents.WantsMarkdownOrchestrator(filepath.Base(path), body.Kind, string(out)); want {
other, oerr := existingOtherOrchestrator(h.dir, filepath.Base(path))
if oerr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()})
return
}
if other != "" {
c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)})
return
}
}
if err := os.MkdirAll(h.dir, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.WriteFile(path, out, 0644); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
}
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
}
// UpdateMarkdownAgent PUT /api/multi-agent/markdown-agents/:filename
func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
filename := c.Param("filename")
path, err := h.safeJoin(filename)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var body markdownAgentBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
sub := config.MultiAgentSubConfig{
ID: strings.TrimSpace(body.ID),
Name: strings.TrimSpace(body.Name),
Description: strings.TrimSpace(body.Description),
Instruction: strings.TrimSpace(body.Instruction),
RoleTools: body.Tools,
BindRole: strings.TrimSpace(body.BindRole),
MaxIterations: body.MaxIterations,
Kind: strings.TrimSpace(body.Kind),
}
if (strings.EqualFold(filename, agents.OrchestratorMarkdownFilename) ||
strings.EqualFold(filename, agents.OrchestratorPlanExecuteMarkdownFilename) ||
strings.EqualFold(filename, agents.OrchestratorSupervisorMarkdownFilename)) && sub.Kind == "" {
sub.Kind = "orchestrator"
}
if sub.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"})
return
}
if sub.ID == "" {
sub.ID = agents.SlugID(sub.Name)
}
var out []byte
if strings.TrimSpace(body.Raw) != "" {
out = []byte(body.Raw)
} else {
out, err = agents.BuildMarkdownFile(sub)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if want := agents.WantsMarkdownOrchestrator(filename, body.Kind, string(out)); want {
other, oerr := existingOtherOrchestrator(h.dir, filename)
if oerr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()})
return
}
if other != "" {
c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)})
return
}
}
if err := os.WriteFile(path, out, 0644); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
}
// DeleteMarkdownAgent DELETE /api/multi-agent/markdown-agents/:filename
func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
filename := c.Param("filename")
path, err := h.safeJoin(filename)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
}
+618
View File
@@ -0,0 +1,618 @@
package handler
import (
"encoding/json"
"errors"
"io"
"net/http"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// MonitorHandler 监控处理器
type MonitorHandler struct {
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager
executor *security.Executor
db *database.DB
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *MonitorHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewMonitorHandler 创建新的监控处理器
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, db *database.DB, logger *zap.Logger) *MonitorHandler {
return &MonitorHandler{
mcpServer: mcpServer,
externalMCPMgr: nil, // 将在创建后设置
executor: executor,
db: db,
logger: logger,
}
}
// SetExternalMCPManager 设置外部MCP管理器
func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
h.externalMCPMgr = mgr
}
// MonitorResponse 监控响应
type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"`
Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"`
}
// Monitor 获取监控信息
func (h *MonitorHandler) Monitor(c *gin.Context) {
// 解析分页参数
page := 1
pageSize := 20
if pageStr := c.Query("page"); pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
}
}
if pageSizeStr := c.Query("page_size"); pageSizeStr != "" {
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 {
pageSize = ps
}
}
// 解析状态筛选参数
status := c.Query("status")
// 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool
toolName := normalizeToolNameFilter(c.Query("tool"))
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
stats := h.loadStats()
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
c.JSON(http.StatusOK, MonitorResponse{
Executions: executions,
Stats: stats,
Timestamp: time.Now(),
Total: total,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
})
}
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
return executions
}
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
if h.db == nil {
allExecutions := h.mcpServer.GetAllExecutions()
// 如果指定了状态筛选或工具筛选,先进行筛选
if status != "" || toolName != "" {
filtered := make([]*mcp.ToolExecution, 0)
for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索)
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
if matchStatus && matchTool {
filtered = append(filtered, exec)
}
}
allExecutions = filtered
}
total := len(allExecutions)
offset := (page - 1) * pageSize
end := offset + pageSize
if end > total {
end = total
}
if offset >= total {
return []*mcp.ToolExecution{}, total
}
return allExecutions[offset:end], total
}
offset := (page - 1) * pageSize
executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status, toolName)
if err != nil {
h.logger.Warn("从数据库加载执行记录失败,回退到内存数据", zap.Error(err))
allExecutions := h.mcpServer.GetAllExecutions()
// 如果指定了状态筛选或工具筛选,先进行筛选
if status != "" || toolName != "" {
filtered := make([]*mcp.ToolExecution, 0)
for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索)
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
if matchStatus && matchTool {
filtered = append(filtered, exec)
}
}
allExecutions = filtered
}
total := len(allExecutions)
offset := (page - 1) * pageSize
end := offset + pageSize
if end > total {
end = total
}
if offset >= total {
return []*mcp.ToolExecution{}, total
}
return allExecutions[offset:end], total
}
// 获取总数(考虑状态筛选和工具筛选)
total, err := h.db.CountToolExecutions(status, toolName)
if err != nil {
h.logger.Warn("获取执行记录总数失败", zap.Error(err))
// 回退:使用已加载的记录数估算
total = offset + len(executions)
if len(executions) == pageSize {
total = offset + len(executions) + 1
}
}
return executions, total
}
func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
// 合并内部MCP服务器和外部MCP管理器的统计信息
stats := make(map[string]*mcp.ToolStats)
// 加载内部MCP服务器的统计信息
if h.db == nil {
internalStats := h.mcpServer.GetStats()
for k, v := range internalStats {
stats[k] = v
}
} else {
dbStats, err := h.db.LoadToolStats()
if err != nil {
h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err))
internalStats := h.mcpServer.GetStats()
for k, v := range internalStats {
stats[k] = v
}
} else {
for k, v := range dbStats {
stats[k] = v
}
}
}
// 合并外部MCP管理器的统计信息
if h.externalMCPMgr != nil {
externalStats := h.externalMCPMgr.GetToolStats()
for k, v := range externalStats {
// 如果已存在,合并统计信息
if existing, exists := stats[k]; exists {
existing.TotalCalls += v.TotalCalls
existing.SuccessCalls += v.SuccessCalls
existing.FailedCalls += v.FailedCalls
// 使用最新的调用时间
if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) {
existing.LastCallTime = v.LastCallTime
}
} else {
stats[k] = v
}
}
}
return stats
}
// GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id")
// 先从内部MCP服务器查找
exec, exists := h.mcpServer.GetExecution(id)
if exists {
c.JSON(http.StatusOK, exec)
return
}
// 如果找不到,尝试从外部MCP管理器查找
if h.externalMCPMgr != nil {
exec, exists = h.externalMCPMgr.GetExecution(id)
if exists {
c.JSON(http.StatusOK, exec)
return
}
}
// 如果都找不到,尝试从数据库查找(如果使用数据库存储)
if h.db != nil {
exec, err := h.db.GetToolExecution(id)
if err == nil && exec != nil {
c.JSON(http.StatusOK, exec)
return
}
}
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
}
// CancelExecution 手动取消进行中的 MCP 工具调用(仅取消该次 tools/call 的上下文,不停止整条 Agent / 迭代任务)
// 请求体可选 JSON{ "note": "用户说明" },将与工具已返回输出合并交给模型(含「用户终止说明」标题块,与命令行原文区分)。
func (h *MonitorHandler) CancelExecution(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"})
return
}
note := ""
dec := json.NewDecoder(c.Request.Body)
var body struct {
Note string `json:"note"`
}
if err := dec.Decode(&body); err != nil && !errors.Is(err, io.EOF) {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体须为 JSON,例如 {\"note\":\"说明\"},可为空对象"})
return
}
note = strings.TrimSpace(body.Note)
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
return
}
if h.externalMCPMgr != nil && h.externalMCPMgr.CancelToolExecutionWithNote(id, note) {
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "external"), zap.Bool("hasNote", note != ""))
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
return
}
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
}
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
var req struct {
IDs []string `json:"ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
result := make(map[string]string, len(req.IDs))
for _, id := range req.IDs {
// 先从内部MCP服务器查找
if exec, exists := h.mcpServer.GetExecution(id); exists {
result[id] = exec.ToolName
continue
}
// 再从外部MCP管理器查找
if h.externalMCPMgr != nil {
if exec, exists := h.externalMCPMgr.GetExecution(id); exists {
result[id] = exec.ToolName
continue
}
}
// 最后从数据库查找
if h.db != nil {
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
result[id] = exec.ToolName
}
}
}
c.JSON(http.StatusOK, result)
}
// GetStats 获取统计信息
func (h *MonitorHandler) GetStats(c *gin.Context) {
stats := h.loadStats()
c.JSON(http.StatusOK, stats)
}
// CallsTimelinePoint 调用趋势数据点
type CallsTimelinePoint struct {
T time.Time `json:"t"`
Total int `json:"total"`
Failed int `json:"failed"`
}
// CallsTimelineSummary 调用趋势汇总
type CallsTimelineSummary struct {
TotalCalls int `json:"totalCalls"`
Peak int `json:"peak"`
}
// CallsTimelineResponse 调用趋势响应
type CallsTimelineResponse struct {
Range string `json:"range"`
Points []CallsTimelinePoint `json:"points"`
Summary CallsTimelineSummary `json:"summary"`
}
type callsTimelineConfig struct {
rangeKey string
duration time.Duration
bucketSize time.Duration
dailyBuckets bool
}
func parseCallsTimelineRange(raw string) (callsTimelineConfig, bool) {
switch strings.TrimSpace(raw) {
case "24h":
return callsTimelineConfig{rangeKey: "24h", duration: 24 * time.Hour, bucketSize: time.Hour, dailyBuckets: false}, true
case "30d":
return callsTimelineConfig{rangeKey: "30d", duration: 30 * 24 * time.Hour, bucketSize: 24 * time.Hour, dailyBuckets: true}, true
default:
return callsTimelineConfig{rangeKey: "7d", duration: 7 * 24 * time.Hour, bucketSize: time.Hour, dailyBuckets: false}, true
}
}
func truncateToBucket(t time.Time, bucketSize time.Duration, dailyBuckets bool) time.Time {
if dailyBuckets {
y, m, d := t.Date()
return time.Date(y, m, d, 0, 0, 0, 0, t.Location())
}
return t.Truncate(bucketSize)
}
func buildCallsTimelinePoints(cfg callsTimelineConfig, buckets map[time.Time]struct{ total, failed int }) []CallsTimelinePoint {
now := time.Now()
start := truncateToBucket(now.Add(-cfg.duration), cfg.bucketSize, cfg.dailyBuckets)
end := truncateToBucket(now, cfg.bucketSize, cfg.dailyBuckets)
points := make([]CallsTimelinePoint, 0)
for current := start; !current.After(end); current = current.Add(cfg.bucketSize) {
val := buckets[current]
points = append(points, CallsTimelinePoint{
T: current,
Total: val.total,
Failed: val.failed,
})
}
return points
}
func (h *MonitorHandler) loadCallsTimeline(cfg callsTimelineConfig) []CallsTimelinePoint {
since := time.Now().Add(-cfg.duration)
bucketMap := make(map[time.Time]struct{ total, failed int })
if h.db != nil {
dbBuckets, err := h.db.LoadCallsTimeline(since, cfg.dailyBuckets)
if err != nil {
h.logger.Warn("从数据库加载调用趋势失败,回退到内存数据", zap.Error(err))
} else {
for _, b := range dbBuckets {
key := truncateToBucket(b.BucketTime, cfg.bucketSize, cfg.dailyBuckets)
entry := bucketMap[key]
entry.total += b.Total
entry.failed += b.Failed
bucketMap[key] = entry
}
return buildCallsTimelinePoints(cfg, bucketMap)
}
}
for _, exec := range h.mcpServer.GetAllExecutions() {
if exec == nil || exec.StartTime.Before(since) {
continue
}
key := truncateToBucket(exec.StartTime, cfg.bucketSize, cfg.dailyBuckets)
entry := bucketMap[key]
entry.total++
if exec.Status == "failed" || exec.Status == "cancelled" {
entry.failed++
}
bucketMap[key] = entry
}
return buildCallsTimelinePoints(cfg, bucketMap)
}
// GetCallsTimeline 获取 MCP 工具调用趋势
func (h *MonitorHandler) GetCallsTimeline(c *gin.Context) {
cfg, _ := parseCallsTimelineRange(c.Query("range"))
points := h.loadCallsTimeline(cfg)
summary := CallsTimelineSummary{}
for _, p := range points {
summary.TotalCalls += p.Total
if p.Total > summary.Peak {
summary.Peak = p.Total
}
}
c.JSON(http.StatusOK, CallsTimelineResponse{
Range: cfg.rangeKey,
Points: points,
Summary: summary,
})
}
// DeleteExecution 删除执行记录
func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"})
return
}
// 如果使用数据库,先获取执行记录信息,然后删除并更新统计
if h.db != nil {
// 先获取执行记录信息(用于更新统计)
exec, err := h.db.GetToolExecution(id)
if err != nil {
// 如果找不到记录,可能已经被删除,直接返回成功
h.logger.Warn("执行记录不存在,可能已被删除", zap.String("executionId", id), zap.Error(err))
c.JSON(http.StatusOK, gin.H{"message": "执行记录不存在或已被删除"})
return
}
// 删除执行记录
err = h.db.DeleteToolExecution(id)
if err != nil {
h.logger.Error("删除执行记录失败", zap.Error(err), zap.String("executionId", id))
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除执行记录失败: " + err.Error()})
return
}
// 更新统计信息(减少相应的计数)
totalCalls := 1
successCalls := 0
failedCalls := 0
if exec.Status == "failed" || exec.Status == "cancelled" {
failedCalls = 1
} else if exec.Status == "completed" {
successCalls = 1
}
if exec.ToolName != "" {
if err := h.db.DecreaseToolStats(exec.ToolName, totalCalls, successCalls, failedCalls); err != nil {
h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", exec.ToolName))
// 不返回错误,因为记录已经删除成功
}
}
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
if h.audit != nil {
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
"tool_name": exec.ToolName,
})
}
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
return
}
// 如果不使用数据库,尝试从内存中删除(内部MCP服务器)
// 注意:内存中的记录可能已经被清理,所以这里只记录日志
h.logger.Info("尝试删除内存中的执行记录", zap.String("executionId", id))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
}
// DeleteExecutions 批量删除执行记录
func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
var request struct {
IDs []string `json:"ids"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
if len(request.IDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID列表不能为空"})
return
}
// 如果使用数据库,先获取执行记录信息,然后删除并更新统计
if h.db != nil {
// 先获取执行记录信息(用于更新统计)
executions, err := h.db.GetToolExecutionsByIds(request.IDs)
if err != nil {
h.logger.Error("获取执行记录失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取执行记录失败: " + err.Error()})
return
}
// 按工具名称分组统计需要减少的数量
toolStats := make(map[string]struct {
totalCalls int
successCalls int
failedCalls int
})
for _, exec := range executions {
if exec.ToolName == "" {
continue
}
stats := toolStats[exec.ToolName]
stats.totalCalls++
if exec.Status == "failed" || exec.Status == "cancelled" {
stats.failedCalls++
} else if exec.Status == "completed" {
stats.successCalls++
}
toolStats[exec.ToolName] = stats
}
// 批量删除执行记录
err = h.db.DeleteToolExecutions(request.IDs)
if err != nil {
h.logger.Error("批量删除执行记录失败", zap.Error(err), zap.Int("count", len(request.IDs)))
c.JSON(http.StatusInternalServerError, gin.H{"error": "批量删除执行记录失败: " + err.Error()})
return
}
// 更新统计信息(减少相应的计数)
for toolName, stats := range toolStats {
if err := h.db.DecreaseToolStats(toolName, stats.totalCalls, stats.successCalls, stats.failedCalls); err != nil {
h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", toolName))
// 不返回错误,因为记录已经删除成功
}
}
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
if h.audit != nil {
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
"count": len(request.IDs),
})
}
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
return
}
// 如果不使用数据库,尝试从内存中删除(内部MCP服务器)
// 注意:内存中的记录可能已经被清理,所以这里只记录日志
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
}
// normalizeToolNameFilter 将模型侧 mcp__tool 转为内部存储用的 mcp::tool。
func normalizeToolNameFilter(name string) string {
name = strings.TrimSpace(name)
if name == "" {
return name
}
if strings.Contains(name, "::") {
return name
}
if idx := strings.Index(name, "__"); idx > 0 {
return name[:idx] + "::" + name[idx+2:]
}
return name
}
func toolNameFilterMatches(storedName, filter string) bool {
filter = strings.TrimSpace(filter)
if filter == "" {
return true
}
storedLower := strings.ToLower(storedName)
filterLower := strings.ToLower(filter)
if strings.Contains(storedLower, filterLower) {
return true
}
normFilter := strings.ToLower(normalizeToolNameFilter(filter))
if normFilter != filterLower && strings.Contains(storedLower, normFilter) {
return true
}
return strings.Contains(strings.ReplaceAll(storedLower, "::", "__"), filterLower)
}
+609
View File
@@ -0,0 +1,609 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/multiagent"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。
func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
c.Header("Content-Type", "text/event-stream; charset=utf-8")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
if h.config == nil || !h.config.MultiAgent.Enabled {
ev := StreamEvent{Type: "error", Message: "多代理未启用,请在设置或 config.yaml 中开启 multi_agent.enabled"}
b, _ := json.Marshal(ev)
fmt.Fprintf(c.Writer, "data: %s\n\n", b)
done := StreamEvent{Type: "done", Message: ""}
db, _ := json.Marshal(done)
fmt.Fprintf(c.Writer, "data: %s\n\n", db)
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
return
}
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
event := StreamEvent{Type: "error", Message: "请求参数错误: " + err.Error()}
b, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", b)
done := StreamEvent{Type: "done", Message: ""}
db, _ := json.Marshal(done)
fmt.Fprintf(c.Writer, "data: %s\n\n", db)
c.Writer.Flush()
return
}
c.Header("X-Accel-Buffering", "no")
// 用于在 sendEvent 中判断是否为用户主动停止导致的取消。
// 注意:baseCtx 会在后面创建;该变量用于闭包提前捕获引用。
var baseCtx context.Context
clientDisconnected := false
// 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。
var sseWriteMu sync.Mutex
var ssePublishConversationID string
sendEvent := func(eventType, message string, data interface{}) {
// 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
if eventType == "error" && baseCtx != nil {
cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
return
}
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, errMarshal := json.Marshal(ev)
if errMarshal != nil {
b = []byte(`{"type":"error","message":"marshal failed"}`)
}
sseLine := make([]byte, 0, len(b)+8)
sseLine = append(sseLine, []byte("data: ")...)
sseLine = append(sseLine, b...)
sseLine = append(sseLine, '\n', '\n')
if ssePublishConversationID != "" && h.taskEventBus != nil {
h.taskEventBus.Publish(ssePublishConversationID, sseLine)
}
if clientDisconnected {
return
}
select {
case <-c.Request.Context().Done():
clientDisconnected = true
return
default:
}
sseWriteMu.Lock()
_, err := c.Writer.Write(sseLine)
if err != nil {
sseWriteMu.Unlock()
clientDisconnected = true
return
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
c.Writer.Flush()
}
sseWriteMu.Unlock()
}
h.logger.Info("收到 Eino DeepAgent 流式请求",
zap.String("conversationId", req.ConversationID),
)
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
if err != nil {
sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil)
return
}
ssePublishConversationID = prep.ConversationID
if prep.CreatedNew {
sendEvent("conversation", "会话已创建", map[string]interface{}{
"conversationId": prep.ConversationID,
})
}
conversationID := prep.ConversationID
assistantMessageID := prep.AssistantMessageID
h.activateHITLForConversation(conversationID, req.Hitl)
if h.hitlManager != nil {
defer h.hitlManager.DeactivateConversation(conversationID)
}
if prep.UserMessageID != "" {
sendEvent("message_saved", "", map[string]interface{}{
"conversationId": conversationID,
"userMessageId": prep.UserMessageID,
})
}
var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History
roleTools := prep.RoleTools
orch := strings.TrimSpace(req.Orchestration)
taskStatus := "completed"
// 仅在成功 StartTask 后再 FinishTask;避免「任务已存在」分支 return 时误删正在运行的同会话任务。
taskOwned := false
defer func() {
if taskOwned {
h.tasks.FinishTask(conversationID, taskStatus)
}
}()
sendEvent("progress", "正在启动 Eino 多代理...", map[string]interface{}{
"conversationId": conversationID,
})
stopKeepalive := make(chan struct{})
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
defer close(stopKeepalive)
var result *multiagent.RunResult
var runErr error
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
var errorMsg string
if errors.Is(err, ErrTaskAlreadyRunning) {
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
sendEvent("error", errorMsg, map[string]interface{}{
"conversationId": conversationID,
"errorType": "task_already_running",
})
} else {
errorMsg = "❌ 无法启动任务: " + err.Error()
sendEvent("error", errorMsg, nil)
}
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
}
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
taskOwned = true
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
var emptyResponseAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int
for {
segmentMainIterationMax := 0
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
progressCallback := func(eventType, message string, data interface{}) {
if eventType == "iteration" {
if m, ok := data.(map[string]interface{}); ok {
if scope, _ := m["einoScope"].(string); scope == "main" {
raw := 0
switch v := m["iteration"].(type) {
case int:
raw = v
case int32:
raw = int(v)
case int64:
raw = int(v)
case float64:
raw = int(v)
case float32:
raw = int(v)
}
if raw > 0 {
if raw > segmentMainIterationMax {
segmentMainIterationMax = raw
}
m["iteration"] = raw + mainIterationOffset
}
}
}
}
rawProgressCallback(eventType, message, data)
}
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
})
result, runErr = multiagent.RunDeepAgent(
taskCtxLoop,
h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
conversationID,
h.conversationProjectID(conversationID),
curFinalMessage,
curHistory,
roleTools,
progressCallback,
h.agentsMarkdownDir,
orch,
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(conversationID),
)
if result != nil && len(result.MCPExecutionIDs) > 0 {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
}
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if exhaustedEmpty {
runErr = nil
transientRunAttempts = 0
timeoutCancel()
break
}
if handledEmpty {
mainIterationOffset += segmentMainIterationMax
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
emptyResponseAttempts = 0
timeoutCancel()
break
}
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
note := h.tasks.TakeInterruptContinueNote(conversationID)
icSummary := interruptContinueTimelineSummary(note)
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
"conversationId": conversationID,
"rawReason": strings.TrimSpace(note),
"emptyReason": strings.TrimSpace(note) == "",
"kind": "no_active_mcp_tool",
})
inject := formatInterruptContinueUserMessage(note)
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
curHistory = hist
}
curFinalMessage = inject
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
"conversationId": conversationID,
"source": "interrupt_continue",
})
mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
cancelMsg := "任务已被用户取消,后续操作已停止。"
if assistantMessageID != "" {
if result != nil {
if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil {
h.logger.Warn("合并取消前的部分回复失败", zap.Error(err))
}
}
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
}
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
}
sendEvent("cancelled", cancelMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) {
taskStatus = "timeout"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
timeoutMsg := "任务执行超时,已自动终止。"
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
}
sendEvent("error", timeoutMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
"errorType": "timeout",
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
taskStatus = "failed"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
errMsg := "执行失败: " + runErr.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
}
sendEvent("error", errMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
timeoutCancel()
if assistantMessageID != "" {
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
}
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
effectiveOrch := config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration)
if o := strings.TrimSpace(req.Orchestration); o != "" {
effectiveOrch = config.NormalizeMultiAgentOrchestration(o)
}
sendEvent("response", result.Response, map[string]interface{}{
"mcpExecutionIds": cumulativeMCPExecutionIDs,
"conversationId": conversationID,
"messageId": assistantMessageID,
"agentMode": "eino_" + effectiveOrch,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
}
// MultiAgentLoop Eino DeepAgent 非流式对话(需 multi_agent.enabled)。
func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
if h.config == nil || !h.config.MultiAgent.Enabled {
c.JSON(http.StatusNotFound, gin.H{"error": "多代理未启用,请在 config.yaml 中设置 multi_agent.enabled: true"})
return
}
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
if err != nil {
status, msg := multiAgentHTTPErrorStatus(err)
c.JSON(status, gin.H{"error": msg})
return
}
h.activateHITLForConversation(prep.ConversationID, req.Hitl)
if h.hitlManager != nil {
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
}
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
defer cancelWithCause(nil)
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
defer timeoutCancel()
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil)
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
})
curHist := prep.History
curMsg := prep.FinalMessage
var result *multiagent.RunResult
var runErr error
var transientRunAttempts int
var emptyResponseAttempts int
for {
result, runErr = multiagent.RunDeepAgent(
taskCtx,
h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
prep.ConversationID,
h.conversationProjectID(prep.ConversationID),
curMsg,
curHist,
prep.RoleTools,
progressCallback,
h.agentsMarkdownDir,
strings.TrimSpace(req.Orchestration),
chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID),
)
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
)
if exhaustedEmpty {
runErr = nil
break
}
if handledEmpty {
continue
}
if runErr == nil {
break
}
if handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
); handled {
continue
} else if fatalErr != nil {
runErr = fatalErr
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
}
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
errMsg := "执行失败: " + runErr.Error()
if prep.AssistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID)
}
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg})
return
}
if prep.AssistantMessageID != "" {
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
}
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
c.JSON(http.StatusOK, ChatResponse{
Response: result.Response,
MCPExecutionIDs: result.MCPExecutionIDs,
ConversationID: prep.ConversationID,
Time: time.Now(),
})
}
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
if h == nil || result == nil {
return
}
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
return
}
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。
func mergeMCPExecutionIDLists(dst []string, more []string) []string {
seen := make(map[string]struct{}, len(dst)+len(more))
out := make([]string, 0, len(dst)+len(more))
add := func(ids []string) {
for _, id := range ids {
id = strings.TrimSpace(id)
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
}
add(dst)
add(more)
return out
}
// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。
func interruptContinueTimelineSummary(note string) string {
note = strings.TrimSpace(note)
if note == "" {
return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。"
}
return "用户中断说明(原文):\n\n" + note
}
// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。
func formatInterruptContinueUserMessage(note string) string {
var b strings.Builder
b.WriteString("【用户补充 / 中断后继续】\n")
if s := strings.TrimSpace(note); s != "" {
b.WriteString(s)
b.WriteString("\n\n")
}
b.WriteString("【请在本轮落实】\n")
b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n")
b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n")
b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n")
return strings.TrimSpace(b.String())
}
func multiAgentHTTPErrorStatus(err error) (int, string) {
msg := err.Error()
switch {
case strings.Contains(msg, "对话不存在"):
return http.StatusNotFound, msg
case strings.Contains(msg, "未找到该 WebShell"):
return http.StatusBadRequest, msg
case strings.Contains(msg, "附件最多"):
return http.StatusBadRequest, msg
case strings.Contains(msg, "保存用户消息失败"), strings.Contains(msg, "创建对话失败"):
return http.StatusInternalServerError, msg
case strings.Contains(msg, "保存上传文件失败"):
return http.StatusInternalServerError, msg
default:
return http.StatusBadRequest, msg
}
}
+152
View File
@@ -0,0 +1,152 @@
package handler
import (
"fmt"
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// multiAgentPrepared 多代理请求在调用 Eino 前的会话与消息准备结果。
type multiAgentPrepared struct {
ConversationID string
CreatedNew bool
History []agent.ChatMessage
FinalMessage string
RoleTools []string
AssistantMessageID string
UserMessageID string
}
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
if len(req.Attachments) > maxAttachments {
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
}
conversationID := strings.TrimSpace(req.ConversationID)
createdNew := false
if conversationID == "" {
title := safeTruncateString(req.Message, 50)
var conv *database.Conversation
var err error
meta := audit.ConversationCreateMetaFromGin(c, source)
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
if strings.TrimSpace(req.WebShellConnectionID) != "" {
meta.Source = source + "_webshell"
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
} else {
conv, err = h.db.CreateConversation(title, meta)
}
if err != nil {
return nil, fmt.Errorf("创建对话失败: %w", err)
}
conversationID = conv.ID
createdNew = true
} else {
if _, err := h.db.GetConversation(conversationID); err != nil {
return nil, fmt.Errorf("对话不存在")
}
}
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil {
agentHistoryMessages = []agent.ChatMessage{}
} else {
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
}
}
finalMessage := req.Message
var roleTools []string
if req.WebShellConnectionID != "" {
conn, errConn := h.db.GetWebshellConnection(strings.TrimSpace(req.WebShellConnectionID))
if errConn != nil || conn == nil {
h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn))
return nil, fmt.Errorf("未找到该 WebShell 连接")
}
webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, req.Message)
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + webshellContext
h.logger.Info("WebShell + 角色: 应用角色提示词(多代理)", zap.String("role", req.Role))
} else {
finalMessage = webshellContext
}
} else {
finalMessage = webshellContext
}
roleTools = []string{
builtin.ToolWebshellExec,
builtin.ToolWebshellFileList,
builtin.ToolWebshellFileRead,
builtin.ToolWebshellFileWrite,
builtin.ToolRecordVulnerability,
builtin.ToolListVulnerabilities,
builtin.ToolGetVulnerability,
builtin.ToolUpsertProjectFact,
builtin.ToolGetProjectFact,
builtin.ToolListProjectFacts,
builtin.ToolSearchProjectFacts,
builtin.ToolDeprecateProjectFact,
builtin.ToolRestoreProjectFact,
builtin.ToolListKnowledgeRiskTypes,
builtin.ToolSearchKnowledgeBase,
}
} else if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
}
roleTools = role.Tools
}
}
var savedPaths []string
if len(req.Attachments) > 0 {
var aerr error
savedPaths, aerr = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger)
if aerr != nil {
return nil, fmt.Errorf("保存上传文件失败: %w", aerr)
}
}
finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths)
userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths)
userMsgRow, uerr := h.db.AddMessage(conversationID, "user", userContent, nil)
if uerr != nil {
h.logger.Error("保存用户消息失败", zap.Error(uerr))
return nil, fmt.Errorf("保存用户消息失败: %w", uerr)
}
userMessageID := ""
if userMsgRow != nil {
userMessageID = userMsgRow.ID
}
assistantMsg, aerr := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
var assistantMessageID string
if aerr != nil {
h.logger.Warn("创建助手消息占位失败", zap.Error(aerr))
} else if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
return &multiAgentPrepared{
ConversationID: conversationID,
CreatedNew: createdNew,
History: agentHistoryMessages,
FinalMessage: finalMessage,
RoleTools: roleTools,
AssistantMessageID: assistantMessageID,
UserMessageID: userMessageID,
}, nil
}
+699
View File
@@ -0,0 +1,699 @@
package handler
import (
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// NotificationHandler 聚合通知(Phase 2:服务端统一计算)
type NotificationHandler struct {
db *database.DB
agentHandler *AgentHandler
logger *zap.Logger
}
const notificationReadMaxRows = 150
// NotificationSummaryItem 通知项
type NotificationSummaryItem struct {
ID string `json:"id"`
Level string `json:"level"` // p0/p1/p2
Type string `json:"type"`
Title string `json:"title"`
Desc string `json:"desc"`
Ts string `json:"ts"` // RFC3339
Count int `json:"count,omitempty"`
Actionable bool `json:"actionable"`
Read bool `json:"read"`
// 以下字段用于前端深链跳转(通知即入口)
ConversationID string `json:"conversationId,omitempty"`
VulnerabilityID string `json:"vulnerabilityId,omitempty"`
ExecutionID string `json:"executionId,omitempty"`
InterruptID string `json:"interruptId,omitempty"`
SessionID string `json:"sessionId,omitempty"` // C2 会话(如新会话上线)
}
// NotificationSummaryResponse 聚合响应
type NotificationSummaryResponse struct {
SinceMs int64 `json:"sinceMs"`
GeneratedAt string `json:"generatedAt"`
P0Count int `json:"p0Count"`
UnreadCount int `json:"unreadCount"`
Counts map[string]int `json:"counts"`
Items []NotificationSummaryItem `json:"items"`
}
func NewNotificationHandler(db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *NotificationHandler {
return &NotificationHandler{
db: db,
agentHandler: agentHandler,
logger: logger,
}
}
func parseSinceMs(raw string) int64 {
v := strings.TrimSpace(raw)
if v == "" {
return 0
}
if ms, err := strconv.ParseInt(v, 10, 64); err == nil && ms > 0 {
return ms
}
if t, err := time.Parse(time.RFC3339, v); err == nil {
return t.UnixMilli()
}
return 0
}
func unixSecToRFC3339(sec int64) string {
if sec <= 0 {
return time.Now().UTC().Format(time.RFC3339)
}
return time.Unix(sec, 0).UTC().Format(time.RFC3339)
}
func normalizedSinceSec(sinceMs int64) int64 {
sec := sinceMs / 1000
// SQLite 默认时间精度到秒;给 1s 回看窗口,避免“同秒内新增”被漏算。
if sec > 0 {
return sec - 1
}
return 0
}
func normalizeSinceMs(raw int64) int64 {
if raw > 0 {
return raw
}
// 默认仅看最近 24 小时,避免首次打开拉全量历史噪音。
return time.Now().Add(-24 * time.Hour).UnixMilli()
}
func levelBySeverity(sev string) string {
switch strings.ToLower(strings.TrimSpace(sev)) {
case "critical", "high":
return "p0"
case "medium":
return "p1"
default:
return "p2"
}
}
func requestWantsEnglish(c *gin.Context) bool {
if c == nil {
return false
}
lang := strings.ToLower(strings.TrimSpace(c.Query("lang")))
if lang == "" {
lang = strings.ToLower(strings.TrimSpace(c.GetHeader("Accept-Language")))
}
return strings.HasPrefix(lang, "en")
}
func i18nText(english bool, zh string, en string) string {
if english {
return en
}
return zh
}
func (h *NotificationHandler) loadPendingHITLItems(limit int, english bool) ([]NotificationSummaryItem, error) {
rows, err := h.db.Query(`
SELECT
id,
conversation_id,
tool_name,
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM hitl_interrupts
WHERE status = 'pending'
ORDER BY created_at DESC
LIMIT ?
`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
for rows.Next() {
var id, conversationID, toolName string
var createdSec int64
if err := rows.Scan(&id, &conversationID, &toolName, &createdSec); err != nil {
continue
}
desc := i18nText(english, "会话 "+conversationID+" 的审批中断待处理", "Conversation "+conversationID+" has pending HITL approval")
if strings.TrimSpace(toolName) != "" {
desc = i18nText(english, "工具 "+toolName+" 等待审批", "Tool "+toolName+" is waiting for approval")
}
items = append(items, NotificationSummaryItem{
ID: "hitl:" + id,
Level: "p0",
Type: "hitl_pending",
Title: i18nText(english, "HITL 待审批", "HITL Pending Approval"),
Desc: desc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: true,
Read: false,
ConversationID: conversationID,
InterruptID: id,
})
}
return items, nil
}
func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, map[string]int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT
id,
title,
severity,
conversation_id,
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM vulnerabilities
WHERE CAST(strftime('%s', created_at) AS INTEGER) > ?
ORDER BY created_at DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, nil, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
counts := map[string]int{
"newCriticalVulns": 0,
"newHighVulns": 0,
"newMediumVulns": 0,
"newLowVulns": 0,
"newInfoVulns": 0,
}
for rows.Next() {
var id, title, severity, conversationID string
var createdSec int64
if err := rows.Scan(&id, &title, &severity, &conversationID, &createdSec); err != nil {
continue
}
switch strings.ToLower(strings.TrimSpace(severity)) {
case "critical":
counts["newCriticalVulns"]++
case "high":
counts["newHighVulns"]++
case "medium":
counts["newMediumVulns"]++
case "low":
counts["newLowVulns"]++
default:
counts["newInfoVulns"]++
}
sevUpper := strings.ToUpper(strings.TrimSpace(severity))
if sevUpper == "" {
sevUpper = "INFO"
}
finalTitle := i18nText(english, "新漏洞("+sevUpper+"", "New Vulnerability ("+sevUpper+")")
finalDesc := strings.TrimSpace(title)
if finalDesc == "" {
finalDesc = i18nText(english, "(无标题)", "(Untitled)")
}
items = append(items, NotificationSummaryItem{
ID: "vuln:" + id,
Level: levelBySeverity(severity),
Type: "vulnerability_created",
Title: finalTitle,
Desc: finalDesc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: false,
Read: false,
ConversationID: conversationID,
VulnerabilityID: id,
})
}
return items, counts, nil
}
// loadC2SessionOnlineEvents 新会话上线(c2_eventssession + critical,与 Manager.IngestCheckIn 一致)
func (h *NotificationHandler) loadC2SessionOnlineEvents(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT id, message, COALESCE(session_id, ''),
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM c2_events
WHERE category = 'session' AND level = 'critical'
AND CAST(strftime('%s', created_at) AS INTEGER) > ?
ORDER BY created_at DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, 0, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
for rows.Next() {
var id, message, sessionID string
var createdSec int64
if err := rows.Scan(&id, &message, &sessionID, &createdSec); err != nil {
continue
}
desc := strings.TrimSpace(message)
if len(desc) > 220 {
desc = desc[:200] + "…"
}
if desc == "" {
desc = i18nText(english, "新会话已建立", "A new session was created")
}
items = append(items, NotificationSummaryItem{
ID: "c2evt:" + id,
Level: "p0",
Type: "c2_session_online",
Title: i18nText(english, "C2 新会话上线", "C2 new session online"),
Desc: desc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: false,
Read: false,
SessionID: sessionID,
})
}
return items, len(items), rows.Err()
}
func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT
id,
tool_name,
COALESCE(CAST(strftime('%s', start_time) AS INTEGER), 0)
FROM tool_executions
WHERE status = 'failed'
AND CAST(strftime('%s', start_time) AS INTEGER) > ?
ORDER BY start_time DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, 0, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
count := 0
for rows.Next() {
var id, toolName string
var startSec int64
if err := rows.Scan(&id, &toolName, &startSec); err != nil {
continue
}
count++
if strings.TrimSpace(toolName) == "" {
toolName = i18nText(english, "未知工具", "unknown")
}
items = append(items, NotificationSummaryItem{
ID: "exec_failed:" + id,
Level: "p0",
Type: "task_failed",
Title: i18nText(english, "任务执行失败", "Task Execution Failed"),
Desc: i18nText(english, "工具 "+toolName+" 执行失败", "Tool "+toolName+" execution failed"),
Ts: unixSecToRFC3339(startSec),
Count: 1,
Actionable: false,
Read: false,
ExecutionID: id,
})
}
return items, count, nil
}
func (h *NotificationHandler) summarizeLongRunningTasks(threshold time.Duration, english bool) ([]NotificationSummaryItem, int) {
if h.agentHandler == nil || h.agentHandler.tasks == nil {
return nil, 0
}
tasks := h.agentHandler.tasks.GetActiveTasks()
now := time.Now()
items := make([]NotificationSummaryItem, 0, len(tasks))
for _, t := range tasks {
if t == nil {
continue
}
if now.Sub(t.StartedAt) >= threshold {
items = append(items, NotificationSummaryItem{
ID: "task_long:" + t.ConversationID,
Level: "p1",
Type: "long_running_tasks",
Title: i18nText(english, "长时间运行任务", "Long Running Task"),
Desc: i18nText(english, "会话 "+t.ConversationID+" 运行超过 15 分钟", "Conversation "+t.ConversationID+" has been running over 15 minutes"),
Ts: t.StartedAt.UTC().Format(time.RFC3339),
Count: 1,
Actionable: true,
Read: false,
ConversationID: t.ConversationID,
})
}
}
return items, len(items)
}
func (h *NotificationHandler) summarizeCompletedTasksSince(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int) {
if h.agentHandler == nil || h.agentHandler.tasks == nil {
return nil, 0
}
since := time.UnixMilli(sinceMs)
completed := h.agentHandler.tasks.GetCompletedTasks()
items := make([]NotificationSummaryItem, 0, limit)
for _, t := range completed {
if t == nil {
continue
}
if t.CompletedAt.After(since) {
items = append(items, NotificationSummaryItem{
ID: "task_completed:" + t.ConversationID + ":" + strconv.FormatInt(t.CompletedAt.Unix(), 10),
Level: "p2",
Type: "task_completed",
Title: i18nText(english, "任务完成", "Task Completed"),
Desc: i18nText(english, "会话 "+t.ConversationID+" 已完成", "Conversation "+t.ConversationID+" completed"),
Ts: t.CompletedAt.UTC().Format(time.RFC3339),
Count: 1,
Actionable: false,
Read: false,
ConversationID: t.ConversationID,
})
if len(items) >= limit {
break
}
}
}
return items, len(items)
}
func buildPlaceholders(n int) string {
if n <= 0 {
return ""
}
out := make([]string, 0, n)
for i := 0; i < n; i++ {
out = append(out, "?")
}
return strings.Join(out, ",")
}
func (h *NotificationHandler) readStatesByIDs(ids []string) (map[string]bool, error) {
result := make(map[string]bool, len(ids))
if len(ids) == 0 {
return result, nil
}
holders := buildPlaceholders(len(ids))
query := "SELECT event_id FROM notification_reads WHERE event_id IN (" + holders + ")"
args := make([]interface{}, 0, len(ids))
for _, id := range ids {
args = append(args, id)
}
rows, err := h.db.Query(query, args...)
if err != nil {
return result, err
}
defer rows.Close()
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
continue
}
result[id] = true
}
return result, nil
}
func (h *NotificationHandler) applyReadStates(items []NotificationSummaryItem) ([]NotificationSummaryItem, error) {
markableIDs := make([]string, 0, len(items))
for _, item := range items {
if item.Actionable {
continue
}
markableIDs = append(markableIDs, item.ID)
}
readMap, err := h.readStatesByIDs(markableIDs)
if err != nil {
return items, err
}
for i := range items {
if items[i].Actionable {
items[i].Read = false
continue
}
items[i].Read = readMap[items[i].ID]
}
return items, nil
}
func filterVisibleItems(items []NotificationSummaryItem) []NotificationSummaryItem {
out := make([]NotificationSummaryItem, 0, len(items))
for _, item := range items {
if item.Actionable || !item.Read {
out = append(out, item)
}
}
return out
}
func countP0(items []NotificationSummaryItem) int {
total := 0
for _, item := range items {
if item.Level == "p0" {
if item.Count > 0 {
total += item.Count
} else {
total++
}
}
}
return total
}
func countUnread(items []NotificationSummaryItem) int {
total := 0
for _, item := range items {
if item.Actionable || !item.Read {
if item.Count > 0 {
total += item.Count
} else {
total++
}
}
}
return total
}
func createNotificationReadTableIfNeeded(db *database.DB) error {
if db == nil {
return fmt.Errorf("db is nil")
}
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS notification_reads (
event_id TEXT PRIMARY KEY,
read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`)
if err != nil {
return err
}
_, idxErr := db.Exec(`CREATE INDEX IF NOT EXISTS idx_notification_reads_read_at ON notification_reads(read_at DESC);`)
return idxErr
}
func pruneNotificationReads(db *database.DB, maxRows int) error {
if db == nil {
return fmt.Errorf("db is nil")
}
if maxRows <= 0 {
return nil
}
_, err := db.Exec(`
DELETE FROM notification_reads
WHERE event_id NOT IN (
SELECT event_id
FROM notification_reads
ORDER BY read_at DESC, rowid DESC
LIMIT ?
)
`, maxRows)
return err
}
type markReadRequest struct {
EventIDs []string `json:"eventIds"`
}
func normalizeMarkableEventID(id string) (string, bool) {
v := strings.TrimSpace(id)
if v == "" {
return "", false
}
// 仅允许“可读后隐藏”的信息类事件;Actionable 事件不参与 read 标记。
allowedPrefixes := []string{
"vuln:",
"exec_failed:",
"task_completed:",
"c2evt:",
}
for _, prefix := range allowedPrefixes {
if strings.HasPrefix(v, prefix) {
return v, true
}
}
return "", false
}
// MarkRead 按事件 ID 标记已读
func (h *NotificationHandler) MarkRead(c *gin.Context) {
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare notification read table"})
return
}
var req markReadRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if len(req.EventIDs) == 0 {
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": 0})
return
}
tx, err := h.db.Begin()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to begin transaction"})
return
}
defer func() {
_ = tx.Rollback()
}()
stmt, err := tx.Prepare(`
INSERT INTO notification_reads(event_id, read_at)
VALUES(?, CURRENT_TIMESTAMP)
ON CONFLICT(event_id) DO UPDATE SET read_at = CURRENT_TIMESTAMP
`)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare statement"})
return
}
defer stmt.Close()
marked := 0
for _, raw := range req.EventIDs {
id, ok := normalizeMarkableEventID(raw)
if !ok {
continue
}
if _, err := stmt.Exec(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mark read"})
return
}
marked++
}
if err := tx.Commit(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to commit read marks"})
return
}
if err := pruneNotificationReads(h.db, notificationReadMaxRows); err != nil {
h.logger.Warn("裁剪通知已读记录失败", zap.Error(err))
}
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": marked})
}
// GetSummary 返回通知聚合视图(用于头部铃铛)
func (h *NotificationHandler) GetSummary(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
h.logger.Warn("初始化通知已读表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initialize notification read table"})
return
}
english := requestWantsEnglish(c)
sinceMs := normalizeSinceMs(parseSinceMs(c.Query("since")))
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("limit", "50")))
if limit <= 0 {
limit = 50
}
if limit > 200 {
limit = 200
}
hitlItems, err := h.loadPendingHITLItems(limit, english)
if err != nil {
h.logger.Warn("加载 HITL 通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize hitl notifications"})
return
}
vulnItems, vulnCounts, err := h.loadVulnerabilityItems(sinceMs, limit, english)
if err != nil {
h.logger.Warn("加载漏洞通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize vulnerabilities"})
return
}
c2OnlineItems, c2OnlineCount, err := h.loadC2SessionOnlineEvents(sinceMs, limit, english)
if err != nil {
h.logger.Warn("加载 C2 会话上线通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize c2 session events"})
return
}
longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english)
completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english)
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(c2OnlineItems)+len(longRunningItems)+len(completedItems))
items = append(items, hitlItems...)
items = append(items, vulnItems...)
items = append(items, c2OnlineItems...)
items = append(items, longRunningItems...)
items = append(items, completedItems...)
items, err = h.applyReadStates(items)
if err != nil {
h.logger.Warn("加载通知已读状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load notification read states"})
return
}
items = filterVisibleItems(items)
sort.Slice(items, func(i, j int) bool {
ti, errI := time.Parse(time.RFC3339, items[i].Ts)
tj, errJ := time.Parse(time.RFC3339, items[j].Ts)
if errI != nil || errJ != nil {
return i < j
}
return ti.After(tj)
})
p0Count := countP0(items)
unreadCount := countUnread(items)
c.JSON(http.StatusOK, NotificationSummaryResponse{
SinceMs: sinceMs,
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
P0Count: p0Count,
UnreadCount: unreadCount,
Counts: map[string]int{
"hitlPending": len(hitlItems),
"newCriticalVulns": vulnCounts["newCriticalVulns"],
"newHighVulns": vulnCounts["newHighVulns"],
"newMediumVulns": vulnCounts["newMediumVulns"],
"newLowVulns": vulnCounts["newLowVulns"],
"newInfoVulns": vulnCounts["newInfoVulns"],
"failedExecutions": 0,
"longRunningTasks": longRunningCount,
"completedTasks": completedCount,
"c2SessionOnline": c2OnlineCount,
},
Items: items,
})
}
File diff suppressed because it is too large Load Diff
+174
View File
@@ -0,0 +1,174 @@
package handler
// apiDocI18n 为 OpenAPI 文档提供 x-i18n-* 扩展键,供前端 apiDocs 国际化使用。
// 前端通过 apiDocs.tags.* / apiDocs.summary.* / apiDocs.response.* 翻译。
var apiDocI18nTagToKey = map[string]string{
"认证": "auth", "对话管理": "conversationManagement", "对话交互": "conversationInteraction",
"批量任务": "batchTasks", "对话分组": "conversationGroups", "漏洞管理": "vulnerabilityManagement",
"角色管理": "roleManagement", "Skills管理": "skillsManagement", "监控": "monitoring",
"配置管理": "configManagement", "外部MCP管理": "externalMCPManagement", "攻击链": "attackChain",
"知识库": "knowledgeBase", "MCP": "mcp",
"FOFA信息收集": "fofaRecon", "终端": "terminal", "WebShell管理": "webshellManagement",
"对话附件": "chatUploads", "机器人集成": "robotIntegration", "多代理Markdown": "markdownAgents",
}
var apiDocI18nSummaryToKey = map[string]string{
"用户登录": "login", "用户登出": "logout", "修改密码": "changePassword", "验证Token": "validateToken",
"创建对话": "createConversation", "列出对话": "listConversations", "查看对话详情": "getConversationDetail",
"更新对话": "updateConversation", "删除对话": "deleteConversation", "获取对话结果": "getConversationResult",
"发送消息并获取AI回复(非流式)": "sendMessageNonStream", "发送消息并获取AI回复(流式)": "sendMessageStream",
"取消任务": "cancelTask", "列出运行中的任务": "listRunningTasks", "列出已完成的任务": "listCompletedTasks",
"创建批量任务队列": "createBatchQueue", "列出批量任务队列": "listBatchQueues", "获取批量任务队列": "getBatchQueue",
"删除批量任务队列": "deleteBatchQueue", "启动批量任务队列": "startBatchQueue", "暂停批量任务队列": "pauseBatchQueue",
"添加任务到队列": "addTaskToQueue", "SQL注入扫描": "sqlInjectionScan", "端口扫描": "portScan",
"更新批量任务": "updateBatchTask", "删除批量任务": "deleteBatchTask",
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
"从分组移除对话": "removeConversationFromGroup",
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
"获取Skill统计": "getSkillStats", "清空Skill统计": "clearSkillStats", "获取Skill": "getSkill",
"更新Skill": "updateSkill", "删除Skill": "deleteSkill", "获取绑定角色": "getBoundRoles",
"获取监控信息": "getMonitorInfo", "获取执行记录": "getExecutionRecords", "删除执行记录": "deleteExecutionRecord",
"批量删除执行记录": "batchDeleteExecutionRecords", "获取统计信息": "getStats",
"获取配置": "getConfig", "更新配置": "updateConfig", "获取工具配置": "getToolConfig", "应用配置": "applyConfig",
"列出外部MCP": "listExternalMCP", "获取外部MCP统计": "getExternalMCPStats", "获取外部MCP": "getExternalMCP",
"添加或更新外部MCP": "addOrUpdateExternalMCP", "stdio模式配置": "stdioModeConfig", "SSE模式配置": "sseModeConfig",
"删除外部MCP": "deleteExternalMCP", "启动外部MCP": "startExternalMCP", "停止外部MCP": "stopExternalMCP",
"获取攻击链": "getAttackChain", "重新生成攻击链": "regenerateAttackChain",
"设置对话置顶": "pinConversation", "设置分组置顶": "pinGroup", "设置分组中对话的置顶": "pinGroupConversation",
"获取分类": "getCategories", "列出知识项": "listKnowledgeItems", "创建知识项": "createKnowledgeItem",
"获取知识项": "getKnowledgeItem", "更新知识项": "updateKnowledgeItem", "删除知识项": "deleteKnowledgeItem",
"获取索引状态": "getIndexStatus", "重建索引": "rebuildIndex", "扫描知识库": "scanKnowledgeBase",
"搜索知识库": "searchKnowledgeBase", "基础搜索": "basicSearch", "按风险类型搜索": "searchByRiskType",
"获取检索日志": "getRetrievalLogs", "删除检索日志": "deleteRetrievalLog",
"MCP端点": "mcpEndpoint", "列出所有工具": "listAllTools", "调用工具": "invokeTool", "初始化连接": "initConnection",
"成功响应": "successResponse", "错误响应": "errorResponse",
// 新增缺失端点
"删除对话轮次": "deleteConversationTurn", "获取消息过程详情": "getMessageProcessDetails",
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
"获取所有分组映射": "getAllGroupMappings",
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
"测试OpenAI API连接": "testOpenAI",
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
"获取AI对话历史": "getWebshellAIHistory", "列出AI对话": "listWebshellAIConversations",
"执行WebShell命令": "webshellExec", "WebShell文件操作": "webshellFileOp",
"列出附件": "listChatUploads", "上传附件": "uploadChatFile", "删除附件": "deleteChatUpload",
"下载附件": "downloadChatUpload", "获取附件文本内容": "getChatUploadContent",
"写入附件文本内容": "putChatUploadContent", "创建附件目录": "mkdirChatUpload", "重命名附件": "renameChatUpload",
"企业微信回调验证": "wecomCallbackVerify", "企业微信消息回调": "wecomCallbackMessage",
"钉钉消息回调": "dingtalkCallback", "飞书消息回调": "larkCallback", "测试机器人消息处理": "testRobot",
"列出Markdown代理": "listMarkdownAgents", "创建Markdown代理": "createMarkdownAgent",
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
"批量获取工具名称": "batchGetToolNames",
"获取知识库统计": "getKnowledgeStats",
}
var apiDocI18nResponseDescToKey = map[string]string{
"获取成功": "getSuccess", "未授权": "unauthorized", "未授权,需要有效的Token": "unauthorizedToken",
"创建成功": "createSuccess", "请求参数错误": "badRequest", "对话不存在": "conversationNotFound",
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
"删除成功": "deleteSuccess", "队列不存在": "queueNotFound", "启动成功": "startSuccess",
"暂停成功": "pauseSuccess", "添加成功": "addSuccess",
"任务不存在": "taskNotFound", "对话或分组不存在": "conversationOrGroupNotFound",
"取消请求已提交": "cancelSubmitted", "未找到正在执行的任务": "noRunningTask",
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events": "streamResponse",
// 新增缺失端点响应
"参数错误或删除失败": "badRequestOrDeleteFailed",
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
"文件下载": "fileDownload", "文件不存在": "fileNotFound", "写入成功": "writeSuccess",
"重命名成功": "renameSuccess", "验证成功,返回解密后的echostr": "wecomVerifySuccess",
"处理成功": "processSuccess", "代理不存在": "agentNotFound", "保存成功": "saveSuccess",
"操作结果": "operationResult", "执行结果": "executionResult", "连接不存在": "connectionNotFound",
}
// enrichSpecWithI18nKeys 在 spec 的每个 operation 上写入 x-i18n-tags、x-i18n-summary
// 在每个 response 上写入 x-i18n-description,供前端按 key 做国际化。
func enrichSpecWithI18nKeys(spec map[string]interface{}) {
paths, _ := spec["paths"].(map[string]interface{})
if paths == nil {
return
}
for _, pathItem := range paths {
pm, _ := pathItem.(map[string]interface{})
if pm == nil {
continue
}
for _, method := range []string{"get", "post", "put", "delete", "patch"} {
opVal, ok := pm[method]
if !ok {
continue
}
op, _ := opVal.(map[string]interface{})
if op == nil {
continue
}
// x-i18n-tags: 与 tags 一一对应的 i18n 键数组(spec 中 tags 为 []string
switch tags := op["tags"].(type) {
case []string:
if len(tags) > 0 {
keys := make([]string, 0, len(tags))
for _, s := range tags {
if k := apiDocI18nTagToKey[s]; k != "" {
keys = append(keys, k)
} else {
keys = append(keys, s)
}
}
op["x-i18n-tags"] = keys
}
case []interface{}:
if len(tags) > 0 {
keys := make([]interface{}, 0, len(tags))
for _, t := range tags {
if s, ok := t.(string); ok {
if k := apiDocI18nTagToKey[s]; k != "" {
keys = append(keys, k)
} else {
keys = append(keys, s)
}
}
}
if len(keys) > 0 {
op["x-i18n-tags"] = keys
}
}
}
// x-i18n-summary
if summary, _ := op["summary"].(string); summary != "" {
if k := apiDocI18nSummaryToKey[summary]; k != "" {
op["x-i18n-summary"] = k
}
}
// responses -> 每个 status -> x-i18n-description
if respMap, _ := op["responses"].(map[string]interface{}); respMap != nil {
for _, rv := range respMap {
if r, _ := rv.(map[string]interface{}); r != nil {
if desc, _ := r["description"].(string); desc != "" {
if k := apiDocI18nResponseDescToKey[desc]; k != "" {
r["x-i18n-description"] = k
}
}
}
}
}
}
}
}
+410
View File
@@ -0,0 +1,410 @@
package handler
import (
"net/http"
"strconv"
"strings"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/project"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const maxProjectDescriptionRunes = 4000
func clampProjectDescription(s string) string {
r := []rune(s)
if len(r) <= maxProjectDescriptionRunes {
return s
}
return string(r[:maxProjectDescriptionRunes])
}
// ProjectHandler 项目管理处理器。
type ProjectHandler struct {
db *database.DB
logger *zap.Logger
}
// NewProjectHandler 创建项目管理处理器。
func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler {
return &ProjectHandler{db: db, logger: logger}
}
type createProjectRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
ScopeJSON string `json:"scope_json"`
Status string `json:"status"`
}
// updateProjectRequest 部分更新:字段省略表示不修改;传 null 或 "" 可清空字符串字段。
type updateProjectRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
ScopeJSON *string `json:"scope_json"`
Status *string `json:"status"`
Pinned *bool `json:"pinned"`
}
// CreateProject POST /api/projects
func (h *ProjectHandler) CreateProject(c *gin.Context) {
var req createProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p := &database.Project{
Name: strings.TrimSpace(req.Name),
Description: clampProjectDescription(req.Description),
ScopeJSON: req.ScopeJSON,
Status: strings.TrimSpace(req.Status),
}
created, err := h.db.CreateProject(p)
if err != nil {
h.logger.Error("创建项目失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, created)
}
// GetDashboardSummary GET /api/projects/dashboard-summary
func (h *ProjectHandler) GetDashboardSummary(c *gin.Context) {
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("fact_limit", "5")))
if limit <= 0 {
limit = 5
}
if limit > 50 {
limit = 50
}
summary, err := h.db.GetProjectDashboardSummary(limit)
if err != nil {
h.logger.Error("获取项目仪表盘摘要失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if summary.RecentFacts == nil {
summary.RecentFacts = []database.ProjectDashboardFact{}
}
c.JSON(http.StatusOK, summary)
}
// ListProjects GET /api/projects
func (h *ProjectHandler) ListProjects(c *gin.Context) {
status := c.Query("status")
search := c.Query("search")
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50"))
offset, _ := strconv.Atoi(c.Query("offset"))
if limit <= 0 {
limit = 50
}
if limit > 500 {
limit = 500
}
list, err := h.db.ListProjects(status, search, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.Project{}
}
total, err := h.db.CountProjects(status, search)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"projects": list,
"total": total,
"limit": limit,
"offset": offset,
})
}
// GetProjectStats GET /api/projects/:id/stats
func (h *ProjectHandler) GetProjectStats(c *gin.Context) {
stats, err := project.GetProjectStats(h.db, c.Param("id"))
if err != nil {
if strings.Contains(err.Error(), "不存在") {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}
// ListProjectConversations GET /api/projects/:id/conversations
func (h *ProjectHandler) ListProjectConversations(c *gin.Context) {
projectID := c.Param("id")
if _, err := h.db.GetProject(projectID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
offset, _ := strconv.Atoi(c.Query("offset"))
list, err := h.db.ListConversationsByProjectID(projectID, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.Conversation{}
}
total, _ := h.db.CountConversationsByProjectID(projectID)
c.JSON(http.StatusOK, gin.H{
"conversations": list,
"total": total,
"limit": limit,
"offset": offset,
})
}
// GetProject GET /api/projects/:id
func (h *ProjectHandler) GetProject(c *gin.Context) {
p, err := h.db.GetProject(c.Param("id"))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
c.JSON(http.StatusOK, p)
}
// UpdateProject PUT /api/projects/:id
func (h *ProjectHandler) UpdateProject(c *gin.Context) {
id := c.Param("id")
p, err := h.db.GetProject(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
return
}
var req updateProjectRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name != nil {
if s := strings.TrimSpace(*req.Name); s != "" {
p.Name = s
}
}
if req.Description != nil {
p.Description = clampProjectDescription(*req.Description)
}
if req.ScopeJSON != nil {
p.ScopeJSON = *req.ScopeJSON
}
if req.Status != nil {
if s := strings.TrimSpace(*req.Status); s != "" {
p.Status = s
}
}
if req.Pinned != nil {
p.Pinned = *req.Pinned
}
if err := h.db.UpdateProject(p); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, p)
}
// DeleteProject DELETE /api/projects/:id
func (h *ProjectHandler) DeleteProject(c *gin.Context) {
if err := h.db.DeleteProject(c.Param("id")); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
type upsertFactRequest struct {
FactKey string `json:"fact_key" binding:"required"`
Category string `json:"category"`
Summary string `json:"summary" binding:"required"`
Body string `json:"body"`
Confidence string `json:"confidence"`
Pinned bool `json:"pinned"`
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
}
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
type updateFactRequest struct {
FactKey *string `json:"fact_key"`
Category *string `json:"category"`
Summary *string `json:"summary"`
Body *string `json:"body"`
Confidence *string `json:"confidence"`
Pinned *bool `json:"pinned"`
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
ClearBody bool `json:"clear_body"`
}
// ListFacts GET /api/projects/:id/facts fact_key 查询参数可获取单条详情)
func (h *ProjectHandler) ListFacts(c *gin.Context) {
projectID := c.Param("id")
if key := strings.TrimSpace(c.Query("fact_key")); key != "" {
f, err := h.db.GetProjectFactByKey(projectID, key)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, f)
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
offset, _ := strconv.Atoi(c.Query("offset"))
filter := database.ProjectFactListFilter{
Category: c.Query("category"),
Confidence: c.Query("confidence"),
Search: c.Query("search"),
RelatedVulnerabilityID: c.Query("related_vulnerability_id"),
}
if c.Query("exclude_deprecated") == "1" || c.Query("exclude_deprecated") == "true" {
filter.ExcludeDeprecated = true
}
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []*database.ProjectFact{}
}
if sparseOnly := c.Query("sparse_only"); sparseOnly == "1" || sparseOnly == "true" {
filtered := make([]*database.ProjectFact, 0, len(list))
for _, f := range list {
if project.IsSparseFactBody(f.Category, f.FactKey, f.Body) {
filtered = append(filtered, f)
}
}
list = filtered
}
c.JSON(http.StatusOK, list)
}
// CreateFact POST /api/projects/:id/facts
func (h *ProjectHandler) CreateFact(c *gin.Context) {
var req upsertFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
f := &database.ProjectFact{
ProjectID: c.Param("id"),
FactKey: req.FactKey,
Category: req.Category,
Summary: req.Summary,
Body: req.Body,
Confidence: req.Confidence,
Pinned: req.Pinned,
RelatedVulnerabilityID: req.RelatedVulnerabilityID,
}
created, err := h.db.UpsertProjectFact(f)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, created)
}
// UpdateFact PUT /api/projects/:id/facts/:factId
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
var req updateFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.FactKey != nil {
if k := strings.TrimSpace(*req.FactKey); k != "" {
existing.FactKey = k
}
}
if req.Category != nil && strings.TrimSpace(*req.Category) != "" {
existing.Category = *req.Category
}
if req.Summary != nil && strings.TrimSpace(*req.Summary) != "" {
existing.Summary = *req.Summary
}
if req.ClearBody {
existing.Body = ""
} else if req.Body != nil {
existing.Body = *req.Body
}
if req.Confidence != nil && strings.TrimSpace(*req.Confidence) != "" {
existing.Confidence = *req.Confidence
}
if req.Pinned != nil {
existing.Pinned = *req.Pinned
}
if req.RelatedVulnerabilityID != nil {
existing.RelatedVulnerabilityID = *req.RelatedVulnerabilityID
}
updated, err := h.db.UpsertProjectFact(existing)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, updated)
}
// DeleteFact DELETE /api/projects/:id/facts/:factId
func (h *ProjectHandler) DeleteFact(c *gin.Context) {
existing, err := h.db.GetProjectFact(c.Param("factId"))
if err != nil || existing.ProjectID != c.Param("id") {
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
return
}
if err := h.db.DeleteProjectFact(existing.ID); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
type deprecateFactRequest struct {
FactKey string `json:"fact_key" binding:"required"`
}
// DeprecateFact POST /api/projects/:id/facts/deprecate
func (h *ProjectHandler) DeprecateFact(c *gin.Context) {
var req deprecateFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
type restoreFactRequest struct {
FactKey string `json:"fact_key" binding:"required"`
Confidence string `json:"confidence"` // 可选:confirmed | tentative,默认 tentative
}
// RestoreFact POST /api/projects/:id/facts/restore
func (h *ProjectHandler) RestoreFact(c *gin.Context) {
var req restoreFactRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.RestoreProjectFact(c.Param("id"), req.FactKey, req.Confidence); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
+48
View File
@@ -0,0 +1,48 @@
package handler
import (
"strings"
"cyberstrike-ai/internal/project"
"go.uber.org/zap"
)
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
if h == nil || h.db == nil || h.config == nil {
return ""
}
if !h.config.Project.Enabled {
return ""
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
projectID, err := h.db.GetConversationProjectID(conversationID)
if err != nil || projectID == "" {
return ""
}
block, err := project.BuildProjectBlackboardBlock(h.db, projectID, h.config.Project)
if err != nil {
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
return ""
}
return strings.TrimSpace(block)
}
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
func (h *AgentHandler) conversationProjectID(conversationID string) string {
if h == nil || h.db == nil {
return ""
}
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
projectID, err := h.db.GetConversationProjectID(conversationID)
if err != nil {
return ""
}
return strings.TrimSpace(projectID)
}
+18
View File
@@ -0,0 +1,18 @@
package handler
import (
"strings"
"cyberstrike-ai/internal/config"
)
// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。
func effectiveProjectID(cfg *config.Config, explicit string) string {
if pid := strings.TrimSpace(explicit); pid != "" {
return pid
}
if cfg != nil {
return strings.TrimSpace(cfg.Project.DefaultProjectID)
}
return ""
}
File diff suppressed because it is too large Load Diff
+469
View File
@@ -0,0 +1,469 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// RoleHandler 角色处理器
type RoleHandler struct {
config *config.Config
configPath string
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *RoleHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewRoleHandler 创建新的角色处理器
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
return &RoleHandler{
config: cfg,
configPath: configPath,
logger: logger,
}
}
// GetRoles 获取所有角色
func (h *RoleHandler) GetRoles(c *gin.Context) {
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
for key, role := range h.config.Roles {
// 确保角色的key与name一致
if role.Name == "" {
role.Name = key
}
roles = append(roles, role)
}
c.JSON(http.StatusOK, gin.H{
"roles": roles,
})
}
// GetRole 获取单个角色
func (h *RoleHandler) GetRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
role, exists := h.config.Roles[roleName]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 确保角色的name与key一致
if role.Name == "" {
role.Name = roleName
}
c.JSON(http.StatusOK, gin.H{
"role": role,
})
}
// UpdateRole 更新角色
func (h *RoleHandler) UpdateRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 确保角色名称与请求中的name一致
if req.Name == "" {
req.Name = roleName
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 删除所有与角色name相同但key不同的旧角色(避免重复)
// 使用角色name作为key,确保唯一性
finalKey := req.Name
keysToDelete := make([]string, 0)
for key := range h.config.Roles {
// 如果key与最终的key不同,但name相同,则标记为删除
if key != finalKey {
role := h.config.Roles[key]
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = key
}
if role.Name == req.Name {
keysToDelete = append(keysToDelete, key)
}
}
}
// 删除旧的角色
for _, key := range keysToDelete {
delete(h.config.Roles, key)
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
}
// 如果当前更新的key与最终key不同,也需要删除旧的
if roleName != finalKey {
delete(h.config.Roles, roleName)
}
// 如果角色名称改变,需要删除旧文件
if roleName != finalKey {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 删除旧的角色文件
oldSafeFileName := sanitizeFileName(roleName)
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
if _, err := os.Stat(oldRoleFileYaml); err == nil {
if err := os.Remove(oldRoleFileYaml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
}
}
if _, err := os.Stat(oldRoleFileYml); err == nil {
if err := os.Remove(oldRoleFileYml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
}
}
}
// 使用角色name作为key来保存(确保唯一性)
h.config.Roles[finalKey] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
}
c.JSON(http.StatusOK, gin.H{
"message": "角色已更新",
"role": req,
})
}
// CreateRole 创建新角色
func (h *RoleHandler) CreateRole(c *gin.Context) {
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 检查角色是否已存在
if _, exists := h.config.Roles[req.Name]; exists {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
return
}
// 创建角色(默认启用)
if !req.Enabled {
req.Enabled = true
}
h.config.Roles[req.Name] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("创建角色", zap.String("roleName", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
}
c.JSON(http.StatusOK, gin.H{
"message": "角色已创建",
"role": req,
})
}
// DeleteRole 删除角色
func (h *RoleHandler) DeleteRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
if _, exists := h.config.Roles[roleName]; !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 不允许删除"默认"角色
if roleName == "默认" {
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
return
}
delete(h.config.Roles, roleName)
// 删除对应的角色文件
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 尝试删除角色文件(.yaml 和 .yml)
safeFileName := sanitizeFileName(roleName)
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
// 删除 .yaml 文件(如果存在)
if _, err := os.Stat(roleFileYaml); err == nil {
if err := os.Remove(roleFileYaml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
}
}
// 删除 .yml 文件(如果存在)
if _, err := os.Stat(roleFileYml); err == nil {
if err := os.Remove(roleFileYml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
}
}
h.logger.Info("删除角色", zap.String("roleName", roleName))
if h.audit != nil {
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
}
c.JSON(http.StatusOK, gin.H{
"message": "角色已删除",
})
}
// saveConfig 保存配置到目录中的文件
func (h *RoleHandler) saveConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeFileName 将角色名称转换为安全的文件名
func sanitizeFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// updateRolesConfig 更新角色配置
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
root := doc.Content[0]
rolesNode := ensureMap(root, "roles")
// 清空现有角色
if rolesNode.Kind == yaml.MappingNode {
rolesNode.Content = nil
}
// 添加新角色(使用name作为key,确保唯一性)
if cfg.Roles != nil {
// 先建立一个以name为key的map,去重(保留最后一个)
rolesByName := make(map[string]config.RoleConfig)
for roleKey, role := range cfg.Roles {
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = roleKey
}
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
rolesByName[role.Name] = role
}
// 将去重后的角色写入YAML
for roleName, role := range rolesByName {
roleNode := ensureMap(rolesNode, roleName)
setStringInMap(roleNode, "name", role.Name)
setStringInMap(roleNode, "description", role.Description)
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
if role.Icon != "" {
setStringInMap(roleNode, "icon", role.Icon)
}
setBoolInMap(roleNode, "enabled", role.Enabled)
// 添加工具列表(优先使用tools字段)
if len(role.Tools) > 0 {
toolsNode := ensureArray(roleNode, "tools")
toolsNode.Content = nil
for _, toolKey := range role.Tools {
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
toolsNode.Content = append(toolsNode.Content, toolNode)
}
} else if len(role.MCPs) > 0 {
// 向后兼容:如果没有tools但有mcps,保存mcps
mcpsNode := ensureArray(roleNode, "mcps")
mcpsNode.Content = nil
for _, mcpName := range role.MCPs {
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
}
}
}
}
}
// ensureArray 确保数组中存在指定key的数组节点
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
_, valueNode := ensureKeyValue(parent, key)
if valueNode.Kind != yaml.SequenceNode {
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
}
return valueNode
}
+710
View File
@@ -0,0 +1,710 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/skillpackage"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
// SkillsHandler Skills处理器(磁盘 + Eino 规范;运行时由 Eino ADK skill 中间件加载)
type SkillsHandler struct {
config *config.Config
configPath string
logger *zap.Logger
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *SkillsHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewSkillsHandler 创建新的Skills处理器
func NewSkillsHandler(cfg *config.Config, configPath string, logger *zap.Logger) *SkillsHandler {
return &SkillsHandler{
config: cfg,
configPath: configPath,
logger: logger,
}
}
func (h *SkillsHandler) skillsRootAbs() string {
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
return skillsDir
}
// SetDB 设置数据库连接(用于获取调用统计)
func (h *SkillsHandler) SetDB(db *database.DB) {
h.db = db
}
// GetSkills 获取所有skills列表(支持分页和搜索)
func (h *SkillsHandler) GetSkills(c *gin.Context) {
allSummaries, err := skillpackage.ListSkillSummaries(h.skillsRootAbs())
if err != nil {
h.logger.Error("获取skills列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
searchKeyword := strings.TrimSpace(c.Query("search"))
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
for _, s := range allSummaries {
skillInfo := map[string]interface{}{
"id": s.ID,
"name": s.Name,
"dir_name": s.DirName,
"description": s.Description,
"version": s.Version,
"path": s.Path,
"tags": s.Tags,
"triggers": s.Triggers,
"script_count": s.ScriptCount,
"file_count": s.FileCount,
"progressive": s.Progressive,
"file_size": s.FileSize,
"mod_time": s.ModTime,
}
allSkillsInfo = append(allSkillsInfo, skillInfo)
}
filteredSkillsInfo := allSkillsInfo
if searchKeyword != "" {
keywordLower := strings.ToLower(searchKeyword)
filteredSkillsInfo = make([]map[string]interface{}, 0)
for _, skillInfo := range allSkillsInfo {
id := strings.ToLower(fmt.Sprintf("%v", skillInfo["id"]))
name := strings.ToLower(fmt.Sprintf("%v", skillInfo["name"]))
description := strings.ToLower(fmt.Sprintf("%v", skillInfo["description"]))
path := strings.ToLower(fmt.Sprintf("%v", skillInfo["path"]))
version := strings.ToLower(fmt.Sprintf("%v", skillInfo["version"]))
tagsJoined := ""
if tags, ok := skillInfo["tags"].([]string); ok {
tagsJoined = strings.ToLower(strings.Join(tags, " "))
}
trigJoined := ""
if tr, ok := skillInfo["triggers"].([]string); ok {
trigJoined = strings.ToLower(strings.Join(tr, " "))
}
if strings.Contains(id, keywordLower) ||
strings.Contains(name, keywordLower) ||
strings.Contains(description, keywordLower) ||
strings.Contains(path, keywordLower) ||
strings.Contains(version, keywordLower) ||
strings.Contains(tagsJoined, keywordLower) ||
strings.Contains(trigJoined, keywordLower) {
filteredSkillsInfo = append(filteredSkillsInfo, skillInfo)
}
}
}
// 分页参数
limit := 20 // 默认每页20条
offset := 0
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
// 允许更大的limit用于搜索场景,但设置一个合理的上限(10000)
if parsed <= 10000 {
limit = parsed
} else {
limit = 10000
}
}
}
if offsetStr := c.Query("offset"); offsetStr != "" {
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
offset = parsed
}
}
// 计算分页范围
total := len(filteredSkillsInfo)
start := offset
end := offset + limit
if start > total {
start = total
}
if end > total {
end = total
}
// 获取当前页的skill列表
var paginatedSkillsInfo []map[string]interface{}
if start < end {
paginatedSkillsInfo = filteredSkillsInfo[start:end]
} else {
paginatedSkillsInfo = []map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{
"skills": paginatedSkillsInfo,
"total": total,
"limit": limit,
"offset": offset,
})
}
// GetSkill 获取单个skill的详细信息
func (h *SkillsHandler) GetSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
resPath := strings.TrimSpace(c.Query("resource_path"))
if resPath == "" {
resPath = strings.TrimSpace(c.Query("skill_script_path"))
}
if resPath != "" {
content, err := skillpackage.ReadScriptText(h.skillsRootAbs(), skillName, resPath, 0)
if err != nil {
h.logger.Warn("读取skill资源失败", zap.String("skill", skillName), zap.String("path", resPath), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"skill": map[string]interface{}{
"id": skillName,
},
"resource": map[string]interface{}{
"path": resPath,
"content": content,
},
})
return
}
depthStr := strings.ToLower(strings.TrimSpace(c.DefaultQuery("depth", "full")))
section := strings.TrimSpace(c.Query("section"))
opt := skillpackage.LoadOptions{Section: section}
switch depthStr {
case "summary":
opt.Depth = "summary"
case "full", "":
opt.Depth = "full"
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "depth 仅支持 summary 或 full"})
return
}
skill, err := skillpackage.LoadSkill(h.skillsRootAbs(), skillName, opt)
if err != nil {
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()})
return
}
skillPath := skill.Path
skillFile := filepath.Join(skillPath, "SKILL.md")
fileInfo, _ := os.Stat(skillFile)
var fileSize int64
var modTime string
if fileInfo != nil {
fileSize = fileInfo.Size()
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
}
c.JSON(http.StatusOK, gin.H{
"skill": map[string]interface{}{
"id": skill.DirName,
"name": skill.Name,
"description": skill.Description,
"content": skill.Content,
"path": skill.Path,
"version": skill.Version,
"tags": skill.Tags,
"scripts": skill.Scripts,
"sections": skill.Sections,
"package_files": skill.PackageFiles,
"file_size": fileSize,
"mod_time": modTime,
"depth": depthStr,
"section": section,
},
})
}
// ListSkillPackageFiles lists all files in a skill directory (Agent Skills layout).
func (h *SkillsHandler) ListSkillPackageFiles(c *gin.Context) {
skillID := c.Param("name")
files, err := skillpackage.ListPackageFiles(h.skillsRootAbs(), skillID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"files": files})
}
// GetSkillPackageFile returns one file by relative path (?path=).
func (h *SkillsHandler) GetSkillPackageFile(c *gin.Context) {
skillID := c.Param("name")
rel := strings.TrimSpace(c.Query("path"))
if rel == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "query path is required"})
return
}
b, err := skillpackage.ReadPackageFile(h.skillsRootAbs(), skillID, rel, 0)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"path": rel, "content": string(b)})
}
// PutSkillPackageFile writes a file inside the skill package.
func (h *SkillsHandler) PutSkillPackageFile(c *gin.Context) {
skillID := c.Param("name")
var req struct {
Path string `json:"path" binding:"required"`
Content string `json:"content"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Path == "SKILL.md" {
if err := skillpackage.ValidateSkillMDPackage([]byte(req.Content), skillID); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
}
if err := skillpackage.WritePackageFile(h.skillsRootAbs(), skillID, req.Path, []byte(req.Content)); err != nil {
h.logger.Error("写入 skill 文件失败", zap.String("skill", skillID), zap.String("path", req.Path), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "saved", "path": req.Path})
}
// GetSkillBoundRoles 获取绑定指定skill的角色列表
func (h *SkillsHandler) GetSkillBoundRoles(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
boundRoles := h.getRolesBoundToSkill(skillName)
c.JSON(http.StatusOK, gin.H{
"skill": skillName,
"bound_roles": boundRoles,
"bound_count": len(boundRoles),
})
}
// getRolesBoundToSkill 预留:角色不再配置 skill 绑定,始终返回空列表。
func (h *SkillsHandler) getRolesBoundToSkill(skillName string) []string {
_ = skillName
return nil
}
// CreateSkill 创建新 skill(标准 Agent Skills:生成 SKILL.md + YAML front matter
func (h *SkillsHandler) CreateSkill(c *gin.Context) {
var req struct {
Name string `json:"name" binding:"required"`
Description string `json:"description" binding:"required"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if !isValidSkillName(req.Name) {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill 目录名须为小写字母、数字、连字符(与 Agent Skills name 一致)"})
return
}
manifest := &skillpackage.SkillManifest{
Name: req.Name,
Description: strings.TrimSpace(req.Description),
}
skillMD, err := skillpackage.BuildSkillMD(manifest, req.Content)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := skillpackage.ValidateSkillMDPackage(skillMD, req.Name); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
skillDir := filepath.Join(h.skillsRootAbs(), req.Name)
if err := os.MkdirAll(skillDir, 0755); err != nil {
h.logger.Error("创建skill目录失败", zap.String("skill", req.Name), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill目录失败: " + err.Error()})
return
}
if _, err := os.Stat(filepath.Join(skillDir, "SKILL.md")); err == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill已存在"})
return
}
if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), skillMD, 0644); err != nil {
h.logger.Error("创建 SKILL.md 失败", zap.String("skill", req.Name), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建 SKILL.md 失败: " + err.Error()})
return
}
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
}
c.JSON(http.StatusOK, gin.H{
"message": "skill已创建",
"skill": map[string]interface{}{
"name": req.Name,
"path": skillDir,
},
})
}
// UpdateSkill 更新 SKILL.md(保留 front matter 中除 description 外的字段;可选覆盖 description
func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
var req struct {
Description string `json:"description"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
mdPath := filepath.Join(h.skillsRootAbs(), skillName, "SKILL.md")
raw, err := os.ReadFile(mdPath)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()})
return
}
m, _, err := skillpackage.ParseSkillMD(raw)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Description != "" {
m.Description = strings.TrimSpace(req.Description)
}
skillMD, err := skillpackage.BuildSkillMD(m, req.Content)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := skillpackage.ValidateSkillMDPackage(skillMD, skillName); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
skillDir := filepath.Join(h.skillsRootAbs(), skillName)
if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), skillMD, 0644); err != nil {
h.logger.Error("更新 SKILL.md 失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新 SKILL.md 失败: " + err.Error()})
return
}
h.logger.Info("更新skill成功", zap.String("skill", skillName))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
}
c.JSON(http.StatusOK, gin.H{
"message": "skill已更新",
})
}
// DeleteSkill 删除skill
func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
// 检查是否有角色绑定了该skill,如果有则自动移除绑定
affectedRoles := h.removeSkillFromRoles(skillName)
if len(affectedRoles) > 0 {
h.logger.Info("从角色中移除skill绑定",
zap.String("skill", skillName),
zap.Strings("roles", affectedRoles))
}
skillDir := filepath.Join(h.skillsRootAbs(), skillName)
if err := os.RemoveAll(skillDir); err != nil {
h.logger.Error("删除skill失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除skill失败: " + err.Error()})
return
}
responseMsg := "skill已删除"
if len(affectedRoles) > 0 {
responseMsg = fmt.Sprintf("skill已删除,已自动从 %d 个角色中移除绑定: %s",
len(affectedRoles), strings.Join(affectedRoles, ", "))
}
h.logger.Info("删除skill成功", zap.String("skill", skillName))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
"affected_roles": affectedRoles,
})
}
c.JSON(http.StatusOK, gin.H{
"message": responseMsg,
"affected_roles": affectedRoles,
})
}
// GetSkillStats 获取skills调用统计信息
func (h *SkillsHandler) GetSkillStats(c *gin.Context) {
skillList, err := skillpackage.ListSkillDirNames(h.skillsRootAbs())
if err != nil {
h.logger.Error("获取skills列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
skillsDir := h.skillsRootAbs()
// 从数据库加载调用统计
var skillStatsMap map[string]*database.SkillStats
if h.db != nil {
dbStats, err := h.db.LoadSkillStats()
if err != nil {
h.logger.Warn("从数据库加载Skills统计信息失败", zap.Error(err))
skillStatsMap = make(map[string]*database.SkillStats)
} else {
skillStatsMap = dbStats
}
} else {
skillStatsMap = make(map[string]*database.SkillStats)
}
// 构建统计信息(包含所有skills,即使没有调用记录)
statsList := make([]map[string]interface{}, 0, len(skillList))
totalCalls := 0
totalSuccess := 0
totalFailed := 0
for _, skillName := range skillList {
stat, exists := skillStatsMap[skillName]
if !exists {
stat = &database.SkillStats{
SkillName: skillName,
TotalCalls: 0,
SuccessCalls: 0,
FailedCalls: 0,
}
}
totalCalls += stat.TotalCalls
totalSuccess += stat.SuccessCalls
totalFailed += stat.FailedCalls
lastCallTimeStr := ""
if stat.LastCallTime != nil {
lastCallTimeStr = stat.LastCallTime.Format("2006-01-02 15:04:05")
}
statsList = append(statsList, map[string]interface{}{
"skill_name": stat.SkillName,
"total_calls": stat.TotalCalls,
"success_calls": stat.SuccessCalls,
"failed_calls": stat.FailedCalls,
"last_call_time": lastCallTimeStr,
})
}
c.JSON(http.StatusOK, gin.H{
"total_skills": len(skillList),
"total_calls": totalCalls,
"total_success": totalSuccess,
"total_failed": totalFailed,
"skills_dir": skillsDir,
"stats": statsList,
})
}
// ClearSkillStats 清空所有Skills统计信息
func (h *SkillsHandler) ClearSkillStats(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
return
}
if err := h.db.ClearSkillStats(); err != nil {
h.logger.Error("清空Skills统计信息失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
return
}
h.logger.Info("已清空所有Skills统计信息")
c.JSON(http.StatusOK, gin.H{
"message": "已清空所有Skills统计信息",
})
}
// ClearSkillStatsByName 清空指定skill的统计信息
func (h *SkillsHandler) ClearSkillStatsByName(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
return
}
if err := h.db.ClearSkillStatsByName(skillName); err != nil {
h.logger.Error("清空指定skill统计信息失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
return
}
h.logger.Info("已清空指定skill统计信息", zap.String("skill", skillName))
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("已清空skill '%s' 的统计信息", skillName),
})
}
// removeSkillFromRoles 预留:角色不再存储 skill 绑定,无操作。
func (h *SkillsHandler) removeSkillFromRoles(skillName string) []string {
_ = skillName
return nil
}
// saveRolesConfig 保存角色配置到文件(从SkillsHandler调用)
func (h *SkillsHandler) saveRolesConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeRoleFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeRoleFileName 将角色名称转换为安全的文件名
func sanitizeRoleFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// isValidSkillName 验证 skill 目录名(与 Agent Skills 的 name 字段一致:小写、数字、连字符)
func isValidSkillName(name string) bool {
if name == "" || len(name) > 100 {
return false
}
for _, r := range name {
if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-') {
return false
}
}
return true
}
+58
View File
@@ -0,0 +1,58 @@
package handler
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// sseInterval is how often we write on long SSE streams. Shorter intervals help NATs and
// some proxies that treat connections as idle; 10s is a reasonable balance with traffic.
const sseKeepaliveInterval = 10 * time.Second
// sseKeepalive sends periodic SSE traffic so proxies (e.g. nginx proxy_read_timeout), NATs,
// and load balancers do not close long-running streams. Some intermediaries ignore comment-only
// lines, so we send both a comment and a minimal data frame (type heartbeat) per tick.
//
// writeMu must be the same mutex used by sendEvent for this request: concurrent writes to
// http.ResponseWriter break chunked transfer encoding (browser: net::ERR_INVALID_CHUNKED_ENCODING).
func sseKeepalive(c *gin.Context, stop <-chan struct{}, writeMu *sync.Mutex) {
if writeMu == nil {
return
}
ticker := time.NewTicker(sseKeepaliveInterval)
defer ticker.Stop()
for {
select {
case <-stop:
return
case <-c.Request.Context().Done():
return
case <-ticker.C:
select {
case <-stop:
return
case <-c.Request.Context().Done():
return
default:
}
writeMu.Lock()
if _, err := fmt.Fprintf(c.Writer, ": keepalive\n\n"); err != nil {
writeMu.Unlock()
return
}
// data: frame so strict proxies still see downstream bytes (comments alone may not reset timers)
if _, err := fmt.Fprintf(c.Writer, `data: {"type":"heartbeat"}`+"\n\n"); err != nil {
writeMu.Unlock()
return
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
writeMu.Unlock()
}
}
}
+116
View File
@@ -0,0 +1,116 @@
package handler
import "sync"
// TaskEventBus 将主 SSE 连接上的事件镜像给后订阅的客户端(例如刷新页面后、HITL 审批通过需继续收事件)。
// 每个 payload 为完整 SSE 行: "data: {...}\n\n"
type TaskEventBus struct {
mu sync.RWMutex
subs map[string]map[*taskEventSub]struct{}
}
type taskEventSub struct {
mu sync.Mutex
ch chan []byte
closed bool
}
func (s *taskEventSub) sendNonBlocking(line []byte) bool {
if s == nil {
return false
}
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return false
}
select {
case s.ch <- line:
return true
default:
return false
}
}
func (s *taskEventSub) closeOnce() {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.closed {
return
}
s.closed = true
close(s.ch)
}
func NewTaskEventBus() *TaskEventBus {
return &TaskEventBus{
subs: make(map[string]map[*taskEventSub]struct{}),
}
}
// Subscribe 注册订阅;cancel 时需调用 Unsubscribe。
func (b *TaskEventBus) Subscribe(conversationID string) (sub *taskEventSub, ch <-chan []byte) {
chBuf := make(chan []byte, 256)
sub = &taskEventSub{ch: chBuf}
b.mu.Lock()
if b.subs[conversationID] == nil {
b.subs[conversationID] = make(map[*taskEventSub]struct{})
}
b.subs[conversationID][sub] = struct{}{}
b.mu.Unlock()
return sub, chBuf
}
func (b *TaskEventBus) Unsubscribe(conversationID string, sub *taskEventSub) {
if sub == nil {
return
}
b.mu.Lock()
m, ok := b.subs[conversationID]
if !ok {
b.mu.Unlock()
return
}
delete(m, sub)
if len(m) == 0 {
delete(b.subs, conversationID)
}
b.mu.Unlock()
sub.closeOnce()
}
// Publish 非阻塞投递;慢消费者丢帧(HITL 场景以最新状态为准,丢帧可接受)。
func (b *TaskEventBus) Publish(conversationID string, line []byte) {
if b == nil || conversationID == "" || len(line) == 0 {
return
}
b.mu.RLock()
m := b.subs[conversationID]
subs := make([]*taskEventSub, 0, len(m))
for s := range m {
subs = append(subs, s)
}
b.mu.RUnlock()
cp := append([]byte(nil), line...)
for _, s := range subs {
s.sendNonBlocking(cp)
}
}
// CloseConversation 任务结束时关闭该会话所有订阅 channel。
func (b *TaskEventBus) CloseConversation(conversationID string) {
if b == nil || conversationID == "" {
return
}
b.mu.Lock()
m := b.subs[conversationID]
delete(b.subs, conversationID)
b.mu.Unlock()
for sub := range m {
sub.closeOnce()
}
}
+407
View File
@@ -0,0 +1,407 @@
package handler
import (
"context"
"errors"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/multiagent"
)
// ErrTaskCancelled 用户取消任务的错误
var ErrTaskCancelled = errors.New("agent task cancelled by user")
// ErrTaskAlreadyRunning 会话已有任务正在执行
var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation")
// shouldPersistEinoAgentTraceAfterRunErrorEino 相关 Run 非成功返回时,是否仍写入 last_react_* 供下轮 loadHistoryFromAgentTrace。
// 当前策略:无论正常结束、异常结束或用户主动停止,都尽量保留最后可用轨迹,
// 以便在同一会话继续时可基于原始上下文续跑,而不是回退到仅消息文本历史。
func shouldPersistEinoAgentTraceAfterRunError(baseCtx context.Context) bool {
return true
}
// AgentTask 描述正在运行的Agent任务
type AgentTask struct {
ConversationID string `json:"conversationId"`
Message string `json:"message,omitempty"`
StartedAt time.Time `json:"startedAt"`
Status string `json:"status"`
CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务
// ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具)
ActiveMCPExecutionID string `json:"-"`
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
InterruptContinueNote string `json:"-"`
cancel func(error)
}
// RegisterRunningTool 实现 mcp.ToolRunRegistry:工具开始时登记本会话当前 executionId。
func (m *AgentTaskManager) RegisterRunningTool(conversationID, executionID string) {
conversationID = strings.TrimSpace(conversationID)
executionID = strings.TrimSpace(executionID)
if conversationID == "" || executionID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.ActiveMCPExecutionID = executionID
}
}
// UnregisterRunningTool 工具结束时清除登记(仅当 id 仍匹配时清除,避免并发串单)。
func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID string) {
conversationID = strings.TrimSpace(conversationID)
executionID = strings.TrimSpace(executionID)
if conversationID == "" || executionID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
if t.ActiveMCPExecutionID == executionID {
t.ActiveMCPExecutionID = ""
}
}
}
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.InterruptContinueNote = note
}
}
// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。
func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
n := t.InterruptContinueNote
t.InterruptContinueNote = ""
return n
}
return ""
}
// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。
func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" || cancel == nil {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.cancel = func(err error) {
cancel(err)
}
}
}
// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。
func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
m.mu.RLock()
defer m.mu.RUnlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
return strings.TrimSpace(t.ActiveMCPExecutionID)
}
return ""
}
// CompletedTask 已完成的任务(用于历史记录)
type CompletedTask struct {
ConversationID string `json:"conversationId"`
Message string `json:"message,omitempty"`
StartedAt time.Time `json:"startedAt"`
CompletedAt time.Time `json:"completedAt"`
Status string `json:"status"`
}
// AgentTaskManager 管理正在运行的Agent任务
type AgentTaskManager struct {
mu sync.RWMutex
tasks map[string]*AgentTask
completedTasks []*CompletedTask // 最近完成的任务历史
maxHistorySize int // 最大历史记录数
historyRetention time.Duration // 历史记录保留时间
eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅
}
const (
// cancellingStuckThreshold 处于「取消中」超过此时长则强制从运行列表移除。正常取消会在当前步骤内返回,
// 超过则视为卡住,尽快释放会话。常见做法多为 30–60s 内释放。
cancellingStuckThreshold = 45 * time.Second
// cancellingStuckThresholdLegacy 未记录 CancellingAt 时用 StartedAt 判断的兜底时长
cancellingStuckThresholdLegacy = 2 * time.Minute
cleanupInterval = 15 * time.Second // 与上面阈值配合,最长约 60s 内移除
)
// NewAgentTaskManager 创建任务管理器
func NewAgentTaskManager() *AgentTaskManager {
m := &AgentTaskManager{
tasks: make(map[string]*AgentTask),
completedTasks: make([]*CompletedTask, 0),
maxHistorySize: 50, // 最多保留50条历史记录
historyRetention: 24 * time.Hour, // 保留24小时
}
go m.runStuckCancellingCleanup()
return m
}
// SetTaskEventBus 设置任务事件总线(与 AgentHandler 共用同一实例)。
func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) {
m.mu.Lock()
defer m.mu.Unlock()
m.eventBus = b
}
// GetTask 返回运行中任务(无则 nil)。
func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask {
m.mu.RLock()
defer m.mu.RUnlock()
return m.tasks[conversationID]
}
// runStuckCancellingCleanup 定期将长时间处于「取消中」的任务强制结束,避免卡住无法发新消息
func (m *AgentTaskManager) runStuckCancellingCleanup() {
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for range ticker.C {
m.cleanupStuckCancelling()
}
}
func (m *AgentTaskManager) cleanupStuckCancelling() {
m.mu.Lock()
var toFinish []string
now := time.Now()
for id, task := range m.tasks {
if task.Status != "cancelling" {
continue
}
var elapsed time.Duration
if !task.CancellingAt.IsZero() {
elapsed = now.Sub(task.CancellingAt)
if elapsed < cancellingStuckThreshold {
continue
}
} else {
elapsed = now.Sub(task.StartedAt)
if elapsed < cancellingStuckThresholdLegacy {
continue
}
}
toFinish = append(toFinish, id)
}
m.mu.Unlock()
for _, id := range toFinish {
m.FinishTask(id, "cancelled")
}
}
// StartTask 注册并开始一个新的任务
func (m *AgentTaskManager) StartTask(conversationID, message string, cancel context.CancelCauseFunc) (*AgentTask, error) {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.tasks[conversationID]; exists {
return nil, ErrTaskAlreadyRunning
}
task := &AgentTask{
ConversationID: conversationID,
Message: message,
StartedAt: time.Now(),
Status: "running",
cancel: func(err error) {
if cancel != nil {
cancel(err)
}
},
}
m.tasks[conversationID] = task
return task, nil
}
// CancelTask 取消指定会话的任务。若任务已在取消中,仍返回 (true, nil) 以便接口幂等、前端不报错。
func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, error) {
m.mu.Lock()
task, exists := m.tasks[conversationID]
if !exists {
m.mu.Unlock()
return false, nil
}
// 如果已经处于取消流程,视为成功(幂等),避免前端重复点击报「未找到任务」
if task.Status == "cancelling" {
m.mu.Unlock()
return true, nil
}
// ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。
if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) {
task.Status = "running"
} else {
task.Status = "cancelling"
task.CancellingAt = time.Now()
}
if cause != nil && errors.Is(cause, ErrTaskCancelled) {
task.InterruptContinueNote = ""
}
cancel := task.cancel
m.mu.Unlock()
if cause == nil {
cause = ErrTaskCancelled
}
if cancel != nil {
cancel(cause)
}
return true, nil
}
// UpdateTaskStatus 更新任务状态但不删除任务(用于在发送事件前更新状态)
func (m *AgentTaskManager) UpdateTaskStatus(conversationID string, status string) {
m.mu.Lock()
defer m.mu.Unlock()
task, exists := m.tasks[conversationID]
if !exists {
return
}
if status != "" {
task.Status = status
}
}
// FinishTask 完成任务并从管理器中移除
func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) {
m.mu.Lock()
task, exists := m.tasks[conversationID]
if !exists {
m.mu.Unlock()
return
}
if finalStatus != "" {
task.Status = finalStatus
}
// 保存到历史记录
completedTask := &CompletedTask{
ConversationID: task.ConversationID,
Message: task.Message,
StartedAt: task.StartedAt,
CompletedAt: time.Now(),
Status: finalStatus,
}
// 添加到历史记录
m.completedTasks = append(m.completedTasks, completedTask)
// 清理过期和过多的历史记录
m.cleanupHistory()
// 从运行任务中移除
delete(m.tasks, conversationID)
bus := m.eventBus
m.mu.Unlock()
if bus != nil {
bus.CloseConversation(conversationID)
}
}
// cleanupHistory 清理过期的历史记录
func (m *AgentTaskManager) cleanupHistory() {
now := time.Now()
cutoffTime := now.Add(-m.historyRetention)
// 过滤掉过期的记录
validTasks := make([]*CompletedTask, 0, len(m.completedTasks))
for _, task := range m.completedTasks {
if task.CompletedAt.After(cutoffTime) {
validTasks = append(validTasks, task)
}
}
// 如果仍然超过最大数量,只保留最新的
if len(validTasks) > m.maxHistorySize {
// 按完成时间排序,保留最新的
// 由于是追加的,最新的在最后,所以直接取最后N个
start := len(validTasks) - m.maxHistorySize
validTasks = validTasks[start:]
}
m.completedTasks = validTasks
}
// GetActiveTasks 返回所有正在运行的任务
func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]*AgentTask, 0, len(m.tasks))
for _, task := range m.tasks {
result = append(result, &AgentTask{
ConversationID: task.ConversationID,
Message: task.Message,
StartedAt: task.StartedAt,
Status: task.Status,
})
}
return result
}
// GetCompletedTasks 返回最近完成的任务历史
func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask {
m.mu.RLock()
defer m.mu.RUnlock()
// 清理过期记录(只读锁,不影响其他操作)
// 注意:这里不能直接调用cleanupHistory,因为需要写锁
// 所以返回时过滤过期记录
now := time.Now()
cutoffTime := now.Add(-m.historyRetention)
result := make([]*CompletedTask, 0, len(m.completedTasks))
for _, task := range m.completedTasks {
if task.CompletedAt.After(cutoffTime) {
result = append(result, task)
}
}
// 按完成时间倒序排序(最新的在前)
// 由于是追加的,最新的在最后,需要反转
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}
// 限制返回数量
if len(result) > m.maxHistorySize {
result = result[:m.maxHistorySize]
}
return result
}
+257
View File
@@ -0,0 +1,257 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const (
terminalMaxCommandLen = 4096
terminalMaxOutputLen = 256 * 1024 // 256KB
terminalTimeout = 30 * time.Minute
)
// TerminalHandler 处理系统设置中的终端命令执行
type TerminalHandler struct {
logger *zap.Logger
}
// maskTerminalCommand 对可能包含敏感信息的终端命令做脱敏,避免在日志中直接记录密码等内容
func maskTerminalCommand(cmd string) string {
trimmed := strings.TrimSpace(cmd)
lower := strings.ToLower(trimmed)
if strings.Contains(lower, "sudo") || strings.Contains(lower, "password") {
return "[masked sensitive terminal command]"
}
if len(trimmed) > 256 {
return trimmed[:256] + "..."
}
return trimmed
}
// NewTerminalHandler 创建终端处理器
func NewTerminalHandler(logger *zap.Logger) *TerminalHandler {
return &TerminalHandler{logger: logger}
}
// RunCommandRequest 执行命令请求
type RunCommandRequest struct {
Command string `json:"command"`
Shell string `json:"shell,omitempty"`
Cwd string `json:"cwd,omitempty"`
}
// RunCommandResponse 执行命令响应
type RunCommandResponse struct {
Stdout string `json:"stdout"`
Stderr string `json:"stderr"`
ExitCode int `json:"exit_code"`
Error string `json:"error,omitempty"`
}
// RunCommand 执行终端命令(需登录)
func (h *TerminalHandler) RunCommand(c *gin.Context) {
var req RunCommandRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"})
return
}
cmdStr := strings.TrimSpace(req.Command)
if cmdStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"})
return
}
if len(cmdStr) > terminalMaxCommandLen {
c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"})
return
}
shell := req.Shell
if shell == "" {
if runtime.GOOS == "windows" {
shell = "cmd"
} else {
shell = "sh"
}
}
ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
} else {
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
// 无 TTY 时设置 COLUMNS/TERM,使 ping 等工具的 usage 排版与真实终端一致
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
}
if req.Cwd != "" {
absCwd, err := filepath.Abs(req.Cwd)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"})
return
}
cur, _ := os.Getwd()
curAbs, _ := filepath.Abs(cur)
rel, err := filepath.Rel(curAbs, absCwd)
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"})
return
}
cmd.Dir = absCwd
}
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
stdoutBytes := stdout.Bytes()
stderrBytes := stderr.Bytes()
// 限制输出长度,防止内存占用过大(复制后截断,避免修改原 buffer)
truncSuffix := []byte("\n...(输出已截断)\n")
if len(stdoutBytes) > terminalMaxOutputLen {
tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix))
n := copy(tmp, stdoutBytes[:terminalMaxOutputLen])
copy(tmp[n:], truncSuffix)
stdoutBytes = tmp
}
if len(stderrBytes) > terminalMaxOutputLen {
tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix))
n := copy(tmp, stderrBytes[:terminalMaxOutputLen])
copy(tmp[n:], truncSuffix)
stderrBytes = tmp
}
exitCode := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
if ctx.Err() == context.DeadlineExceeded {
so := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n")
so = strings.ReplaceAll(so, "\r", "\n")
se := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n")
se = strings.ReplaceAll(se, "\r", "\n")
resp := RunCommandResponse{
Stdout: so,
Stderr: se,
ExitCode: -1,
Error: "命令执行超时(" + terminalTimeout.String() + "",
}
c.JSON(http.StatusOK, resp)
return
}
h.logger.Debug("终端命令执行异常", zap.String("command", maskTerminalCommand(cmdStr)), zap.Error(err))
}
// 统一为 \n,避免前端因 \r 出现错位/对角线排版
stdoutStr := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n")
stdoutStr = strings.ReplaceAll(stdoutStr, "\r", "\n")
stderrStr := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n")
stderrStr = strings.ReplaceAll(stderrStr, "\r", "\n")
resp := RunCommandResponse{
Stdout: stdoutStr,
Stderr: stderrStr,
ExitCode: exitCode,
}
if err != nil && exitCode != 0 {
resp.Error = err.Error()
}
c.JSON(http.StatusOK, resp)
}
// streamEvent SSE 事件
type streamEvent struct {
T string `json:"t"` // "out" | "err" | "exit"
D string `json:"d,omitempty"`
C int `json:"c"` // exit code(不用 omitempty,否则 0 不序列化导致前端显示 [exit undefined]
}
// RunCommandStream 流式执行命令,输出实时推送到前端(SSE)
func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
var req RunCommandRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"})
return
}
cmdStr := strings.TrimSpace(req.Command)
if cmdStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"})
return
}
if len(cmdStr) > terminalMaxCommandLen {
c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"})
return
}
shell := req.Shell
if shell == "" {
if runtime.GOOS == "windows" {
shell = "cmd"
} else {
shell = "sh"
}
}
ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
} else {
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
}
if req.Cwd != "" {
absCwd, err := filepath.Abs(req.Cwd)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"})
return
}
cur, _ := os.Getwd()
curAbs, _ := filepath.Abs(cur)
rel, err := filepath.Rel(curAbs, absCwd)
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"})
return
}
cmd.Dir = absCwd
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
cancel()
return
}
sendEvent := func(ev streamEvent) {
body, _ := json.Marshal(ev)
c.SSEvent("", string(body))
flusher.Flush()
}
_ = runCommandStreamImpl(cmd, sendEvent, ctx)
}
+47
View File
@@ -0,0 +1,47 @@
//go:build !windows
package handler
import (
"bufio"
"context"
"os/exec"
"strings"
"github.com/creack/pty"
)
const ptyCols = 256
const ptyRows = 40
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return -1
}
defer ptmx.Close()
normalize := func(s string) string {
s = strings.ReplaceAll(s, "\r\n", "\n")
return strings.ReplaceAll(s, "\r", "\n")
}
sc := bufio.NewScanner(ptmx)
for sc.Scan() {
sendEvent(streamEvent{T: "out", D: normalize(sc.Text())})
}
exitCode := 0
if err := cmd.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
}
if ctx.Err() == context.DeadlineExceeded {
exitCode = -1
}
sendEvent(streamEvent{T: "exit", C: exitCode})
return exitCode
}
@@ -0,0 +1,66 @@
//go:build windows
package handler
import (
"bufio"
"context"
"os/exec"
"strings"
"sync"
)
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return -1
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return -1
}
if err := cmd.Start(); err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return -1
}
normalize := func(s string) string {
s = strings.ReplaceAll(s, "\r\n", "\n")
return strings.ReplaceAll(s, "\r", "\n")
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
sc := bufio.NewScanner(stdoutPipe)
for sc.Scan() {
sendEvent(streamEvent{T: "out", D: normalize(sc.Text())})
}
}()
go func() {
defer wg.Done()
sc := bufio.NewScanner(stderrPipe)
for sc.Scan() {
sendEvent(streamEvent{T: "err", D: normalize(sc.Text())})
}
}()
wg.Wait()
exitCode := 0
if err := cmd.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
}
if ctx.Err() == context.DeadlineExceeded {
exitCode = -1
}
sendEvent(streamEvent{T: "exit", C: exitCode})
return exitCode
}
+111
View File
@@ -0,0 +1,111 @@
//go:build !windows
package handler
import (
"encoding/json"
"net/http"
"os"
"os/exec"
"time"
"github.com/creack/pty"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
// terminalResize is sent by the frontend when the xterm.js terminal is resized.
type terminalResize struct {
Type string `json:"type"`
Cols uint16 `json:"cols"`
Rows uint16 `json:"rows"`
}
// wsUpgrader 仅用于系统设置中的终端 WebSocket,会复用已有的登录保护(JWT 中间件在上层路由组)
var wsUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// 由于已在 Gin 路由层做了认证,这里放宽 Origin,方便在同一域名下通过 HTTPS/WSS 访问
return true
},
}
// RunCommandWS 提供真正交互式 Shell:基于 WebSocket + PTY 的长会话
// 前端建立 WebSocket 连接后,所有键盘输入都会透传到 Shell,Shell 的输出也会实时写回前端。
func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
defer conn.Close()
// 启动交互式 Shell,这里优先使用 bash,找不到则退回 sh
shell := "bash"
if _, err := exec.LookPath(shell); err != nil {
shell = "sh"
}
cmd := exec.Command(shell)
cmd.Env = append(os.Environ(),
"COLUMNS=80",
"LINES=24",
"TERM=xterm-256color",
)
// Use 80x24 as a safe default; the frontend will send the actual size immediately after connecting.
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: 80, Rows: 24})
if err != nil {
return
}
defer ptmx.Close()
// Shell -> WebSocket:将 PTY 输出实时发给前端
doneChan := make(chan struct{})
go func() {
buf := make([]byte, 4096)
for {
n, err := ptmx.Read(buf)
if n > 0 {
_ = conn.WriteMessage(websocket.BinaryMessage, buf[:n])
}
if err != nil {
break
}
}
close(doneChan)
}()
// WebSocket -> Shell:将前端输入写入 PTY(包括 sudo 密码、Ctrl+C 等)
conn.SetReadLimit(64 * 1024)
_ = conn.SetReadDeadline(time.Now().Add(terminalTimeout))
conn.SetPongHandler(func(string) error {
_ = conn.SetReadDeadline(time.Now().Add(terminalTimeout))
return nil
})
for {
msgType, data, err := conn.ReadMessage()
if err != nil {
_ = cmd.Process.Kill()
break
}
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
continue
}
if len(data) == 0 {
continue
}
// Check if this is a resize message (JSON with type:"resize")
if msgType == websocket.TextMessage && len(data) > 0 && data[0] == '{' {
var resize terminalResize
if json.Unmarshal(data, &resize) == nil && resize.Type == "resize" && resize.Cols > 0 && resize.Rows > 0 {
_ = pty.Setsize(ptmx, &pty.Winsize{Cols: resize.Cols, Rows: resize.Rows})
continue
}
}
if _, err := ptmx.Write(data); err != nil {
_ = cmd.Process.Kill()
break
}
}
<-doneChan
}
+533
View File
@@ -0,0 +1,533 @@
package handler
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// VulnerabilityHandler 漏洞处理器
type VulnerabilityHandler struct {
db *database.DB
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewVulnerabilityHandler 创建新的漏洞处理器
func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler {
return &VulnerabilityHandler{
db: db,
logger: logger,
}
}
// CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
ProjectID string `json:"project_id"`
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title" binding:"required"`
Description string `json:"description"`
Severity string `json:"severity" binding:"required"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// CreateVulnerability 创建漏洞
func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
var req CreateVulnerabilityRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
vuln := &database.Vulnerability{
ConversationID: req.ConversationID,
ProjectID: strings.TrimSpace(req.ProjectID),
ConversationTag: req.ConversationTag,
TaskTag: req.TaskTag,
Title: req.Title,
Description: req.Description,
Severity: req.Severity,
Status: req.Status,
Type: req.Type,
Target: req.Target,
Proof: req.Proof,
Impact: req.Impact,
Recommendation: req.Recommendation,
}
created, err := h.db.CreateVulnerability(vuln)
if err != nil {
h.logger.Error("创建漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
"severity": created.Severity, "title": created.Title,
})
}
c.JSON(http.StatusOK, created)
}
// GetVulnerability 获取漏洞
func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) {
id := c.Param("id")
vuln, err := h.db.GetVulnerability(id)
if err != nil {
h.logger.Error("获取漏洞失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
return
}
c.JSON(http.StatusOK, vuln)
}
// ListVulnerabilitiesResponse 漏洞列表响应
type ListVulnerabilitiesResponse struct {
Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
q := strings.TrimSpace(c.Query("q"))
if q == "" {
q = strings.TrimSpace(c.Query("search"))
}
return database.VulnerabilityListFilter{
ProjectID: c.Query("project_id"),
ID: c.Query("id"),
Search: q,
ConversationID: c.Query("conversation_id"),
Severity: c.Query("severity"),
Status: c.Query("status"),
TaskID: c.Query("task_id"),
ConversationTag: c.Query("conversation_tag"),
TaskTag: c.Query("task_tag"),
}
}
// ListVulnerabilities 列出漏洞
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "20")
offsetStr := c.DefaultQuery("offset", "0")
pageStr := c.Query("page")
filter := parseVulnerabilityListFilter(c)
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
page := 1
// 如果提供了page参数,优先使用page计算offset
if pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
offset = (page - 1) * limit
}
}
if limit <= 0 || limit > 100 {
limit = 20
}
if offset < 0 {
offset = 0
}
// 获取总数
total, err := h.db.CountVulnerabilities(filter)
if err != nil {
h.logger.Error("获取漏洞总数失败", zap.Error(err))
// 继续执行,使用0作为总数
total = 0
}
// 获取漏洞列表
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
if err != nil {
h.logger.Error("获取漏洞列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 计算总页数
totalPages := (total + limit - 1) / limit
if totalPages == 0 {
totalPages = 1
}
// 如果使用offset计算page,需要重新计算
if pageStr == "" {
page = (offset / limit) + 1
}
response := ListVulnerabilitiesResponse{
Vulnerabilities: vulnerabilities,
Total: total,
Page: page,
PageSize: limit,
TotalPages: totalPages,
}
c.JSON(http.StatusOK, response)
}
// UpdateVulnerabilityRequest 更新漏洞请求
type UpdateVulnerabilityRequest struct {
ProjectID *string `json:"project_id"`
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// UpdateVulnerability 更新漏洞
func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
id := c.Param("id")
var req UpdateVulnerabilityRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 获取现有漏洞
existing, err := h.db.GetVulnerability(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
return
}
// 更新字段
if req.ProjectID != nil {
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
}
if req.ConversationTag != "" {
existing.ConversationTag = req.ConversationTag
}
if req.TaskTag != "" {
existing.TaskTag = req.TaskTag
}
if req.Title != "" {
existing.Title = req.Title
}
if req.Description != "" {
existing.Description = req.Description
}
if req.Severity != "" {
existing.Severity = req.Severity
}
if req.Status != "" {
existing.Status = req.Status
}
if req.Type != "" {
existing.Type = req.Type
}
if req.Target != "" {
existing.Target = req.Target
}
if req.Proof != "" {
existing.Proof = req.Proof
}
if req.Impact != "" {
existing.Impact = req.Impact
}
if req.Recommendation != "" {
existing.Recommendation = req.Recommendation
}
if err := h.db.UpdateVulnerability(id, existing); err != nil {
h.logger.Error("更新漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 返回更新后的漏洞
updated, err := h.db.GetVulnerability(id)
if err != nil {
h.logger.Error("获取更新后的漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
})
}
c.JSON(http.StatusOK, updated)
}
// DeleteVulnerability 删除漏洞
func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteVulnerability(id); err != nil {
h.logger.Error("删除漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "vulnerability",
Action: "delete",
Result: "success",
ResourceType: "vulnerability",
ResourceID: id,
Message: "删除漏洞记录",
})
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// BatchDeleteVulnerabilities 按当前筛选条件批量删除漏洞
func (h *VulnerabilityHandler) BatchDeleteVulnerabilities(c *gin.Context) {
filter := parseVulnerabilityListFilter(c)
total, err := h.db.CountVulnerabilities(filter)
if err != nil {
h.logger.Error("统计待删除漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if total == 0 {
c.JSON(http.StatusOK, gin.H{"message": "当前筛选条件下没有可删除的漏洞", "deleted": 0})
return
}
deleted, err := h.db.DeleteVulnerabilitiesByFilter(filter)
if err != nil {
h.logger.Error("批量删除漏洞失败", zap.Error(err), zap.Int("count", total))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "vulnerability", "delete_batch", "批量删除漏洞记录", "vulnerability", "", map[string]interface{}{
"deleted": deleted,
"filter": filter,
})
}
c.JSON(http.StatusOK, gin.H{"message": "批量删除成功", "deleted": deleted})
}
// GetVulnerabilityStats 获取漏洞统计
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
filter := parseVulnerabilityListFilter(c)
stats, err := h.db.GetVulnerabilityStats(filter)
if err != nil {
h.logger.Error("获取漏洞统计失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
func (h *VulnerabilityHandler) GetVulnerabilityFilterOptions(c *gin.Context) {
options, err := h.db.GetVulnerabilityFilterOptions()
if err != nil {
h.logger.Error("获取漏洞筛选建议失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, options)
}
// ExportVulnerabilities 导出漏洞(支持按对话/任务分组,汇总或拆分)
func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
groupBy := c.DefaultQuery("group_by", "conversation")
mode := c.DefaultQuery("mode", "summary")
if groupBy != "conversation" && groupBy != "task" {
c.JSON(http.StatusBadRequest, gin.H{"error": "group_by 仅支持 conversation 或 task"})
return
}
if mode != "summary" && mode != "split" {
c.JSON(http.StatusBadRequest, gin.H{"error": "mode 仅支持 summary 或 split"})
return
}
filter := parseVulnerabilityListFilter(c)
total, err := h.db.CountVulnerabilities(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if total == 0 {
c.JSON(http.StatusOK, gin.H{"mode": mode, "group_by": groupBy, "total": 0, "files": []any{}})
return
}
items, err := h.db.ListVulnerabilities(total, 0, filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
type exportFile struct {
FileName string `json:"filename"`
Content string `json:"content"`
}
grouped := map[string][]*database.Vulnerability{}
for _, v := range items {
key := v.ConversationID
if groupBy == "conversation" {
if strings.TrimSpace(v.ConversationTag) != "" {
key = strings.TrimSpace(v.ConversationTag)
}
} else {
key = firstNonEmpty(v.TaskTag, v.TaskID, v.TaskQueueID, "unassigned-task")
}
grouped[key] = append(grouped[key], v)
}
files := make([]exportFile, 0)
nowStr := time.Now().Format("20060102-150405")
if mode == "summary" {
var b strings.Builder
b.WriteString("# 漏洞批量导出报告\n\n")
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
b.WriteString(fmt.Sprintf("- 分组维度: %s\n", groupBy))
b.WriteString(fmt.Sprintf("- 漏洞总数: %d\n", len(items)))
b.WriteString(fmt.Sprintf("- 分组数: %d\n\n", len(grouped)))
for group, list := range grouped {
b.WriteString(fmt.Sprintf("## %s (%d)\n\n", group, len(list)))
for _, v := range list {
appendVulnerabilityMarkdown(&b, v, "###")
}
}
files = append(files, exportFile{
FileName: fmt.Sprintf("vulnerability-report-%s-%s.md", groupBy, nowStr),
Content: b.String(),
})
} else {
for group, list := range grouped {
var b strings.Builder
b.WriteString(fmt.Sprintf("# 漏洞报告 - %s\n\n", group))
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
b.WriteString(fmt.Sprintf("- 漏洞数量: %d\n\n", len(list)))
for _, v := range list {
appendVulnerabilityMarkdown(&b, v, "##")
}
files = append(files, exportFile{
FileName: fmt.Sprintf("vulnerability-%s-%s.md", sanitizeExportName(group), nowStr),
Content: b.String(),
})
}
}
c.JSON(http.StatusOK, gin.H{
"mode": mode,
"group_by": groupBy,
"total": len(items),
"files": files,
})
}
// appendVulnerabilityMarkdown 单条漏洞的 Markdown 片段(与单文件下载字段对齐,缺省字段不写)
func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability, titleHeading string) {
b.WriteString(fmt.Sprintf("%s %s\n\n", titleHeading, v.Title))
b.WriteString(fmt.Sprintf("- 漏洞ID: `%s`\n", v.ID))
b.WriteString(fmt.Sprintf("- 严重程度: %s\n", v.Severity))
b.WriteString(fmt.Sprintf("- 状态: %s\n", v.Status))
if v.Type != "" {
b.WriteString(fmt.Sprintf("- 类型: %s\n", v.Type))
}
if v.Target != "" {
b.WriteString(fmt.Sprintf("- 目标: %s\n", v.Target))
}
b.WriteString(fmt.Sprintf("- 对话ID: `%s`\n", v.ConversationID))
if v.ConversationTag != "" {
b.WriteString(fmt.Sprintf("- 对话标签: %s\n", v.ConversationTag))
}
if v.TaskTag != "" {
b.WriteString(fmt.Sprintf("- 任务标签: %s\n", v.TaskTag))
}
if v.TaskID != "" {
b.WriteString(fmt.Sprintf("- 任务ID: `%s`\n", v.TaskID))
}
if v.TaskQueueID != "" {
b.WriteString(fmt.Sprintf("- 任务队列ID: `%s`\n", v.TaskQueueID))
}
if !v.CreatedAt.IsZero() {
b.WriteString(fmt.Sprintf("- 创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
}
if !v.UpdatedAt.IsZero() {
b.WriteString(fmt.Sprintf("- 更新时间: %s\n", v.UpdatedAt.Format("2006-01-02 15:04:05")))
}
if v.Description != "" {
b.WriteString("\n#### 描述\n\n")
b.WriteString(v.Description)
b.WriteString("\n")
}
if v.Proof != "" {
b.WriteString("\n#### 证明(POC\n\n```\n")
b.WriteString(v.Proof)
b.WriteString("\n```\n")
}
if v.Impact != "" {
b.WriteString("\n#### 影响\n\n")
b.WriteString(v.Impact)
b.WriteString("\n")
}
if v.Recommendation != "" {
b.WriteString("\n#### 修复建议\n\n")
b.WriteString(v.Recommendation)
b.WriteString("\n")
}
b.WriteString("\n")
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
trimmed := strings.TrimSpace(v)
if trimmed != "" {
return trimmed
}
}
return ""
}
func sanitizeExportName(raw string) string {
name := strings.TrimSpace(raw)
if name == "" {
return "unknown"
}
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
return replacer.Replace(name)
}
+993
View File
@@ -0,0 +1,993 @@
package handler
import (
"bytes"
"crypto/tls"
"database/sql"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
// webshellSupportedEncodings 允许的 WebShell 响应编码取值(小写,含空串代表 auto)
// 仅暴露目前最常见的几种,其他需求可后续扩展(如 Big5、Shift_JIS 等)。
var webshellSupportedEncodings = map[string]struct{}{
"": {}, // 未配置,按 auto 处理
"auto": {},
"utf-8": {},
"utf8": {},
"gbk": {},
"gb18030": {},
}
// normalizeWebshellEncoding 归一化编码标识:统一为小写,未知值回退为 auto,供持久化使用
func normalizeWebshellEncoding(enc string) string {
enc = strings.ToLower(strings.TrimSpace(enc))
if _, ok := webshellSupportedEncodings[enc]; !ok {
return "auto"
}
if enc == "" {
return "auto"
}
if enc == "utf8" {
return "utf-8"
}
return enc
}
// decodeWebshellOutput 把 WebShell 返回的字节按指定编码转换为合法 UTF-8 字符串。
// 约定:
// - "" / "auto":若已是合法 UTF-8 原样返回,否则依次尝试 GB18030(GBK 超集)解码。
// - "utf-8" / "utf8":原样返回,非法字节交由 JSON 层按 U+FFFD 处理(保持原有行为)。
// - "gbk" / "gb18030":强制按对应编码解码;失败则回退原始字节。
//
// 该函数对空输入直接返回空串,避免不必要的转换。
func decodeWebshellOutput(raw []byte, encoding string) string {
if len(raw) == 0 {
return ""
}
enc := normalizeWebshellEncoding(encoding)
switch enc {
case "utf-8":
return string(raw)
case "gbk":
if out, _, err := transform.Bytes(simplifiedchinese.GBK.NewDecoder(), raw); err == nil {
return string(out)
}
return string(raw)
case "gb18030":
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
return string(out)
}
return string(raw)
default: // auto
if utf8.Valid(raw) {
return string(raw)
}
// GB18030 是 GBK 的超集,覆盖范围最广,auto 模式统一用它兜底
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
return string(out)
}
return string(raw)
}
}
// webshellSupportedOS 允许的 WebShell 目标操作系统(小写,空串代表 auto)
var webshellSupportedOS = map[string]struct{}{
"": {},
"auto": {},
"linux": {},
"windows": {},
}
// normalizeWebshellOS 归一化 OS 标识,未知值回退为 auto,供持久化使用
func normalizeWebshellOS(osTag string) string {
osTag = strings.ToLower(strings.TrimSpace(osTag))
if _, ok := webshellSupportedOS[osTag]; !ok {
return "auto"
}
if osTag == "" {
return "auto"
}
return osTag
}
// resolveWebshellOS 根据连接的 os 与 shellType 推断最终目标 OS(仅返回 "linux" 或 "windows")。
// 规则:
// - 显式 linux / windows:按用户选择。
// - auto 或未知:asp/aspx → windows,其他 → linux。保持历史行为,平滑向后兼容。
func resolveWebshellOS(osTag, shellType string) string {
osTag = strings.ToLower(strings.TrimSpace(osTag))
switch osTag {
case "linux":
return "linux"
case "windows":
return "windows"
}
t := strings.ToLower(strings.TrimSpace(shellType))
if t == "asp" || t == "aspx" {
return "windows"
}
return "linux"
}
// quoteCmdPath 把路径按 Windows cmd.exe 规则转义。
// 使用双引号包裹,内部双引号转义为 ""(cmd 接受的写法)。
func quoteCmdPath(p string) string {
if p == "" {
return "\".\""
}
return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\""
}
// normalizeWindowsCmdPath 把前端统一的 "/" 路径转换为 cmd 更稳定识别的 "\"。
// 仅用于 Windows 命令构造,不改变语义(例如 "." / ".." 会保持不变)。
func normalizeWindowsCmdPath(p string) string {
s := strings.TrimSpace(p)
if s == "" {
return s
}
return strings.ReplaceAll(s, "/", "\\")
}
// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。
// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。
func quotePsSingle(s string) string {
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
}
// quoteShellSinglePosix 把路径按 POSIX sh 单引号规则转义(内部 ' → '\''
func quoteShellSinglePosix(p string) string {
if p == "" {
return "."
}
return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'"
}
// quoteWebshellPath 按目标 OS 选择转义方案:Linux 用 POSIX 单引号,Windows 用 cmd 双引号
func quoteWebshellPath(path, osTag string) string {
if resolveWebshellOS(osTag, "") == "windows" {
return quoteCmdPath(path)
}
return quoteShellSinglePosix(path)
}
// buildWindowsPowerShellWrite 构造 Windows 端把 base64 内容一次性写入目标路径的 cmd 命令。
// 外层走 cmd.exe 的 powershell 调用,PowerShell 脚本里只用单引号字符串,避免嵌套引号陷阱。
func buildWindowsPowerShellWrite(path, b64 string) string {
script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" +
"[IO.File]::WriteAllBytes(" + quotePsSingle(path) + ",$b)"
return "powershell -NoProfile -NonInteractive -Command \"" + script + "\""
}
// buildWindowsPowerShellAppend 构造 Windows 端把 base64 内容追加写入目标路径的 cmd 命令(用于分块上传)
func buildWindowsPowerShellAppend(path, b64 string) string {
script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" +
"$f=[IO.File]::Open(" + quotePsSingle(path) + ",[IO.FileMode]::Append,[IO.FileAccess]::Write,[IO.FileShare]::None);" +
"try{$f.Write($b,0,$b.Length)}finally{$f.Close()}"
return "powershell -NoProfile -NonInteractive -Command \"" + script + "\""
}
// fileCommandInput 封装 buildFileCommand 的输入,避免长参数列表
type fileCommandInput struct {
Action string
Path string
TargetPath string
Content string
ChunkIndex int
OS string
ShellType string
}
// buildFileCommand 根据目标 OS 与文件操作类型生成具体的远端命令字符串。
// 同一份实现供 HTTP 入口(FileOp)与 MCP 入口(FileOpWithConnection)共用,避免双份维护。
// 返回值第二位是用户可见的业务错误(如 "path is required")。
func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error) {
targetOS := resolveWebshellOS(in.OS, in.ShellType)
action := strings.ToLower(strings.TrimSpace(in.Action))
path := strings.TrimSpace(in.Path)
switch action {
case "list":
p := path
if p == "" {
p = "."
}
if targetOS == "windows" {
p = normalizeWindowsCmdPath(p)
return "dir /a " + quoteCmdPath(p), nil
}
return "ls -la " + quoteShellSinglePosix(p), nil
case "read":
if path == "" {
return "", errFileOpPathRequired
}
if targetOS == "windows" {
path = normalizeWindowsCmdPath(path)
return "type " + quoteCmdPath(path), nil
}
return "cat " + quoteShellSinglePosix(path), nil
case "delete":
if path == "" {
return "", errFileOpPathRequired
}
if targetOS == "windows" {
path = normalizeWindowsCmdPath(path)
return "del /q /f " + quoteCmdPath(path), nil
}
return "rm -f " + quoteShellSinglePosix(path), nil
case "mkdir":
if path == "" {
return "", errFileOpPathRequired
}
if targetOS == "windows" {
path = normalizeWindowsCmdPath(path)
// cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p
return "md " + quoteCmdPath(path), nil
}
return "mkdir -p " + quoteShellSinglePosix(path), nil
case "rename":
oldPath := path
newPath := strings.TrimSpace(in.TargetPath)
if oldPath == "" || newPath == "" {
return "", errFileOpRenameNeedsBothPaths
}
if targetOS == "windows" {
oldPath = normalizeWindowsCmdPath(oldPath)
newPath = normalizeWindowsCmdPath(newPath)
return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil
}
return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil
case "write":
if path == "" {
return "", errFileOpPathRequired
}
// 统一策略:先把内容 base64 编码,再用目标平台对应方式解码写回,
// 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。
b64 := base64.StdEncoding.EncodeToString([]byte(in.Content))
if targetOS == "windows" {
path = normalizeWindowsCmdPath(path)
return buildWindowsPowerShellWrite(path, b64), nil
}
return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil
case "upload":
if path == "" {
return "", errFileOpPathRequired
}
if len(in.Content) > 512*1024 {
return "", errFileOpUploadTooLarge
}
if targetOS == "windows" {
path = normalizeWindowsCmdPath(path)
return buildWindowsPowerShellWrite(path, in.Content), nil
}
return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil
case "upload_chunk":
if path == "" {
return "", errFileOpPathRequired
}
if targetOS == "windows" {
path = normalizeWindowsCmdPath(path)
if in.ChunkIndex == 0 {
return buildWindowsPowerShellWrite(path, in.Content), nil
}
return buildWindowsPowerShellAppend(path, in.Content), nil
}
redir := ">>"
if in.ChunkIndex == 0 {
redir = ">"
}
return "echo '" + in.Content + "' | base64 -d " + redir + " " + quoteShellSinglePosix(path), nil
}
return "", errFileOpUnsupportedAction(action)
}
// 业务错误常量,便于上层统一返回用户可见提示
var (
errFileOpPathRequired = simpleError("path is required")
errFileOpRenameNeedsBothPaths = simpleError("path and target_path are required for rename")
errFileOpUploadTooLarge = simpleError("upload content too large (max 512KB base64)")
)
func errFileOpUnsupportedAction(action string) error {
return simpleError("unsupported action: " + action)
}
// simpleError 是不带堆栈的轻量错误类型,供 buildFileCommand 报可预期的参数校验错误
type simpleError string
func (e simpleError) Error() string { return string(e) }
// WebShellHandler 代理执行 WebShell 命令(类似冰蝎/蚁剑),避免前端跨域并统一构建请求
type WebShellHandler struct {
logger *zap.Logger
client *http.Client
db *database.DB
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *WebShellHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
return &WebShellHandler{
logger: logger,
client: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: false,
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
},
},
db: db,
}
}
// CreateConnectionRequest 创建连接请求
type CreateConnectionRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"`
CmdParam string `json:"cmd_param"`
Remark string `json:"remark"`
Encoding string `json:"encoding"`
OS string `json:"os"`
}
// UpdateConnectionRequest 更新连接请求
type UpdateConnectionRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"`
CmdParam string `json:"cmd_param"`
Remark string `json:"remark"`
Encoding string `json:"encoding"`
OS string `json:"os"`
}
// ListConnections 列出所有 WebShell 连接(GET /api/webshell/connections
func (h *WebShellHandler) ListConnections(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
list, err := h.db.ListWebshellConnections()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []database.WebShellConnection{}
}
c.JSON(http.StatusOK, list)
}
// CreateConnection 创建 WebShell 连接(POST /api/webshell/connections
func (h *WebShellHandler) CreateConnection(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
var req CreateConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
if req.URL == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
return
}
if _, err := url.Parse(req.URL); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
method := strings.ToLower(strings.TrimSpace(req.Method))
if method != "get" && method != "post" {
method = "post"
}
shellType := strings.ToLower(strings.TrimSpace(req.Type))
if shellType == "" {
shellType = "php"
}
conn := &database.WebShellConnection{
ID: "ws_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:12],
URL: req.URL,
Password: strings.TrimSpace(req.Password),
Type: shellType,
Method: method,
CmdParam: strings.TrimSpace(req.CmdParam),
Remark: strings.TrimSpace(req.Remark),
Encoding: normalizeWebshellEncoding(req.Encoding),
OS: normalizeWebshellOS(req.OS),
CreatedAt: time.Now(),
}
if err := h.db.CreateWebshellConnection(conn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
host := req.URL
if u, err := url.Parse(req.URL); err == nil {
host = u.Host
}
h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{
"host": host, "type": shellType,
})
}
c.JSON(http.StatusOK, conn)
}
// UpdateConnection 更新 WebShell 连接(PUT /api/webshell/connections/:id
func (h *WebShellHandler) UpdateConnection(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
var req UpdateConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
if req.URL == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
return
}
if _, err := url.Parse(req.URL); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
method := strings.ToLower(strings.TrimSpace(req.Method))
if method != "get" && method != "post" {
method = "post"
}
shellType := strings.ToLower(strings.TrimSpace(req.Type))
if shellType == "" {
shellType = "php"
}
conn := &database.WebShellConnection{
ID: id,
URL: req.URL,
Password: strings.TrimSpace(req.Password),
Type: shellType,
Method: method,
CmdParam: strings.TrimSpace(req.CmdParam),
Remark: strings.TrimSpace(req.Remark),
Encoding: normalizeWebshellEncoding(req.Encoding),
OS: normalizeWebshellOS(req.OS),
}
if err := h.db.UpdateWebshellConnection(conn); err != nil {
if err == sql.ErrNoRows {
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
updated, _ := h.db.GetWebshellConnection(id)
if updated != nil {
c.JSON(http.StatusOK, updated)
} else {
c.JSON(http.StatusOK, conn)
}
}
// DeleteConnection 删除 WebShell 连接(DELETE /api/webshell/connections/:id
func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
if err := h.db.DeleteWebshellConnection(id); err != nil {
if err == sql.ErrNoRows {
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil)
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
// GetConnectionState 获取 WebShell 连接关联的前端持久化状态(GET /api/webshell/connections/:id/state
func (h *WebShellHandler) GetConnectionState(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
conn, err := h.db.GetWebshellConnection(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if conn == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
return
}
stateJSON, err := h.db.GetWebshellConnectionState(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var state interface{}
if err := json.Unmarshal([]byte(stateJSON), &state); err != nil {
state = map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{"state": state})
}
// SaveConnectionState 保存 WebShell 连接关联的前端持久化状态(PUT /api/webshell/connections/:id/state
func (h *WebShellHandler) SaveConnectionState(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
conn, err := h.db.GetWebshellConnection(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if conn == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
return
}
var req struct {
State json.RawMessage `json:"state"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
raw := req.State
if len(raw) == 0 {
raw = json.RawMessage(`{}`)
}
if len(raw) > 2*1024*1024 {
c.JSON(http.StatusBadRequest, gin.H{"error": "state payload too large (max 2MB)"})
return
}
var anyJSON interface{}
if err := json.Unmarshal(raw, &anyJSON); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "state must be valid json"})
return
}
if err := h.db.UpsertWebshellConnectionState(id, string(raw)); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
// GetAIHistory 获取指定 WebShell 连接的 AI 助手对话历史(GET /api/webshell/connections/:id/ai-history
func (h *WebShellHandler) GetAIHistory(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
conv, err := h.db.GetConversationByWebshellConnectionID(id)
if err != nil {
h.logger.Warn("获取 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err))
c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}})
return
}
if conv == nil {
c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}})
return
}
c.JSON(http.StatusOK, gin.H{"conversationId": conv.ID, "messages": conv.Messages})
}
// ListAIConversations 列出该 WebShell 连接下的所有 AI 对话(供侧边栏)
func (h *WebShellHandler) ListAIConversations(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
list, err := h.db.ListConversationsByWebshellConnectionID(id)
if err != nil {
h.logger.Warn("列出 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err))
c.JSON(http.StatusOK, []database.WebShellConversationItem{})
return
}
if list == nil {
list = []database.WebShellConversationItem{}
}
c.JSON(http.StatusOK, list)
}
// ExecRequest 执行命令请求(前端传入连接信息 + 命令)
type ExecRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"` // php, asp, aspx, jsp, custom
Method string `json:"method"` // GET 或 POST,空则默认 POST
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto
OS string `json:"os"` // 目标操作系统:auto / linux / windows,当前 exec 不用它,保留字段便于未来扩展
Command string `json:"command" binding:"required"`
}
// ExecResponse 执行命令响应
type ExecResponse struct {
OK bool `json:"ok"`
Output string `json:"output"`
Error string `json:"error,omitempty"`
HTTPCode int `json:"http_code,omitempty"`
}
// FileOpRequest 文件操作请求
type FileOpRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"` // GET 或 POST,空则默认 POST
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto
OS string `json:"os"` // 目标操作系统:auto / linux / windows,空则按 shellType 推断
ConnectionID string `json:"connection_id,omitempty"` // 可选:连接 ID;服务端探活出 OS 后会回写到此连接
Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk
Path string `json:"path"`
TargetPath string `json:"target_path"` // rename 时目标路径
Content string `json:"content"` // write/upload 时使用
ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块
}
// FileOpResponse 文件操作响应
type FileOpResponse struct {
OK bool `json:"ok"`
Output string `json:"output"`
Error string `json:"error,omitempty"`
DetectedOS string `json:"detected_os,omitempty"` // 仅在 auto 模式且探活成功时返回,前端应更新本地缓存
}
func (h *WebShellHandler) Exec(c *gin.Context) {
var req ExecRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
req.Command = strings.TrimSpace(req.Command)
if req.URL == "" || req.Command == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url and command are required"})
return
}
parsed, err := url.Parse(req.URL)
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"})
return
}
useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET"
cmdParam := strings.TrimSpace(req.CmdParam)
if cmdParam == "" {
cmdParam = "cmd"
}
var httpReq *http.Request
if useGET {
targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, req.Command)
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
} else {
body := h.buildExecBody(req.Type, req.Password, cmdParam, req.Command)
httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if err != nil {
h.logger.Warn("webshell exec NewRequest", zap.Error(err))
c.JSON(http.StatusInternalServerError, ExecResponse{OK: false, Error: err.Error()})
return
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
h.logger.Warn("webshell exec Do", zap.String("url", req.URL), zap.Error(err))
c.JSON(http.StatusOK, ExecResponse{OK: false, Error: err.Error()})
return
}
defer resp.Body.Close()
out, readErr := io.ReadAll(resp.Body)
if readErr != nil {
h.logger.Warn("webshell exec read body", zap.Error(readErr))
}
output := decodeWebshellOutput(out, req.Encoding)
httpCode := resp.StatusCode
ok := resp.StatusCode == http.StatusOK
c.JSON(http.StatusOK, ExecResponse{
OK: ok,
Output: output,
HTTPCode: httpCode,
})
}
// buildExecBody 按常见 WebShell 约定构建 POST 体(多数使用 pass + cmd,可配置命令参数名)
func (h *WebShellHandler) buildExecBody(shellType, password, cmdParam, command string) []byte {
form := h.execParams(shellType, password, cmdParam, command)
return []byte(form.Encode())
}
// buildExecURL 构建 GET 请求的完整 URLbaseURL + ?pass=xxx&cmd=yyycmd 可配置)
func (h *WebShellHandler) buildExecURL(baseURL, shellType, password, cmdParam, command string) string {
form := h.execParams(shellType, password, cmdParam, command)
if parsed, err := url.Parse(baseURL); err == nil {
parsed.RawQuery = form.Encode()
return parsed.String()
}
return baseURL + "?" + form.Encode()
}
func (h *WebShellHandler) execParams(shellType, password, cmdParam, command string) url.Values {
shellType = strings.ToLower(strings.TrimSpace(shellType))
if shellType == "" {
shellType = "php"
}
if strings.TrimSpace(cmdParam) == "" {
cmdParam = "cmd"
}
form := url.Values{}
form.Set("pass", password)
form.Set(cmdParam, command)
return form
}
func (h *WebShellHandler) FileOp(c *gin.Context) {
var req FileOpRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
req.Action = strings.ToLower(strings.TrimSpace(req.Action))
if req.URL == "" || req.Action == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url and action are required"})
return
}
parsed, err := url.Parse(req.URL)
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"})
return
}
// 若 OS 未显式配置,先发一次探活命令,识别出真实 OS 再构造文件操作命令。
// 这解决了 "Windows + PHP + OS=auto" 场景下旧 fallback 错发 `ls -la` 导致目录列不出来的问题。
osTag := req.OS
detectedOS := ""
if normalizeWebshellOS(osTag) == "auto" {
if probed := probeWebshellOSViaExec(h.newHTTPExecFn(req.URL, req.Password, req.Type, req.Method, req.CmdParam, req.Encoding)); probed != "" {
osTag = probed
detectedOS = probed
// 若前端带了 connection_id,顺带把探活结果持久化到该连接,后续刷新零成本
if cid := strings.TrimSpace(req.ConnectionID); cid != "" {
h.persistDetectedOS(cid, probed)
}
}
}
command, cmdErr := h.buildFileCommand(fileCommandInput{
Action: req.Action,
Path: req.Path,
TargetPath: req.TargetPath,
Content: req.Content,
ChunkIndex: req.ChunkIndex,
OS: osTag,
ShellType: req.Type,
})
if cmdErr != nil {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: cmdErr.Error()})
return
}
useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET"
cmdParam := strings.TrimSpace(req.CmdParam)
if cmdParam == "" {
cmdParam = "cmd"
}
var httpReq *http.Request
if useGET {
targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
} else {
body := h.buildExecBody(req.Type, req.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if err != nil {
c.JSON(http.StatusInternalServerError, FileOpResponse{OK: false, Error: err.Error()})
return
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
c.JSON(http.StatusOK, FileOpResponse{OK: false, Error: err.Error()})
return
}
defer resp.Body.Close()
out, readErr := io.ReadAll(resp.Body)
if readErr != nil {
h.logger.Warn("webshell fileop read body", zap.Error(readErr))
}
output := decodeWebshellOutput(out, req.Encoding)
c.JSON(http.StatusOK, FileOpResponse{
OK: resp.StatusCode == http.StatusOK,
Output: output,
DetectedOS: detectedOS,
})
}
// ExecWithConnection 在指定 WebShell 连接上执行命令(供 MCP/Agent 等非 HTTP 调用)
func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, command string) (output string, ok bool, errMsg string) {
if conn == nil {
return "", false, "connection is nil"
}
command = strings.TrimSpace(command)
if command == "" {
return "", false, "command is required"
}
useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET"
cmdParam := strings.TrimSpace(conn.CmdParam)
if cmdParam == "" {
cmdParam = "cmd"
}
var httpReq *http.Request
var err error
if useGET {
targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
} else {
body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if err != nil {
return "", false, err.Error()
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
return "", false, err.Error()
}
defer resp.Body.Close()
out, readErr := io.ReadAll(resp.Body)
if readErr != nil {
h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr))
}
return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, ""
}
// FileOpWithConnection 在指定 WebShell 连接上执行文件操作(供 MCP/Agent 调用),支持 list / read / write
func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection, action, path, content, targetPath string) (output string, ok bool, errMsg string) {
if conn == nil {
return "", false, "connection is nil"
}
action = strings.ToLower(strings.TrimSpace(action))
// MCP 入口仅开放 list / read / write 三种动作,与工具文档的承诺保持一致
switch action {
case "list", "read", "write":
// 支持的动作
default:
return "", false, "unsupported action: " + action + " (supported: list, read, write)"
}
// 若连接的 OS 为 auto,先探活并持久化,避免 AI/MCP 每次都对 Windows 发 `ls -la`
osTag := conn.OS
if normalizeWebshellOS(osTag) == "auto" {
if probed := probeWebshellOSViaExec(func(cmd string) (string, bool) {
out, exOk, _ := h.ExecWithConnection(conn, cmd)
return out, exOk
}); probed != "" {
osTag = probed
conn.OS = probed // 本次请求内使用探活结果
h.persistDetectedOS(conn.ID, probed)
}
}
command, cmdErr := h.buildFileCommand(fileCommandInput{
Action: action,
Path: path,
TargetPath: targetPath,
Content: content,
OS: osTag,
ShellType: conn.Type,
})
if cmdErr != nil {
return "", false, cmdErr.Error()
}
useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET"
cmdParam := strings.TrimSpace(conn.CmdParam)
if cmdParam == "" {
cmdParam = "cmd"
}
var httpReq *http.Request
var err error
if useGET {
targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
} else {
body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if err != nil {
return "", false, err.Error()
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
return "", false, err.Error()
}
defer resp.Body.Close()
out, readErr := io.ReadAll(resp.Body)
if readErr != nil {
h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr))
}
return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, ""
}
+106
View File
@@ -0,0 +1,106 @@
package handler
import (
"strings"
"cyberstrike-ai/internal/database"
)
// WebshellSkillHintDefault 对话页 / Eino 单代理共用的 Skills 说明,放在 webshell 上下文末尾,
// 供 AI 选择 skill 加载入口时参考。
const WebshellSkillHintDefault = "Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。"
// WebshellSkillHintMultiAgent 多代理 / Eino 多代理准备阶段使用的 Skills 说明
const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `skill` 工具。"
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、restore_project_fact、list_knowledge_risk_types、search_knowledge_base"
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
// 以及最终的用户请求。调用方只需要决定 skillHint 的文案(默认使用 WebshellSkillHintDefault)。
//
// 之所以把这段逻辑抽到共享函数里,是为了避免 agent.go / multi_agent_prepare.go 等多处复制粘贴,
// 并确保当我们升级 OS / Encoding 文案时只需要改一处、测一处、同步生效。
func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint, userMsg string) string {
if conn == nil {
// 兜底:调用方已保证 conn 非 nil,这里只是防御性返回原消息
return userMsg
}
remark := conn.Remark
if remark == "" {
remark = conn.URL
}
targetOS := resolveWebshellOS(conn.OS, conn.Type) // 归一为 "linux" / "windows"
encoding := normalizeWebshellEncoding(conn.Encoding)
if skillHint == "" {
skillHint = WebshellSkillHintDefault
}
var b strings.Builder
b.Grow(512 + len(userMsg))
b.WriteString("[WebShell 助手上下文] 连接 ID")
b.WriteString(conn.ID)
b.WriteString(",备注:")
b.WriteString(remark)
b.WriteByte('\n')
// 目标系统:明确告诉 AI 能用/不能用的命令集,避免它对着 Windows 发 ls/cat/rm
b.WriteString("- 目标系统:")
b.WriteString(describeTargetOSForPrompt(targetOS))
b.WriteByte('\n')
// 响应编码:仅在非 auto 时显式告知,auto 模式由后端自适应,不打扰模型
if encHint := describeEncodingForPrompt(encoding); encHint != "" {
b.WriteString("- 响应编码:")
b.WriteString(encHint)
b.WriteByte('\n')
}
// 工具清单 & connection_id 约束:保持旧有表达,AI 已熟悉
b.WriteString("可用工具(仅在该连接上操作时使用,connection_id 填 \"")
b.WriteString(conn.ID)
b.WriteString("\"):")
b.WriteString(webshellAssistantToolList)
b.WriteString("。边渗透边记录:每确认新认知即 upsert_project_fact,每验证漏洞即 record_vulnerability,勿等会话结束。")
b.WriteString(skillHint)
b.WriteString("\n\n用户请求:")
b.WriteString(userMsg)
return b.String()
}
// describeTargetOSForPrompt 返回某个 OS 对应的中文描述 + 推荐命令集 + 反例,
// 命令列表覆盖文件管理最常用的 6 类动作(查看/读/删/改名/建目录/查找),让 AI 能直接照抄。
func describeTargetOSForPrompt(targetOS string) string {
switch targetOS {
case "windows":
return "Windows(推荐 cmd/PowerShelldir /a、type、del /q /f、move /y、md、ren" +
"查找文件用 `dir /s /b 过滤词` 或 PowerShell `Get-ChildItem -Recurse`" +
"避免 ls / cat / rm / mv / find 等 Unix 命令,否则将返回 `不是内部或外部命令`)"
case "linux":
return "Linux/Unix(推荐 sh/bashls -la、cat、rm -f、mv、mkdir -p" +
"查找文件用 `find /path -name '*pattern*'`" +
"避免 dir、type、del、move 等 Windows 命令)"
default:
// 理论上不会走到这里,resolveWebshellOS 已经兜底
return "未知(请先执行 `uname || ver` 探测再决定命令集)"
}
}
// describeEncodingForPrompt 返回响应编码的人类可读描述;auto 返回空串以减少 token。
func describeEncodingForPrompt(encoding string) string {
switch encoding {
case "utf-8":
return "UTF-8(目标原生 UTF-8,无需额外解码)"
case "gbk":
return "GBK(中文 Windows;后端已自动转码为 UTF-8 返回,若仍出现大量 \\uFFFD 替换字符说明命令失败或编码识别错误)"
case "gb18030":
return "GB18030(后端已自动转码为 UTF-8 返回)"
default:
return ""
}
}
+170
View File
@@ -0,0 +1,170 @@
package handler
import (
"strings"
"testing"
"cyberstrike-ai/internal/database"
)
func TestBuildWebshellAssistantContext_WindowsExplicit(t *testing.T) {
conn := &database.WebShellConnection{
ID: "ws_win01",
Remark: "IIS Windows 靶机",
URL: "http://example.com/shell.php",
Type: "php",
OS: "windows",
Encoding: "gbk",
}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "列出当前目录并告诉我 flag 在哪")
mustContain(t, got,
"[WebShell 助手上下文]",
"ws_win01",
"IIS Windows 靶机",
"目标系统:Windows",
"dir /a",
"move /y",
"避免 ls / cat / rm",
"响应编码:GBK",
"后端已自动转码为 UTF-8",
"connection_id 填 \"ws_win01\"",
"webshell_exec、webshell_file_list",
WebshellSkillHintDefault,
"用户请求:列出当前目录并告诉我 flag 在哪",
)
// Windows 场景下不应出现 Linux 命令推荐
mustNotContain(t, got, "推荐 sh/bash")
}
func TestBuildWebshellAssistantContext_LinuxAutoFromPHP(t *testing.T) {
conn := &database.WebShellConnection{
ID: "ws_lnx01",
Remark: "", // 测试备注为空时 fallback URL
URL: "http://example.com/a.php",
Type: "php",
OS: "auto", // auto + php → linux
Encoding: "", // auto 编码不显式提示
}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "看看 /etc/passwd")
mustContain(t, got,
"连接 IDws_lnx01",
"备注:http://example.com/a.php", // 备注空时 fallback URL
"目标系统:Linux/Unix",
"ls -la",
"mkdir -p",
"避免 dir、type、del、move",
"用户请求:看看 /etc/passwd",
)
// encoding=auto 不应出现"响应编码:"这一行
mustNotContain(t, got, "响应编码:")
// Linux 场景不应出现 Windows 命令
mustNotContain(t, got, "推荐 cmd/PowerShell")
}
func TestBuildWebshellAssistantContext_AutoFromASPDefaultsToWindows(t *testing.T) {
// 保留向后兼容:旧连接没配 os,shellType=asp 时应视为 Windows
conn := &database.WebShellConnection{
ID: "ws_asp01",
Remark: "老 ASP 靶机",
Type: "asp",
OS: "", // 空串等同 auto
Encoding: "gb18030",
}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "查当前用户")
mustContain(t, got,
"目标系统:Windows",
"响应编码:GB18030",
"后端已自动转码为 UTF-8 返回",
WebshellSkillHintMultiAgent,
)
// 多代理 skill 文案里没有 DeepAgent,不应混入 default 文案
mustNotContain(t, got, "DeepAgent")
}
func TestBuildWebshellAssistantContext_MultiAgentSkillHint(t *testing.T) {
conn := &database.WebShellConnection{ID: "ws_m1", Remark: "x", Type: "php", OS: "linux"}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "hi")
mustContain(t, got, WebshellSkillHintMultiAgent)
mustNotContain(t, got, "DeepAgent")
}
func TestBuildWebshellAssistantContext_DefaultSkillHintFallback(t *testing.T) {
conn := &database.WebShellConnection{ID: "ws_d1", Remark: "x", Type: "php", OS: "linux"}
// skillHint 传空字符串时应回退到 default
got := BuildWebshellAssistantContext(conn, "", "hi")
mustContain(t, got, WebshellSkillHintDefault)
}
func TestBuildWebshellAssistantContext_UTF8EncodingIsAnnotated(t *testing.T) {
conn := &database.WebShellConnection{
ID: "ws_u1", Remark: "u", Type: "jsp", OS: "linux", Encoding: "utf-8",
}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "hi")
mustContain(t, got, "响应编码:UTF-8", "目标原生 UTF-8")
}
func TestBuildWebshellAssistantContext_NilConnReturnsUserMsg(t *testing.T) {
// 防御性:conn == nil 时不 panic,直接返回原消息
got := BuildWebshellAssistantContext(nil, WebshellSkillHintDefault, "just the message")
if got != "just the message" {
t.Errorf("nil conn should return userMsg as-is, got %q", got)
}
}
func TestDescribeTargetOSForPrompt(t *testing.T) {
cases := map[string][]string{
"windows": {"Windows", "dir /a", "move /y", "PowerShell"},
"linux": {"Linux/Unix", "ls -la", "mkdir -p"},
"": {"未知", "uname"}, // 防御性分支
}
for in, wants := range cases {
got := describeTargetOSForPrompt(in)
for _, w := range wants {
if !strings.Contains(got, w) {
t.Errorf("describeTargetOSForPrompt(%q) should contain %q, got: %s", in, w, got)
}
}
}
}
func TestDescribeEncodingForPrompt(t *testing.T) {
cases := map[string]string{
"utf-8": "UTF-8",
"gbk": "GBK",
"gb18030": "GB18030",
"auto": "",
"": "",
}
for in, want := range cases {
got := describeEncodingForPrompt(in)
if want == "" && got != "" {
t.Errorf("describeEncodingForPrompt(%q) should return empty string, got: %s", in, got)
}
if want != "" && !strings.Contains(got, want) {
t.Errorf("describeEncodingForPrompt(%q) should contain %q, got: %s", in, want, got)
}
}
}
// ---- 小工具 ----
func mustContain(t *testing.T, text string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if !strings.Contains(text, s) {
t.Errorf("expected text to contain %q\n--- text ---\n%s", s, text)
}
}
}
func mustNotContain(t *testing.T, text string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if strings.Contains(text, s) {
t.Errorf("text should not contain %q\n--- text ---\n%s", s, text)
}
}
}
+103
View File
@@ -0,0 +1,103 @@
package handler
import (
"testing"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
// mustEncode 使用指定编码对 UTF-8 字符串做编码,得到原始字节,用于构造测试输入
func mustEncode(t *testing.T, s string, enc string) []byte {
t.Helper()
var tr transform.Transformer
switch enc {
case "gbk":
tr = simplifiedchinese.GBK.NewEncoder()
case "gb18030":
tr = simplifiedchinese.GB18030.NewEncoder()
default:
t.Fatalf("unsupported test encoding: %s", enc)
}
out, _, err := transform.Bytes(tr, []byte(s))
if err != nil {
t.Fatalf("mustEncode(%s) failed: %v", enc, err)
}
return out
}
func TestNormalizeWebshellEncoding(t *testing.T) {
cases := map[string]string{
"": "auto",
" ": "auto",
"auto": "auto",
"AUTO": "auto",
"utf-8": "utf-8",
"UTF-8": "utf-8",
"utf8": "utf-8",
"gbk": "gbk",
"GBK": "gbk",
"gb18030": "gb18030",
"big5": "auto", // 未支持的回退到 auto
"anything": "auto",
}
for in, want := range cases {
if got := normalizeWebshellEncoding(in); got != want {
t.Errorf("normalizeWebshellEncoding(%q) = %q, want %q", in, got, want)
}
}
}
func TestDecodeWebshellOutput_AutoDetectsGBK(t *testing.T) {
// 模拟 Windows 中文 cmd 输出的 GBK 字节流
want := "用户名 SID 类型"
raw := mustEncode(t, want, "gbk")
// auto 模式:UTF-8 校验失败后应当回退 GB18030 解码,得到原始中文
got := decodeWebshellOutput(raw, "auto")
if got != want {
t.Errorf("decodeWebshellOutput(auto) = %q, want %q", got, want)
}
// 显式 GBK 模式:同样应当正确解码
got = decodeWebshellOutput(raw, "gbk")
if got != want {
t.Errorf("decodeWebshellOutput(gbk) = %q, want %q", got, want)
}
// 显式 GB18030 模式:GBK 是 GB18030 子集,也应正确解码
got = decodeWebshellOutput(raw, "gb18030")
if got != want {
t.Errorf("decodeWebshellOutput(gb18030) = %q, want %q", got, want)
}
}
func TestDecodeWebshellOutput_PassthroughUTF8(t *testing.T) {
// 已经是 UTF-8 的中文字符串,各模式都应返回原串(不破坏)
want := "hello 世界"
for _, enc := range []string{"", "auto", "utf-8"} {
if got := decodeWebshellOutput([]byte(want), enc); got != want {
t.Errorf("decodeWebshellOutput(%q) passthrough = %q, want %q", enc, got, want)
}
}
}
func TestDecodeWebshellOutput_ASCIIStable(t *testing.T) {
// 纯 ASCII 在任何模式下都必须保持原样
want := "whoami\nAdministrator\n"
for _, enc := range []string{"", "auto", "utf-8", "gbk", "gb18030"} {
if got := decodeWebshellOutput([]byte(want), enc); got != want {
t.Errorf("decodeWebshellOutput(%q) ASCII = %q, want %q", enc, got, want)
}
}
}
func TestDecodeWebshellOutput_EmptyInput(t *testing.T) {
// 空输入直接返回空串,不做额外分配
if got := decodeWebshellOutput(nil, "gbk"); got != "" {
t.Errorf("decodeWebshellOutput(nil) = %q, want empty", got)
}
if got := decodeWebshellOutput([]byte{}, "auto"); got != "" {
t.Errorf("decodeWebshellOutput([]) = %q, want empty", got)
}
}
+348
View File
@@ -0,0 +1,348 @@
package handler
import (
"encoding/base64"
"strings"
"testing"
"go.uber.org/zap"
)
func newTestWebShellHandler() *WebShellHandler {
return NewWebShellHandler(zap.NewNop(), nil)
}
func TestNormalizeWebshellOS(t *testing.T) {
cases := map[string]string{
"": "auto",
" ": "auto",
"auto": "auto",
"AUTO": "auto",
"linux": "linux",
"Linux": "linux",
"windows": "windows",
"WINDOWS": "windows",
"macos": "auto", // 未支持的回退 auto
"solaris": "auto",
}
for in, want := range cases {
if got := normalizeWebshellOS(in); got != want {
t.Errorf("normalizeWebshellOS(%q) = %q, want %q", in, got, want)
}
}
}
func TestResolveWebshellOS(t *testing.T) {
type testCase struct {
osTag string
shellType string
want string
}
cases := []testCase{
// 显式 OS:按用户选择,忽略 shellType
{"linux", "asp", "linux"},
{"windows", "php", "windows"},
{"LINUX", "jsp", "linux"},
// auto + 各种 shellTypeasp/aspx → windows,其他 → linux
{"auto", "asp", "windows"},
{"auto", "aspx", "windows"},
{"auto", "ASP", "windows"},
{"auto", "php", "linux"},
{"auto", "jsp", "linux"},
{"auto", "custom", "linux"},
{"auto", "", "linux"},
// 空/未知 OS 等价 auto
{"", "asp", "windows"},
{"", "php", "linux"},
{"unknown", "aspx", "windows"},
}
for _, c := range cases {
got := resolveWebshellOS(c.osTag, c.shellType)
if got != c.want {
t.Errorf("resolveWebshellOS(%q,%q) = %q, want %q", c.osTag, c.shellType, got, c.want)
}
}
}
func TestQuoteCmdPath(t *testing.T) {
cases := map[string]string{
"": `"."`,
`C:\Windows\Temp`: `"C:\Windows\Temp"`,
`C:\Program Files\a`: `"C:\Program Files\a"`,
`C:\weird"name\f.txt`: `"C:\weird""name\f.txt"`,
`.`: `"."`,
}
for in, want := range cases {
if got := quoteCmdPath(in); got != want {
t.Errorf("quoteCmdPath(%q) = %q, want %q", in, got, want)
}
}
}
func TestQuoteShellSinglePosix(t *testing.T) {
cases := map[string]string{
"": ".",
"/tmp/a b": "'/tmp/a b'",
"/tmp/it's.txt": `'/tmp/it'\''s.txt'`,
}
for in, want := range cases {
if got := quoteShellSinglePosix(in); got != want {
t.Errorf("quoteShellSinglePosix(%q) = %q, want %q", in, got, want)
}
}
}
// TestBuildFileCommand_LinuxBranch 覆盖 Linux 目标下每个 action 产出的命令
func TestBuildFileCommand_LinuxBranch(t *testing.T) {
h := newTestWebShellHandler()
base := fileCommandInput{OS: "linux", ShellType: "php"}
mustContain := func(t *testing.T, cmd string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if !strings.Contains(cmd, s) {
t.Errorf("expected command to contain %q, got: %s", s, cmd)
}
}
}
mustNotContain := func(t *testing.T, cmd string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if strings.Contains(cmd, s) {
t.Errorf("command should not contain %q, got: %s", s, cmd)
}
}
}
// list with empty path defaults to '.'
in := base
in.Action = "list"
cmd, err := h.buildFileCommand(in)
if err != nil {
t.Fatalf("list linux: unexpected err: %v", err)
}
mustContain(t, cmd, "ls -la", "'.'")
// list with path containing spaces
in.Path = "/tmp/my files"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "ls -la ", "'/tmp/my files'")
// read with path
in = base
in.Action = "read"
in.Path = "/etc/passwd"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "cat ", "'/etc/passwd'")
// read without path → error
in.Path = ""
if _, err := h.buildFileCommand(in); err != errFileOpPathRequired {
t.Errorf("read empty path: want errFileOpPathRequired, got %v", err)
}
// delete
in = base
in.Action = "delete"
in.Path = "/tmp/a.txt"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "rm -f ", "'/tmp/a.txt'")
mustNotContain(t, cmd, "del")
// mkdir
in.Action = "mkdir"
in.Path = "/tmp/new/sub"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "mkdir -p ", "'/tmp/new/sub'")
// rename
in = base
in.Action = "rename"
in.Path = "/tmp/a"
in.TargetPath = "/tmp/b"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "mv -f ", "'/tmp/a'", "'/tmp/b'")
// rename missing target → error
in.TargetPath = ""
if _, err := h.buildFileCommand(in); err != errFileOpRenameNeedsBothPaths {
t.Errorf("rename empty target: want errFileOpRenameNeedsBothPaths, got %v", err)
}
// write
in = base
in.Action = "write"
in.Path = "/tmp/w.txt"
in.Content = "hello 世界"
cmd, _ = h.buildFileCommand(in)
b64 := base64.StdEncoding.EncodeToString([]byte("hello 世界"))
mustContain(t, cmd, "echo '"+b64+"'", "| base64 -d", "> '/tmp/w.txt'")
// upload
in = base
in.Action = "upload"
in.Path = "/tmp/bin"
in.Content = "YWJjZA==" // base64 of "abcd"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "echo 'YWJjZA=='", "| base64 -d", "> '/tmp/bin'")
// upload oversized content → error
in.Content = strings.Repeat("A", 513*1024)
if _, err := h.buildFileCommand(in); err != errFileOpUploadTooLarge {
t.Errorf("upload too large: want errFileOpUploadTooLarge, got %v", err)
}
// upload_chunk with chunk_index=0 uses single redirect
in = base
in.Action = "upload_chunk"
in.Path = "/tmp/bin"
in.Content = "YWJj"
in.ChunkIndex = 0
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "base64 -d > '/tmp/bin'")
mustNotContain(t, cmd, ">>")
// upload_chunk with chunk_index>0 uses append redirect
in.ChunkIndex = 1
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "base64 -d >> '/tmp/bin'")
// unsupported action
in = base
in.Action = "nope"
if _, err := h.buildFileCommand(in); err == nil || !strings.Contains(err.Error(), "unsupported action") {
t.Errorf("unknown action: want unsupported action error, got %v", err)
}
}
// TestBuildFileCommand_WindowsBranch 覆盖 Windows 目标下每个 action 产出的命令
func TestBuildFileCommand_WindowsBranch(t *testing.T) {
h := newTestWebShellHandler()
base := fileCommandInput{OS: "windows", ShellType: "php"}
mustContain := func(t *testing.T, cmd string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if !strings.Contains(cmd, s) {
t.Errorf("expected command to contain %q, got: %s", s, cmd)
}
}
}
mustNotContain := func(t *testing.T, cmd string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if strings.Contains(cmd, s) {
t.Errorf("command should not contain %q, got: %s", s, cmd)
}
}
}
// list
in := base
in.Action = "list"
cmd, _ := h.buildFileCommand(in)
mustContain(t, cmd, "dir /a ", `"."`)
mustNotContain(t, cmd, "ls -la")
in.Path = `C:\Users\Public Docs`
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "dir /a ", `"C:\Users\Public Docs"`)
// read
in = base
in.Action = "read"
in.Path = `C:\flag.txt`
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "type ", `"C:\flag.txt"`)
// delete
in.Action = "delete"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "del /q /f ", `"C:\flag.txt"`)
mustNotContain(t, cmd, "rm -f")
// mkdir
in.Action = "mkdir"
in.Path = `C:\a\b\c`
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "md ", `"C:\a\b\c"`)
// rename
in = base
in.Action = "rename"
in.Path = `C:\a.txt`
in.TargetPath = `C:\b.txt`
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "move /y ", `"C:\a.txt"`, `"C:\b.txt"`)
// write → PowerShell base64 one-liner
in = base
in.Action = "write"
in.Path = `C:\out.txt`
in.Content = "hello 世界"
cmd, _ = h.buildFileCommand(in)
wantB64 := base64.StdEncoding.EncodeToString([]byte("hello 世界"))
mustContain(t, cmd,
"powershell -NoProfile -NonInteractive -Command",
"[Convert]::FromBase64String('"+wantB64+"')",
"[IO.File]::WriteAllBytes('C:\\out.txt'",
)
mustNotContain(t, cmd, "echo ", "base64 -d")
// upload (chunk_index=0 equivalent) uses WriteAllBytes
in = base
in.Action = "upload"
in.Path = `C:\bin\f`
in.Content = "YWJjZA=="
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "WriteAllBytes('C:\\bin\\f'", "FromBase64String('YWJjZA==')")
// upload_chunk index=0 → WriteAllBytes
in.Action = "upload_chunk"
in.ChunkIndex = 0
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "WriteAllBytes(")
mustNotContain(t, cmd, "FileMode]::Append")
// upload_chunk index>0 → append (Open with Append mode)
in.ChunkIndex = 1
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "[IO.FileMode]::Append", "FromBase64String('YWJjZA==')")
}
// TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior 确保 os=auto 时与旧版 shellType 判定行为完全一致
// asp/aspx 视为 Windows(旧行为),其他视为 Linux。
func TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior(t *testing.T) {
h := newTestWebShellHandler()
// asp + auto → windows 命令
cmd, _ := h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "asp"})
if !strings.Contains(cmd, "dir /a") {
t.Errorf("auto + asp should use Windows cmd, got: %s", cmd)
}
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "aspx"})
if !strings.Contains(cmd, "dir /a") {
t.Errorf("auto + aspx should use Windows cmd, got: %s", cmd)
}
// php/jsp/custom + auto → linux 命令(与历史行为一致)
for _, st := range []string{"php", "jsp", "custom", ""} {
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: st})
if !strings.Contains(cmd, "ls -la") {
t.Errorf("auto + %q should use Linux cmd, got: %s", st, cmd)
}
}
// 显式 OS 覆盖 shellType
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "windows", ShellType: "php"})
if !strings.Contains(cmd, "dir /a") {
t.Errorf("explicit windows should override php shellType, got: %s", cmd)
}
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "linux", ShellType: "asp"})
if !strings.Contains(cmd, "ls -la") {
t.Errorf("explicit linux should override asp shellType, got: %s", cmd)
}
}
+127
View File
@@ -0,0 +1,127 @@
package handler
import (
"bytes"
"io"
"net/http"
"strings"
"go.uber.org/zap"
)
// webshellOSProbeCommand 探活命令:利用 Windows cmd 与 POSIX shell 对 `%OS%` 展开差异进行判定。
// - Windows cmd`%OS%` 被展开为 `Windows_NT`,回显 `:OSPROBE_Windows_NT:END`
// - POSIX sh/bash`%OS%` 不是变量语法,作为字面量原样保留,回显 `:OSPROBE_%OS%:END`
//
// 一条命令即可得到明确的、互斥的信号,避免探活成本(相比发两次命令)。
// 冒号包裹是为了避免部分 shell 输出多余空白/BOM 时字符串匹配失效。
const webshellOSProbeCommand = "echo :OSPROBE_%OS%:END"
// probeWebshellOSViaExec 通过一次命令执行的回显推断目标操作系统。
//
// 返回值:
// - "windows" / "linux":识别成功
// - "":无法判定(调用方应保留既有 fallback 逻辑)
//
// 入参 execFn 是一个"发命令并拿到回显"的闭包;让 HTTP 入口和 MCP 入口可以共用同一套探活逻辑
// 而不必关心底层是如何发包的。
func probeWebshellOSViaExec(execFn func(cmd string) (output string, ok bool)) string {
if execFn == nil {
return ""
}
out, ok := execFn(webshellOSProbeCommand)
if !ok {
return ""
}
return classifyWebshellOSProbeOutput(out)
}
// classifyWebshellOSProbeOutput 纯函数:根据探活命令的回显判定 OS。
// 抽出来是为了单测可直接覆盖所有分支,无需真实 HTTP 调用。
func classifyWebshellOSProbeOutput(out string) string {
if out == "" {
return ""
}
lower := strings.ToLower(out)
// Windows 强信号:cmd.exe 成功展开了 %OS% 变量
if strings.Contains(out, "Windows_NT") {
return "windows"
}
// 容错:部分老版本 Windows 可能 `%OS%` 展开为其他字样(极少见),再看 PATH/OS 等次级线索
if strings.Contains(lower, "microsoft windows") {
return "windows"
}
// Linux/Unix 强信号:`%OS%` 字面量被原样回显,说明 shell 不是 cmd.exe
if strings.Contains(out, "%OS%") {
return "linux"
}
// 次级线索:部分 webshell 在 Linux 上可能走了其他外壳(如 zsh/ash),
// 但它们对 `%OS%` 同样不展开;若命中 OSPROBE 头部却没拿到 %OS% 字面量,
// 说明回显被中途截断或过滤,保守返回空让上层 fallback。
return ""
}
// newHTTPExecFn 为 HTTP FileOp 路径构造"发命令取回显"的闭包,供探活复用。
// 参数来自 HTTP 请求,复用 buildExecURL / buildExecBody 两个已有的命令编排器,
// 确保探活包与实际文件操作包走完全一致的 webshell 协议(GET/POST、参数名、编码)。
func (h *WebShellHandler) newHTTPExecFn(targetURL, password, shellType, method, cmdParam, encoding string) func(string) (string, bool) {
useGET := strings.ToUpper(strings.TrimSpace(method)) == "GET"
if strings.TrimSpace(cmdParam) == "" {
cmdParam = "cmd"
}
return func(cmd string) (string, bool) {
var (
httpReq *http.Request
err error
)
if useGET {
u := h.buildExecURL(targetURL, shellType, password, cmdParam, cmd)
httpReq, err = http.NewRequest(http.MethodGet, u, nil)
} else {
body := h.buildExecBody(shellType, password, cmdParam, cmd)
httpReq, err = http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body))
if err == nil {
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
}
if err != nil {
return "", false
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
return "", false
}
defer resp.Body.Close()
raw, _ := io.ReadAll(resp.Body)
return decodeWebshellOutput(raw, encoding), resp.StatusCode == http.StatusOK
}
}
// persistDetectedOS 把探活结果回写到连接表;失败只记日志不阻断主流程。
// 设计上故意只触发 UPDATE,不会新建记录,因此即便 connectionID 不存在也只是悄悄放弃。
func (h *WebShellHandler) persistDetectedOS(connectionID, detected string) {
connectionID = strings.TrimSpace(connectionID)
detected = normalizeWebshellOS(detected)
if connectionID == "" || detected == "" || detected == "auto" {
return
}
conn, err := h.db.GetWebshellConnection(connectionID)
if err != nil || conn == nil {
// 不是所有调用方都能提供有效 ID(比如临时测试),这里静默返回
return
}
if normalizeWebshellOS(conn.OS) != "auto" {
// 用户已经显式选过 OS,尊重用户选择,不自动覆盖
return
}
conn.OS = detected
if err := h.db.UpdateWebshellConnection(conn); err != nil {
h.logger.Warn("webshell 探活结果持久化失败", zap.String("id", connectionID), zap.String("os", detected), zap.Error(err))
return
}
h.logger.Info("webshell auto OS 探活成功并持久化", zap.String("id", connectionID), zap.String("os", detected))
}
+68
View File
@@ -0,0 +1,68 @@
package handler
import "testing"
func TestClassifyWebshellOSProbeOutput(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{"Windows cmd 回显完整", ":OSPROBE_Windows_NT:END\r\n", "windows"},
{"Windows cmd 回显带额外空行", "\r\n:OSPROBE_Windows_NT:END\r\n", "windows"},
{"Windows 次级线索 - ver banner", "Microsoft Windows [版本 10.0.19045]\r\n", "windows"},
{"Linux sh 字面量回显", ":OSPROBE_%OS%:END\n", "linux"},
{"Linux 紧凑输出(无换行)", ":OSPROBE_%OS%:END", "linux"},
{"空输出 - 无法判定", "", ""},
{"被过滤的输出 - 无法判定", "something weird", ""},
{"仅有 OSPROBE 前缀但被截断 - 保守返回空", ":OSPROBE_:END", ""},
}
for _, c := range cases {
if got := classifyWebshellOSProbeOutput(c.in); got != c.want {
t.Errorf("case %q: got %q, want %q", c.name, got, c.want)
}
}
}
func TestProbeWebshellOSViaExec_SendsOneCommandOnly(t *testing.T) {
var calls []string
fn := func(cmd string) (string, bool) {
calls = append(calls, cmd)
return ":OSPROBE_Windows_NT:END", true
}
got := probeWebshellOSViaExec(fn)
if got != "windows" {
t.Fatalf("want windows, got %q", got)
}
if len(calls) != 1 {
t.Fatalf("probe should issue exactly one exec call, got %d: %v", len(calls), calls)
}
if calls[0] != webshellOSProbeCommand {
t.Errorf("probe command mismatch: got %q", calls[0])
}
}
func TestProbeWebshellOSViaExec_NotOkReturnsEmpty(t *testing.T) {
// HTTP 非 200 的场景:execFn 返回 ok=false,探活应放弃
fn := func(cmd string) (string, bool) { return "whatever", false }
if got := probeWebshellOSViaExec(fn); got != "" {
t.Errorf("want empty when exec not ok, got %q", got)
}
}
func TestProbeWebshellOSViaExec_NilSafeguard(t *testing.T) {
if got := probeWebshellOSViaExec(nil); got != "" {
t.Errorf("nil execFn should return empty, got %q", got)
}
}
func TestProbeWebshellOSViaExec_LinuxUname(t *testing.T) {
// 某些 webshell 对 `%OS%` 字面量也会过滤(例如安全规则),
// 但主要路径是"%OS% 字面量被原样回显"。这里覆盖标准 Linux 场景。
fn := func(cmd string) (string, bool) {
return ":OSPROBE_%OS%:END\n", true
}
if got := probeWebshellOSViaExec(fn); got != "linux" {
t.Errorf("Linux case: want linux, got %q", got)
}
}
+293
View File
@@ -0,0 +1,293 @@
package handler
import (
"context"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/robot/ilink"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
const wechatLoginTTL = 5 * time.Minute
// WechatConfigSaver 绑定成功后写入配置并重启机器人连接
type WechatConfigSaver interface {
ApplyWechatRobotBinding(cfg config.RobotWechatConfig) error
}
type wechatLoginSession struct {
QRCode string
QRCodeImgURL string
PendingVerify string
CurrentBaseURL string
StartedAt time.Time
}
// WechatRobotHandler 微信 iLink 机器人(扫码绑定 + 配置)
type WechatRobotHandler struct {
config *config.Config
configSaver WechatConfigSaver
logger *zap.Logger
mu sync.Mutex
logins map[string]*wechatLoginSession
}
// NewWechatRobotHandler 创建微信机器人处理器
func NewWechatRobotHandler(cfg *config.Config, saver WechatConfigSaver, logger *zap.Logger) *WechatRobotHandler {
return &WechatRobotHandler{
config: cfg,
configSaver: saver,
logger: logger,
logins: make(map[string]*wechatLoginSession),
}
}
func (h *WechatRobotHandler) purgeExpiredLogins() {
now := time.Now()
for k, v := range h.logins {
if now.Sub(v.StartedAt) > wechatLoginTTL {
delete(h.logins, k)
}
}
}
func (h *WechatRobotHandler) ilinkClient(baseURL string) *ilink.Client {
ver := h.config.Version
if ver == "" {
ver = "1.0.0"
}
ver = strings.TrimPrefix(strings.TrimSpace(ver), "v")
ver = strings.TrimPrefix(ver, "V")
wc := h.config.Robots.Wechat
return ilink.NewClient(baseURL, wc.BotToken, wc.BotAgent, ilink.BuildClientVersion(ver))
}
// HandleWechatQRCode POST /api/robot/wechat/qrcode — 生成绑定二维码
func (h *WechatRobotHandler) HandleWechatQRCode(c *gin.Context) {
h.mu.Lock()
h.purgeExpiredLogins()
h.mu.Unlock()
var req struct {
BotType string `json:"bot_type"`
}
_ = c.ShouldBindJSON(&req)
botType := req.BotType
if botType == "" {
botType = h.config.Robots.Wechat.BotType
}
if botType == "" {
botType = ilink.DefaultBotType
}
baseURL := h.config.Robots.Wechat.BaseURL
if baseURL == "" {
baseURL = ilink.DefaultBaseURL
}
var localTokens []string
if t := h.config.Robots.Wechat.BotToken; t != "" {
localTokens = []string{t}
}
client := h.ilinkClient(baseURL)
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
defer cancel()
qr, err := client.GetBotQRCode(ctx, botType, localTokens)
if err != nil {
h.logger.Warn("获取微信二维码失败", zap.Error(err))
c.JSON(http.StatusBadGateway, gin.H{"error": "获取二维码失败: " + err.Error()})
return
}
if qr.QRCode == "" || qr.QRCodeImgContent == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "微信服务器未返回有效二维码"})
return
}
sessionKey := uuid.New().String()
h.mu.Lock()
h.logins[sessionKey] = &wechatLoginSession{
QRCode: qr.QRCode,
QRCodeImgURL: qr.QRCodeImgContent,
CurrentBaseURL: baseURL,
StartedAt: time.Now(),
}
h.mu.Unlock()
resp := gin.H{
"session_key": sessionKey,
"qrcode": qr.QRCode,
"qrcode_open_url": qr.QRCodeImgContent,
"message": "请使用微信扫描二维码并确认绑定",
}
if dataURL, err := ilink.QRCodeDataURL(qr.QRCodeImgContent, 256); err != nil {
h.logger.Warn("生成二维码图片失败", zap.Error(err))
} else {
resp["qrcode_image_data_url"] = dataURL
}
c.JSON(http.StatusOK, resp)
}
// HandleWechatQRCodeStatus GET /api/robot/wechat/qrcode/status — 轮询扫码状态
func (h *WechatRobotHandler) HandleWechatQRCodeStatus(c *gin.Context) {
sessionKey := c.Query("session_key")
verifyCode := c.Query("verify_code")
if sessionKey == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 session_key"})
return
}
h.mu.Lock()
sess, ok := h.logins[sessionKey]
h.mu.Unlock()
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期,请重新生成二维码"})
return
}
if time.Since(sess.StartedAt) > wechatLoginTTL {
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusGone, gin.H{"error": "二维码已过期,请重新生成"})
return
}
baseURL := sess.CurrentBaseURL
if baseURL == "" {
baseURL = ilink.DefaultBaseURL
}
vc := verifyCode
if vc == "" {
vc = sess.PendingVerify
}
client := h.ilinkClient(baseURL)
ctx, cancel := context.WithTimeout(c.Request.Context(), 40*time.Second)
defer cancel()
st, err := client.GetQRCodeStatus(ctx, sess.QRCode, vc)
if err != nil {
h.logger.Warn("轮询微信二维码状态失败", zap.Error(err))
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
switch st.Status {
case "wait", "scaned":
c.JSON(http.StatusOK, gin.H{"status": st.Status})
return
case "need_verifycode":
c.JSON(http.StatusOK, gin.H{
"status": st.Status,
"message": "请在手机微信查看配对数字,并在下方输入",
})
return
case "scaned_but_redirect":
if st.RedirectHost != "" {
h.mu.Lock()
if s, ok := h.logins[sessionKey]; ok {
s.CurrentBaseURL = "https://" + st.RedirectHost
}
h.mu.Unlock()
}
c.JSON(http.StatusOK, gin.H{"status": st.Status})
return
case "binded_redirect":
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusOK, gin.H{
"status": st.Status,
"already_connected": true,
"message": "该微信已绑定过,无需重复绑定",
})
return
case "confirmed":
if st.BotToken == "" || st.ILinkBotID == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "绑定确认成功但缺少 bot_token"})
return
}
saveBase := st.BaseURL
if saveBase == "" {
saveBase = baseURL
}
wc := h.config.Robots.Wechat
wc.Enabled = true
wc.BotToken = st.BotToken
wc.ILinkBotID = st.ILinkBotID
wc.ILinkUserID = st.ILinkUserID
wc.BaseURL = saveBase
if wc.BotType == "" {
wc.BotType = ilink.DefaultBotType
}
if wc.BotAgent == "" {
wc.BotAgent = ilink.DefaultBotAgent
}
if h.configSaver != nil {
if err := h.configSaver.ApplyWechatRobotBinding(wc); err != nil {
h.logger.Warn("保存微信机器人配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
} else {
h.config.Robots.Wechat = wc
}
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusOK, gin.H{
"status": "confirmed",
"message": "绑定成功,微信机器人已启用",
"ilink_bot_id": st.ILinkBotID,
"ilink_user_id": st.ILinkUserID,
})
return
default:
c.JSON(http.StatusOK, gin.H{"status": st.Status})
}
}
// HandleWechatVerifyCode POST /api/robot/wechat/qrcode/verify — 提交手机配对数字
func (h *WechatRobotHandler) HandleWechatVerifyCode(c *gin.Context) {
var req struct {
SessionKey string `json:"session_key"`
VerifyCode string `json:"verify_code"`
}
if err := c.ShouldBindJSON(&req); err != nil || req.SessionKey == "" || req.VerifyCode == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 session_key 与 verify_code"})
return
}
h.mu.Lock()
sess, ok := h.logins[req.SessionKey]
if ok {
sess.PendingVerify = req.VerifyCode
}
h.mu.Unlock()
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "已提交配对码,请继续等待绑定"})
}
// HandleWechatStatus GET /api/robot/wechat/status — 当前绑定状态(供前端展示)
func (h *WechatRobotHandler) HandleWechatStatus(c *gin.Context) {
wc := h.config.Robots.Wechat
bound := wc.BotToken != "" && wc.ILinkBotID != ""
c.JSON(http.StatusOK, gin.H{
"enabled": wc.Enabled,
"bound": bound,
"ilink_bot_id": wc.ILinkBotID,
"ilink_user_id": wc.ILinkUserID,
"base_url": wc.BaseURL,
})
}
+67
View File
@@ -0,0 +1,67 @@
package knowledge
import (
"context"
"fmt"
"strings"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
"github.com/cloudwego/eino/components/document"
"github.com/pkoukk/tiktoken-go"
)
func tokenizerLenFunc(embeddingModel string) func(string) int {
fallback := func(s string) int {
r := []rune(s)
if len(r) == 0 {
return 0
}
return (len(r) + 3) / 4
}
m := strings.TrimSpace(embeddingModel)
if m == "" {
return fallback
}
tok, err := tiktoken.EncodingForModel(m)
if err != nil {
return fallback
}
return func(s string) int {
return len(tok.Encode(s, nil, nil))
}
}
// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for
// embeddingModel when available, else rune/4 approximation.
func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) {
if chunkSize <= 0 {
return nil, fmt.Errorf("chunk size must be positive")
}
if overlap < 0 {
overlap = 0
}
return recursive.NewSplitter(context.Background(), &recursive.Config{
ChunkSize: chunkSize,
OverlapSize: overlap,
LenFunc: tokenizerLenFunc(embeddingModel),
Separators: []string{
"\n\n", "\n## ", "\n### ", "\n#### ", "\n",
"。", "", "", ". ", "? ", "! ",
" ",
},
})
}
// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#####),适合技术/Markdown 知识库。
func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) {
return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{
Headers: map[string]string{
"#": "h1",
"##": "h2",
"###": "h3",
"####": "h4",
},
TrimHeaders: false,
})
}
+129
View File
@@ -0,0 +1,129 @@
package knowledge
import (
"fmt"
"strings"
)
// Document metadata keys for Eino schema.Document flowing through the RAG pipeline.
const (
metaKBCategory = "kb_category"
metaKBTitle = "kb_title"
metaKBItemID = "kb_item_id"
metaKBChunkIndex = "kb_chunk_index"
metaSimilarity = "similarity"
)
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
const (
DSLRiskType = "risk_type"
DSLSimilarityThreshold = "similarity_threshold"
DSLSubIndexFilter = "sub_index_filter"
)
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
// stay comparable if users skip reindex; new indexes use the same string shape.
func FormatEmbeddingInput(category, title, chunkText string) string {
return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText)
}
// FormatQueryEmbeddingText builds the string embedded at query time so it matches
// [FormatEmbeddingInput] for the same risk category (title left empty for queries).
func FormatQueryEmbeddingText(riskType, query string) string {
q := strings.TrimSpace(query)
rt := strings.TrimSpace(riskType)
if rt != "" {
return FormatEmbeddingInput(rt, "", q)
}
return q
}
// MetaLookupString returns metadata string value or "" if absent.
func MetaLookupString(md map[string]any, key string) string {
if md == nil {
return ""
}
v, ok := md[key]
if !ok || v == nil {
return ""
}
switch t := v.(type) {
case string:
return t
default:
return strings.TrimSpace(fmt.Sprint(t))
}
}
// MetaStringOK returns trimmed non-empty string and true if present and non-empty.
func MetaStringOK(md map[string]any, key string) (string, bool) {
s := strings.TrimSpace(MetaLookupString(md, key))
if s == "" {
return "", false
}
return s, true
}
// RequireMetaString requires a non-empty string metadata field.
func RequireMetaString(md map[string]any, key string) (string, error) {
s, ok := MetaStringOK(md, key)
if !ok {
return "", fmt.Errorf("missing or empty metadata %q", key)
}
return s, nil
}
// RequireMetaInt requires an integer metadata field.
func RequireMetaInt(md map[string]any, key string) (int, error) {
if md == nil {
return 0, fmt.Errorf("missing metadata key %q", key)
}
v, ok := md[key]
if !ok {
return 0, fmt.Errorf("missing metadata key %q", key)
}
switch t := v.(type) {
case int:
return t, nil
case int32:
return int(t), nil
case int64:
return int(t), nil
case float64:
return int(t), nil
default:
return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v)
}
}
// DSLNumeric coerces DSL map values (e.g. from JSON) to float64.
func DSLNumeric(v any) (float64, bool) {
switch t := v.(type) {
case float64:
return t, true
case float32:
return float64(t), true
case int:
return float64(t), true
case int64:
return float64(t), true
case uint32:
return float64(t), true
case uint64:
return float64(t), true
default:
return 0, false
}
}
// MetaFloat64OK reads a float metadata value.
func MetaFloat64OK(md map[string]any, key string) (float64, bool) {
if md == nil {
return 0, false
}
v, ok := md[key]
if !ok {
return 0, false
}
return DSLNumeric(v)
}
+14
View File
@@ -0,0 +1,14 @@
package knowledge
import "testing"
func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) {
q := FormatQueryEmbeddingText("XSS", "payload")
want := FormatEmbeddingInput("XSS", "", "payload")
if q != want {
t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want)
}
if FormatQueryEmbeddingText("", "hello") != "hello" {
t.Fatalf("expected bare query without risk type")
}
}
+25
View File
@@ -0,0 +1,25 @@
package knowledge
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。
// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。
func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) {
if r == nil {
return nil, fmt.Errorf("retriever is nil")
}
ch := compose.NewChain[string, []*schema.Document]()
ch.AppendRetriever(r.AsEinoRetriever())
return ch.Compile(ctx)
}
// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。
func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) {
return BuildKnowledgeRetrieveChain(ctx, r)
}
@@ -0,0 +1,23 @@
package knowledge
import (
"context"
"testing"
"go.uber.org/zap"
)
func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) {
r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop())
_, err := BuildKnowledgeRetrieveChain(context.Background(), r)
if err != nil {
t.Fatal(err)
}
}
func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) {
_, err := BuildKnowledgeRetrieveChain(context.Background(), nil)
if err == nil {
t.Fatal("expected error for nil retriever")
}
}
@@ -0,0 +1,202 @@
package knowledge
import (
"context"
"fmt"
"strings"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity.
//
// Options:
// - [retriever.WithTopK]
// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 01), [DSLSubIndexFilter] (string)
//
// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric.
//
// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then
// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig].
type VectorEinoRetriever struct {
inner *Retriever
}
// NewVectorEinoRetriever wraps r for Eino compose / tooling.
func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever {
if r == nil {
return nil
}
return &VectorEinoRetriever{inner: r}
}
// GetType identifies this retriever for Eino callbacks.
func (h *VectorEinoRetriever) GetType() string {
return "SQLiteVectorKnowledgeRetriever"
}
// Retrieve runs vector search and returns [schema.Document] rows.
func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) {
if h == nil || h.inner == nil {
return nil, fmt.Errorf("VectorEinoRetriever: nil retriever")
}
q := strings.TrimSpace(query)
if q == "" {
return nil, fmt.Errorf("查询不能为空")
}
ro := retriever.GetCommonOptions(nil, opts...)
cfg := h.inner.config
req := &SearchRequest{Query: q}
if ro.TopK != nil && *ro.TopK > 0 {
req.TopK = *ro.TopK
} else if cfg != nil && cfg.TopK > 0 {
req.TopK = cfg.TopK
} else {
req.TopK = 5
}
req.Threshold = 0
if ro.DSLInfo != nil {
if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok {
req.RiskType = strings.TrimSpace(rt)
}
if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok {
if f, ok2 := DSLNumeric(v); ok2 && f > 0 {
req.Threshold = f
}
}
if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok {
req.SubIndexFilter = strings.TrimSpace(sf)
}
}
if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" {
req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter)
}
if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 {
req.Threshold = cfg.SimilarityThreshold
}
if req.Threshold <= 0 {
req.Threshold = 0.7
}
finalTopK := req.TopK
var postPO *config.PostRetrieveConfig
if cfg != nil {
postPO = &cfg.PostRetrieve
}
fetchK := EffectivePrefetchTopK(finalTopK, postPO)
searchReq := *req
searchReq.TopK = fetchK
ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever)
th := req.Threshold
st := &th
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: q,
TopK: finalTopK,
ScoreThreshold: st,
Extra: ro.DSLInfo,
})
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
return
}
_ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out})
}()
results, err := h.inner.vectorSearch(ctx, &searchReq)
if err != nil {
return nil, err
}
out = retrievalResultsToDocuments(results)
if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 {
reranked, rerr := rr.Rerank(ctx, q, out)
if rerr != nil {
if h.inner.logger != nil {
h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr))
}
} else if len(reranked) > 0 {
out = reranked
}
}
tokenModel := ""
if h.inner.embedder != nil {
tokenModel = h.inner.embedder.EmbeddingModelName()
}
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
if err != nil {
return nil, err
}
return out, nil
}
func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document {
out := make([]*schema.Document, 0, len(results))
for _, res := range results {
if res == nil || res.Chunk == nil || res.Item == nil {
continue
}
d := &schema.Document{
ID: res.Chunk.ID,
Content: res.Chunk.ChunkText,
MetaData: map[string]any{
metaKBItemID: res.Item.ID,
metaKBCategory: res.Item.Category,
metaKBTitle: res.Item.Title,
metaKBChunkIndex: res.Chunk.ChunkIndex,
metaSimilarity: res.Similarity,
},
}
d.WithScore(res.Score)
out = append(out, d)
}
return out
}
func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) {
out := make([]*RetrievalResult, 0, len(docs))
for i, d := range docs {
if d == nil {
continue
}
itemID, err := RequireMetaString(d.MetaData, metaKBItemID)
if err != nil {
return nil, fmt.Errorf("document %d: %w", i, err)
}
cat := MetaLookupString(d.MetaData, metaKBCategory)
title := MetaLookupString(d.MetaData, metaKBTitle)
chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex)
if err != nil {
return nil, fmt.Errorf("document %d: %w", i, err)
}
sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity)
item := &KnowledgeItem{ID: itemID, Category: cat, Title: title}
chunk := &KnowledgeChunk{
ID: d.ID,
ItemID: itemID,
ChunkIndex: chunkIdx,
ChunkText: d.Content,
}
out = append(out, &RetrievalResult{
Chunk: chunk,
Item: item,
Similarity: sim,
Score: d.Score(),
})
}
return out, nil
}
var _ retriever.Retriever = (*VectorEinoRetriever)(nil)
+142
View File
@@ -0,0 +1,142 @@
package knowledge
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
)
// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema.
type SQLiteIndexer struct {
db *sql.DB
batchSize int
embeddingModel string
}
// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call.
// batchSize is the embedding batch size; if <= 0, default 64 is used.
// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty).
func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer {
return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)}
}
// GetType implements eino callback run info.
func (s *SQLiteIndexer) GetType() string {
return "SQLiteKnowledgeIndexer"
}
// Store embeds documents and inserts rows. Each doc must carry MetaData:
// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only.
func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
options := indexer.GetCommonOptions(nil, opts...)
if options.Embedding == nil {
return nil, fmt.Errorf("sqlite indexer: embedding is required")
}
if len(docs) == 0 {
return nil, nil
}
ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer)
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
return
}
_ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids})
}()
subIdxStr := strings.Join(options.SubIndexes, ",")
texts := make([]string, len(docs))
for i, d := range docs {
if d == nil {
return nil, fmt.Errorf("sqlite indexer: nil document at %d", i)
}
cat := MetaLookupString(d.MetaData, metaKBCategory)
title := MetaLookupString(d.MetaData, metaKBTitle)
texts[i] = FormatEmbeddingInput(cat, title, d.Content)
}
bs := s.batchSize
if bs <= 0 {
bs = 64
}
var allVecs [][]float64
for start := 0; start < len(texts); start += bs {
end := start + bs
if end > len(texts) {
end = len(texts)
}
batch := texts[start:end]
vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch)
if embedErr != nil {
return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr)
}
if len(vecs) != len(batch) {
return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch))
}
allVecs = append(allVecs, vecs...)
}
embedDim := 0
if len(allVecs) > 0 {
embedDim = len(allVecs[0])
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err)
}
defer tx.Rollback()
ids = make([]string, 0, len(docs))
for i, d := range docs {
chunkID := uuid.New().String()
itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID)
if metaErr != nil {
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
}
chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex)
if metaErr != nil {
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
}
vec := allVecs[i]
if embedDim > 0 && len(vec) != embedDim {
return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim)
}
vec32 := make([]float32, len(vec))
for j, v := range vec {
vec32[j] = float32(v)
}
embeddingJSON, jsonErr := json.Marshal(vec32)
if jsonErr != nil {
return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr)
}
_, err = tx.ExecContext(ctx,
`INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`,
chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim,
)
if err != nil {
return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err)
}
ids = append(ids, chunkID)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite indexer: commit: %w", err)
}
return ids, nil
}
var _ indexer.Indexer = (*SQLiteIndexer)(nil)
+251
View File
@@ -0,0 +1,251 @@
package knowledge
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/embedding"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。
type Embedder struct {
eino embedding.Embedder
config *config.KnowledgeConfig
logger *zap.Logger
rateLimiter *rate.Limiter
rateLimitDelay time.Duration
maxRetries int
retryDelay time.Duration
mu sync.Mutex
}
// NewEmbedder 基于 Eino eino-ext OpenAI EmbedderopenAIConfig 用于在知识库未单独配置 key 时回退 API Key。
func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) {
if cfg == nil {
return nil, fmt.Errorf("knowledge config is nil")
}
var rateLimiter *rate.Limiter
var rateLimitDelay time.Duration
if cfg.Indexing.MaxRPM > 0 {
rpm := cfg.Indexing.MaxRPM
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
if logger != nil {
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
}
} else if cfg.Indexing.RateLimitDelayMs > 0 {
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
if logger != nil {
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
}
}
maxRetries := 3
retryDelay := 1000 * time.Millisecond
if cfg.Indexing.MaxRetries > 0 {
maxRetries = cfg.Indexing.MaxRetries
}
if cfg.Indexing.RetryDelayMs > 0 {
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
}
model := strings.TrimSpace(cfg.Embedding.Model)
if model == "" {
model = "text-embedding-3-small"
}
baseURL := strings.TrimSpace(cfg.Embedding.BaseURL)
baseURL = strings.TrimSuffix(baseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
apiKey := strings.TrimSpace(cfg.Embedding.APIKey)
if apiKey == "" && openAIConfig != nil {
apiKey = strings.TrimSpace(openAIConfig.APIKey)
}
if apiKey == "" {
return nil, fmt.Errorf("embedding API key 未配置")
}
timeout := 120 * time.Second
if cfg.Indexing.RequestTimeoutSeconds > 0 {
timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second
}
httpClient := &http.Client{Timeout: timeout}
inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{
APIKey: apiKey,
BaseURL: baseURL,
ByAzure: false,
Model: model,
HTTPClient: httpClient,
})
if err != nil {
return nil, fmt.Errorf("eino OpenAI embedder: %w", err)
}
return &Embedder{
eino: inner,
config: cfg,
logger: logger,
rateLimiter: rateLimiter,
rateLimitDelay: rateLimitDelay,
maxRetries: maxRetries,
retryDelay: retryDelay,
}, nil
}
// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。
func (e *Embedder) EmbeddingModelName() string {
if e == nil || e.config == nil {
return ""
}
s := strings.TrimSpace(e.config.Embedding.Model)
if s != "" {
return s
}
return "text-embedding-3-small"
}
func (e *Embedder) waitRateLimiter() {
e.mu.Lock()
defer e.mu.Unlock()
if e.rateLimiter != nil {
ctx := context.Background()
if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil {
e.logger.Warn("速率限制器等待失败", zap.Error(err))
}
}
if e.rateLimitDelay > 0 {
time.Sleep(e.rateLimitDelay)
}
}
// EmbedText 单条嵌入(float32,与历史存储格式一致)。
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
vecs, err := e.EmbedStrings(ctx, []string{text})
if err != nil {
return nil, err
}
if len(vecs) != 1 {
return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs))
}
return vecs[0], nil
}
// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。
func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) {
if e == nil || e.eino == nil {
return nil, fmt.Errorf("embedder not initialized")
}
if len(texts) == 0 {
return nil, nil
}
var lastErr error
for attempt := 0; attempt < e.maxRetries; attempt++ {
if attempt > 0 {
wait := e.retryDelay * time.Duration(attempt)
if e.logger != nil {
e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(wait):
}
} else {
e.waitRateLimiter()
}
raw, err := e.eino.EmbedStrings(ctx, texts, opts...)
if err == nil {
out := make([][]float32, len(raw))
for i, row := range raw {
out[i] = make([]float32, len(row))
for j, v := range row {
out[i][j] = float32(v)
}
}
return out, nil
}
lastErr = err
if !e.isRetryableError(err) {
return nil, err
}
if e.logger != nil {
e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err))
}
}
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
}
// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
return e.EmbedStrings(ctx, texts)
}
func (e *Embedder) isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
return true
}
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
return true
}
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
return true
}
return false
}
// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store.
type einoFloatEmbedder struct {
inner *Embedder
}
func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
vec32, err := w.inner.EmbedStrings(ctx, texts, opts...)
if err != nil {
return nil, err
}
out := make([][]float64, len(vec32))
for i, row := range vec32 {
out[i] = make([]float64, len(row))
for j, v := range row {
out[i][j] = float64(v)
}
}
return out, nil
}
func (w *einoFloatEmbedder) GetType() string {
return "CyberStrikeKnowledgeEmbedder"
}
func (w *einoFloatEmbedder) IsCallbacksEnabled() bool {
return false
}
// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path
// and produces float64 vectors expected by generic Eino indexer helpers.
func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder {
return &einoFloatEmbedder{inner: e}
}
+91
View File
@@ -0,0 +1,91 @@
package knowledge
import (
"context"
"database/sql"
"fmt"
"strings"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive".
func normalizeChunkStrategy(s string) string {
v := strings.TrimSpace(strings.ToLower(s))
switch v {
case "recursive":
return "recursive"
case "markdown_then_recursive", "markdown_recursive", "markdown":
return "markdown_then_recursive"
case "":
return "markdown_then_recursive"
default:
return "markdown_then_recursive"
}
}
func buildKnowledgeIndexChain(
ctx context.Context,
indexingCfg *config.IndexingConfig,
db *sql.DB,
recursive document.Transformer,
embeddingModel string,
) (compose.Runnable[[]*schema.Document, []string], error) {
if recursive == nil {
return nil, fmt.Errorf("recursive transformer is nil")
}
if db == nil {
return nil, fmt.Errorf("db is nil")
}
strategy := normalizeChunkStrategy("markdown_then_recursive")
batch := 64
maxChunks := 0
if indexingCfg != nil {
strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy)
if indexingCfg.BatchSize > 0 {
batch = indexingCfg.BatchSize
}
maxChunks = indexingCfg.MaxChunksPerItem
}
si := NewSQLiteIndexer(db, batch, embeddingModel)
ch := compose.NewChain[[]*schema.Document, []string]()
if strategy != "recursive" {
md, err := newMarkdownHeaderSplitter(ctx)
if err != nil {
return nil, fmt.Errorf("markdown splitter: %w", err)
}
ch.AppendDocumentTransformer(md)
}
ch.AppendDocumentTransformer(recursive)
ch.AppendLambda(newChunkEnrichLambda(maxChunks))
ch.AppendIndexer(si)
return ch.Compile(ctx)
}
func newChunkEnrichLambda(maxChunks int) *compose.Lambda {
return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) {
_ = ctx
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil || strings.TrimSpace(d.Content) == "" {
continue
}
out = append(out, d)
}
if maxChunks > 0 && len(out) > maxChunks {
out = out[:maxChunks]
}
for i, d := range out {
if d.MetaData == nil {
d.MetaData = make(map[string]any)
}
d.MetaData[metaKBChunkIndex] = i
}
return out, nil
})
}
+21
View File
@@ -0,0 +1,21 @@
package knowledge
import "testing"
func TestNormalizeChunkStrategy(t *testing.T) {
cases := []struct {
in, want string
}{
{"", "markdown_then_recursive"},
{"recursive", "recursive"},
{"RECURSIVE", "recursive"},
{"markdown_then_recursive", "markdown_then_recursive"},
{"markdown", "markdown_then_recursive"},
{"unknown", "markdown_then_recursive"},
}
for _, tc := range cases {
if got := normalizeChunkStrategy(tc.in); got != tc.want {
t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want)
}
}
}
+352
View File
@@ -0,0 +1,352 @@
package knowledge
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。
type Indexer struct {
db *sql.DB
embedder *Embedder
logger *zap.Logger
chunkSize int
overlap int
indexingCfg *config.IndexingConfig
indexChain compose.Runnable[[]*schema.Document, []string]
fileLoader *fileloader.FileLoader
mu sync.RWMutex
lastError string
lastErrorTime time.Time
errorCount int
rebuildMu sync.RWMutex
isRebuilding bool
rebuildTotalItems int
rebuildCurrent int
rebuildFailed int
rebuildStartTime time.Time
rebuildLastItemID string
rebuildLastChunks int
}
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) {
if db == nil {
return nil, fmt.Errorf("db is nil")
}
if embedder == nil {
return nil, fmt.Errorf("embedder is nil")
}
if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil {
return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err)
}
if kcfg == nil {
kcfg = &config.KnowledgeConfig{}
}
indexingCfg := &kcfg.Indexing
chunkSize := 512
overlap := 50
if indexingCfg.ChunkSize > 0 {
chunkSize = indexingCfg.ChunkSize
}
if indexingCfg.ChunkOverlap >= 0 {
overlap = indexingCfg.ChunkOverlap
}
embedModel := embedder.EmbeddingModelName()
splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel)
if err != nil {
return nil, fmt.Errorf("eino recursive splitter: %w", err)
}
chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel)
if err != nil {
return nil, fmt.Errorf("knowledge index chain: %w", err)
}
var fl *fileloader.FileLoader
fl, err = fileloader.NewFileLoader(ctx, nil)
if err != nil {
if logger != nil {
logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err))
}
fl = nil
err = nil
}
return &Indexer{
db: db,
embedder: embedder,
logger: logger,
chunkSize: chunkSize,
overlap: overlap,
indexingCfg: indexingCfg,
indexChain: chain,
fileLoader: fl,
}, nil
}
// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。
func (idx *Indexer) RecompileIndexChain(ctx context.Context) error {
if idx == nil || idx.db == nil || idx.embedder == nil {
return fmt.Errorf("indexer 未初始化")
}
if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil {
return err
}
embedModel := idx.embedder.EmbeddingModelName()
splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel)
if err != nil {
return fmt.Errorf("eino recursive splitter: %w", err)
}
chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel)
if err != nil {
return fmt.Errorf("knowledge index chain: %w", err)
}
idx.indexChain = chain
return nil
}
// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
if idx.indexChain == nil {
return fmt.Errorf("索引链未初始化")
}
if idx.embedder == nil {
return fmt.Errorf("嵌入器未初始化")
}
var content, category, title, filePath string
err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath)
if err != nil {
return fmt.Errorf("获取知识项失败:%w", err)
}
if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil {
return fmt.Errorf("删除旧向量失败:%w", err)
}
body := strings.TrimSpace(content)
if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil {
docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)})
if lerr == nil && len(docs) > 0 {
var b strings.Builder
for i, d := range docs {
if d == nil {
continue
}
if i > 0 {
b.WriteString("\n\n")
}
b.WriteString(d.Content)
}
if s := strings.TrimSpace(b.String()); s != "" {
body = s
}
} else if idx.logger != nil {
idx.logger.Warn("优先源文件读取失败,使用数据库正文",
zap.String("itemId", itemID),
zap.String("path", filePath),
zap.Error(lerr))
}
}
root := &schema.Document{
ID: itemID,
Content: body,
MetaData: map[string]any{
metaKBCategory: category,
metaKBTitle: title,
metaKBItemID: itemID,
},
}
idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())}
if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 {
idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes))
}
ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...))
if err != nil {
msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err)
idx.mu.Lock()
idx.lastError = msg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
return err
}
if idx.logger != nil {
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids)))
}
idx.rebuildMu.Lock()
idx.rebuildLastItemID = itemID
idx.rebuildLastChunks = len(ids)
idx.rebuildMu.Unlock()
return nil
}
// HasIndex 检查是否存在索引
func (idx *Indexer) HasIndex() (bool, error) {
var count int
err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count)
if err != nil {
return false, fmt.Errorf("检查索引失败:%w", err)
}
return count > 0, nil
}
// RebuildIndex 重建所有索引
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
idx.rebuildMu.Lock()
idx.isRebuilding = true
idx.rebuildTotalItems = 0
idx.rebuildCurrent = 0
idx.rebuildFailed = 0
idx.rebuildStartTime = time.Now()
idx.rebuildLastItemID = ""
idx.rebuildLastChunks = 0
idx.rebuildMu.Unlock()
idx.mu.Lock()
idx.lastError = ""
idx.lastErrorTime = time.Time{}
idx.errorCount = 0
idx.mu.Unlock()
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
if err != nil {
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("查询知识项失败:%w", err)
}
defer rows.Close()
var itemIDs []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("扫描知识项 ID 失败:%w", err)
}
itemIDs = append(itemIDs, id)
}
idx.rebuildMu.Lock()
idx.rebuildTotalItems = len(itemIDs)
idx.rebuildMu.Unlock()
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
failedCount := 0
consecutiveFailures := 0
maxConsecutiveFailures := 5
firstFailureItemID := ""
var firstFailureError error
for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
idx.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemIDs)),
zap.Error(err),
)
}
if consecutiveFailures >= maxConsecutiveFailures {
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("连续索引失败次数过多,立即停止索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemIDs)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError)
}
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
zap.Int("failedCount", failedCount),
zap.Int("totalItems", len(itemIDs)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
}
continue
}
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
idx.rebuildMu.Lock()
idx.rebuildCurrent = i + 1
idx.rebuildFailed = failedCount
idx.rebuildMu.Unlock()
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
}
}
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
return nil
}
// GetLastError 获取最近一次错误信息
func (idx *Indexer) GetLastError() (string, time.Time) {
idx.mu.RLock()
defer idx.mu.RUnlock()
return idx.lastError, idx.lastErrorTime
}
// GetRebuildStatus 获取重建索引状态
func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) {
idx.rebuildMu.RLock()
defer idx.rebuildMu.RUnlock()
return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime
}
+885
View File
@@ -0,0 +1,885 @@
package knowledge
import (
"database/sql"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Manager 知识库管理器
type Manager struct {
db *sql.DB
basePath string
logger *zap.Logger
}
// NewManager 创建新的知识库管理器
func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
return &Manager{
db: db,
basePath: basePath,
logger: logger,
}
}
// ScanKnowledgeBase 扫描知识库目录,更新数据库
// 返回需要索引的知识项ID列表(新添加的或更新的)
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
if m.basePath == "" {
return nil, fmt.Errorf("知识库路径未配置")
}
// 确保目录存在
if err := os.MkdirAll(m.basePath, 0755); err != nil {
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
}
var itemsToIndex []string
// 遍历知识库目录
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
// 跳过目录和非markdown文件
if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") {
return nil
}
// 计算相对路径和分类
relPath, err := filepath.Rel(m.basePath, path)
if err != nil {
return err
}
// 第一个目录名作为分类(风险类型)
parts := strings.Split(relPath, string(filepath.Separator))
category := "未分类"
if len(parts) > 1 {
category = parts[0]
}
// 文件名为标题
title := strings.TrimSuffix(filepath.Base(path), ".md")
// 读取文件内容
content, err := os.ReadFile(path)
if err != nil {
m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err))
return nil // 继续处理其他文件
}
// 检查是否已存在
var existingID string
var existingContent string
var existingUpdatedAt time.Time
err = m.db.QueryRow(
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
path,
).Scan(&existingID, &existingContent, &existingUpdatedAt)
if err == sql.ErrNoRows {
// 创建新项
id := uuid.New().String()
now := time.Now()
_, err = m.db.Exec(
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, category, title, path, string(content), now, now,
)
if err != nil {
return fmt.Errorf("插入知识项失败: %w", err)
}
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
// 新添加的项需要索引
itemsToIndex = append(itemsToIndex, id)
} else if err == nil {
// 检查内容是否有变化
contentChanged := existingContent != string(content)
if contentChanged {
// 更新现有项
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, string(content), time.Now(), existingID,
)
if err != nil {
return fmt.Errorf("更新知识项失败: %w", err)
}
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
// 内容已更新的项需要重新索引
itemsToIndex = append(itemsToIndex, existingID)
} else {
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
}
} else {
return fmt.Errorf("查询知识项失败: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
return itemsToIndex, nil
}
// GetCategories 获取所有分类(风险类型)
func (m *Manager) GetCategories() ([]string, error) {
rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category")
if err != nil {
return nil, fmt.Errorf("查询分类失败: %w", err)
}
defer rows.Close()
var categories []string
for rows.Next() {
var category string
if err := rows.Scan(&category); err != nil {
return nil, fmt.Errorf("扫描分类失败: %w", err)
}
categories = append(categories, category)
}
return categories, nil
}
// GetStats 获取知识库统计信息
func (m *Manager) GetStats() (int, int, error) {
// 获取分类总数
categories, err := m.GetCategories()
if err != nil {
return 0, 0, fmt.Errorf("获取分类失败: %w", err)
}
totalCategories := len(categories)
// 获取知识项总数
var totalItems int
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
if err != nil {
return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err)
}
return totalCategories, totalItems, nil
}
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
// limit: 每页分类数量(0表示不限制)
// offset: 偏移量(按分类偏移)
func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) {
// 首先获取所有分类(带数量统计)
rows, err := m.db.Query(`
SELECT category, COUNT(*) as item_count
FROM knowledge_base_items
GROUP BY category
ORDER BY category
`)
if err != nil {
return nil, 0, fmt.Errorf("查询分类失败: %w", err)
}
defer rows.Close()
// 收集所有分类信息
type categoryInfo struct {
name string
itemCount int
}
var allCategories []categoryInfo
for rows.Next() {
var info categoryInfo
if err := rows.Scan(&info.name, &info.itemCount); err != nil {
return nil, 0, fmt.Errorf("扫描分类失败: %w", err)
}
allCategories = append(allCategories, info)
}
totalCategories := len(allCategories)
// 应用分页(按分类分页)
var paginatedCategories []categoryInfo
if limit > 0 {
start := offset
end := offset + limit
if start >= totalCategories {
paginatedCategories = []categoryInfo{}
} else {
if end > totalCategories {
end = totalCategories
}
paginatedCategories = allCategories[start:end]
}
} else {
paginatedCategories = allCategories
}
// 为每个分类获取其下的知识项(只返回摘要,不包含完整内容)
result := make([]*CategoryWithItems, 0, len(paginatedCategories))
for _, catInfo := range paginatedCategories {
// 获取该分类下的所有知识项
items, _, err := m.GetItemsSummary(catInfo.name, 0, 0)
if err != nil {
return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err)
}
result = append(result, &CategoryWithItems{
Category: catInfo.name,
ItemCount: catInfo.itemCount,
Items: items,
})
}
return result, totalCategories, nil
}
// GetItems 获取知识项列表(完整内容,用于向后兼容)
func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
return m.GetItemsWithOptions(category, 0, 0, true)
}
// GetItemsWithOptions 获取知识项列表(支持分页和可选内容)
// category: 分类筛选(空字符串表示所有分类)
// limit: 每页数量(0表示不限制)
// offset: 偏移量
// includeContent: 是否包含完整内容(false时只返回摘要)
func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) {
var rows *sql.Rows
var err error
// 构建SQL查询
var query string
var args []interface{}
if includeContent {
query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items"
} else {
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
}
if category != "" {
query += " WHERE category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
if offset > 0 {
query += " OFFSET ?"
args = append(args, offset)
}
}
rows, err = m.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItem
for rows.Next() {
item := &KnowledgeItem{}
var createdAt, updatedAt string
if includeContent {
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
} else {
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
// 不包含内容时,Content为空字符串
item.Content = ""
}
// 解析时间 - 支持多种格式
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
// 解析创建时间
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
// 解析更新时间
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
// 如果更新时间为空,使用创建时间
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, nil
}
// GetItemsCount 获取知识项总数
func (m *Manager) GetItemsCount(category string) (int, error) {
var count int
var err error
if category != "" {
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count)
} else {
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count)
}
if err != nil {
return 0, fmt.Errorf("查询知识项总数失败: %w", err)
}
return count, nil
}
// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配)
func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) {
if keyword == "" {
return nil, fmt.Errorf("搜索关键字不能为空")
}
// 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写)
var query string
var args []interface{}
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
// 使用%keyword%进行模糊匹配
searchPattern := "%" + keyword + "%"
query = `
SELECT id, category, title, file_path, created_at, updated_at
FROM knowledge_base_items
WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?))
`
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern)
// 如果指定了分类,添加分类过滤
if category != "" {
query += " AND category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
rows, err := m.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("搜索知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItemSummary
for rows.Next() {
item := &KnowledgeItemSummary{}
var createdAt, updatedAt string
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
// 解析时间
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, nil
}
// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页)
func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) {
// 获取总数
total, err := m.GetItemsCount(category)
if err != nil {
return nil, 0, err
}
// 获取列表数据(不包含内容)
var rows *sql.Rows
var query string
var args []interface{}
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
if category != "" {
query += " WHERE category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
if offset > 0 {
query += " OFFSET ?"
args = append(args, offset)
}
}
rows, err = m.db.Query(query, args...)
if err != nil {
return nil, 0, fmt.Errorf("查询知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItemSummary
for rows.Next() {
item := &KnowledgeItemSummary{}
var createdAt, updatedAt string
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, 0, fmt.Errorf("扫描知识项失败: %w", err)
}
// 解析时间
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, total, nil
}
// GetItem 获取单个知识项
func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
item := &KnowledgeItem{}
var createdAt, updatedAt string
err := m.db.QueryRow(
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?",
id,
).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("知识项不存在")
}
if err != nil {
return nil, fmt.Errorf("查询知识项失败: %w", err)
}
// 解析时间 - 支持多种格式
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
// 解析创建时间
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
// 解析更新时间
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
// 如果更新时间为空,使用创建时间
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
return item, nil
}
// CreateItem 创建知识项
func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) {
id := uuid.New().String()
now := time.Now()
// 构建文件路径
filePath := filepath.Join(m.basePath, category, title+".md")
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
// 写入文件
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
return nil, fmt.Errorf("写入文件失败: %w", err)
}
// 插入数据库
_, err := m.db.Exec(
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, category, title, filePath, content, now, now,
)
if err != nil {
return nil, fmt.Errorf("插入知识项失败: %w", err)
}
return &KnowledgeItem{
ID: id,
Category: category,
Title: title,
FilePath: filePath,
Content: content,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// UpdateItem 更新知识项
func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) {
// 获取现有项
item, err := m.GetItem(id)
if err != nil {
return nil, err
}
// 构建新文件路径
newFilePath := filepath.Join(m.basePath, category, title+".md")
// 如果路径改变,需要移动文件
if item.FilePath != newFilePath {
// 确保新目录存在
if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
// 移动文件
if err := os.Rename(item.FilePath, newFilePath); err != nil {
return nil, fmt.Errorf("移动文件失败: %w", err)
}
// 删除旧目录(如果为空)
oldDir := filepath.Dir(item.FilePath)
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if oldDir != m.basePath {
if err := os.Remove(oldDir); err != nil {
m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err))
}
}
}
}
// 写入文件
if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil {
return nil, fmt.Errorf("写入文件失败: %w", err)
}
// 更新数据库
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, newFilePath, content, time.Now(), id,
)
if err != nil {
return nil, fmt.Errorf("更新知识项失败: %w", err)
}
// 删除旧的向量嵌入(需要重新索引)
_, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id)
if err != nil {
m.logger.Warn("删除旧向量嵌入失败", zap.Error(err))
}
return m.GetItem(id)
}
// DeleteItem 删除知识项
func (m *Manager) DeleteItem(id string) error {
// 获取文件路径
var filePath string
err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath)
if err != nil {
return fmt.Errorf("查询知识项失败: %w", err)
}
// 删除文件
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err))
}
// 删除数据库记录(级联删除向量)
_, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除知识项失败: %w", err)
}
// 删除空目录(如果为空)
dir := filepath.Dir(filePath)
if isEmpty, _ := isEmptyDir(dir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if dir != m.basePath {
if err := os.Remove(dir); err != nil {
m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err))
}
}
}
return nil
}
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
func isEmptyDir(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return false, err
}
for _, entry := range entries {
// 忽略隐藏文件(以 . 开头)
if !strings.HasPrefix(entry.Name(), ".") {
return false, nil
}
}
return true, nil
}
// LogRetrieval 记录检索日志
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
id := uuid.New().String()
itemsJSON, _ := json.Marshal(retrievedItems)
_, err := m.db.Exec(
"INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(),
)
return err
}
// GetIndexStatus 获取索引状态
func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
// 获取总知识项数
var totalItems int
err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
if err != nil {
return nil, fmt.Errorf("查询总知识项数失败: %w", err)
}
// 获取已索引的知识项数(有向量嵌入的)
var indexedItems int
err = m.db.QueryRow(`
SELECT COUNT(DISTINCT item_id)
FROM knowledge_embeddings
`).Scan(&indexedItems)
if err != nil {
return nil, fmt.Errorf("查询已索引项数失败: %w", err)
}
// 计算进度百分比
var progressPercent float64
if totalItems > 0 {
progressPercent = float64(indexedItems) / float64(totalItems) * 100
} else {
progressPercent = 100.0
}
// 判断是否完成
isComplete := indexedItems >= totalItems && totalItems > 0
return map[string]interface{}{
"total_items": totalItems,
"indexed_items": indexedItems,
"progress_percent": progressPercent,
"is_complete": isComplete,
}, nil
}
// GetRetrievalLogs 获取检索日志
func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) {
var rows *sql.Rows
var err error
if messageID != "" {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?",
messageID, limit,
)
} else if conversationID != "" {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?",
conversationID, limit,
)
} else {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?",
limit,
)
}
if err != nil {
return nil, fmt.Errorf("查询检索日志失败: %w", err)
}
defer rows.Close()
var logs []*RetrievalLog
for rows.Next() {
log := &RetrievalLog{}
var createdAt string
var itemsJSON sql.NullString
if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil {
return nil, fmt.Errorf("扫描检索日志失败: %w", err)
}
// 解析时间 - 支持多种格式
var err error
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
for _, format := range timeFormats {
log.CreatedAt, err = time.Parse(format, createdAt)
if err == nil && !log.CreatedAt.IsZero() {
break
}
}
// 如果所有格式都失败,记录警告但继续处理
if log.CreatedAt.IsZero() {
m.logger.Warn("解析检索日志时间失败",
zap.String("timeStr", createdAt),
zap.Error(err),
)
// 使用当前时间作为fallback
log.CreatedAt = time.Now()
}
// 解析检索项
if itemsJSON.Valid {
json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems)
}
logs = append(logs, log)
}
return logs, nil
}
// DeleteRetrievalLog 删除检索日志
func (m *Manager) DeleteRetrievalLog(id string) error {
result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除检索日志失败: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("获取删除行数失败: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("检索日志不存在")
}
return nil
}
+213
View File
@@ -0,0 +1,213 @@
package knowledge
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"unicode"
"unicode/utf8"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/schema"
"github.com/pkoukk/tiktoken-go"
)
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
const postRetrieveMaxPrefetchCap = 200
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
type DocumentReranker interface {
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
}
// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。
type NopDocumentReranker struct{}
// Rerank implements [DocumentReranker] as no-op.
func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) {
return docs, nil
}
var tiktokenEncMu sync.Mutex
var tiktokenEncCache = map[string]*tiktoken.Tiktoken{}
func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) {
m := strings.TrimSpace(model)
if m == "" {
m = "gpt-4"
}
tiktokenEncMu.Lock()
defer tiktokenEncMu.Unlock()
if enc, ok := tiktokenEncCache[m]; ok {
return enc, nil
}
enc, err := tiktoken.EncodingForModel(m)
if err != nil {
enc, err = tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, err
}
}
tiktokenEncCache[m] = enc
return enc, nil
}
func countDocTokens(text, model string) (int, error) {
enc, err := encodingForTokenizerModel(model)
if err != nil {
return 0, err
}
toks := enc.Encode(text, nil, nil)
return len(toks), nil
}
// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。
func normalizeContentFingerprintKey(s string) string {
s = strings.TrimSpace(s)
var b strings.Builder
b.Grow(len(s))
prevSpace := false
for _, r := range s {
if unicode.IsSpace(r) {
if !prevSpace {
b.WriteByte(' ')
prevSpace = true
}
continue
}
prevSpace = false
b.WriteRune(r)
}
return b.String()
}
func contentNormKey(d *schema.Document) string {
if d == nil {
return ""
}
n := normalizeContentFingerprintKey(d.Content)
if n == "" {
return ""
}
sum := sha256.Sum256([]byte(n))
return hex.EncodeToString(sum[:])
}
// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。
func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document {
if len(docs) < 2 {
return docs
}
seen := make(map[string]struct{}, len(docs))
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil {
continue
}
k := contentNormKey(d)
if k == "" {
out = append(out, d)
continue
}
if _, ok := seen[k]; ok {
continue
}
seen[k] = struct{}{}
out = append(out, d)
}
return out
}
// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。
func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) {
if len(docs) == 0 {
return docs, nil
}
unlimitedChars := maxRunes <= 0
unlimitedTok := maxTokens <= 0
if unlimitedChars && unlimitedTok {
return docs, nil
}
remRunes := maxRunes
remTok := maxTokens
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil || strings.TrimSpace(d.Content) == "" {
continue
}
runes := utf8.RuneCountInString(d.Content)
if !unlimitedChars && runes > remRunes {
break
}
var tok int
var err error
if !unlimitedTok {
tok, err = countDocTokens(d.Content, tokenModel)
if err != nil {
return nil, fmt.Errorf("token count: %w", err)
}
if tok > remTok {
break
}
}
out = append(out, d)
if !unlimitedChars {
remRunes -= runes
}
if !unlimitedTok {
remTok -= tok
}
}
return out, nil
}
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
if topK < 1 {
topK = 5
}
fetch := topK
if po != nil && po.PrefetchTopK > fetch {
fetch = po.PrefetchTopK
}
if fetch > postRetrieveMaxPrefetchCap {
fetch = postRetrieveMaxPrefetchCap
}
return fetch
}
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
if finalTopK < 1 {
finalTopK = 5
}
if len(docs) == 0 {
return docs, nil
}
maxChars := 0
maxTok := 0
if po != nil {
maxChars = po.MaxContextChars
maxTok = po.MaxContextTokens
}
out := dedupeByNormalizedContent(docs)
var err error
out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel)
if err != nil {
return nil, err
}
if len(out) > finalTopK {
out = out[:finalTopK]
}
return out, nil
}
@@ -0,0 +1,62 @@
package knowledge
import (
"testing"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/schema"
)
func doc(id, content string, score float64) *schema.Document {
d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}}
d.WithScore(score)
return d
}
func TestDedupeByNormalizedContent(t *testing.T) {
a := doc("1", "hello world", 0.9)
b := doc("2", "hello world", 0.8)
c := doc("3", "other", 0.7)
out := dedupeByNormalizedContent([]*schema.Document{a, b, c})
if len(out) != 2 {
t.Fatalf("len=%d want 2", len(out))
}
if out[0].ID != "1" || out[1].ID != "3" {
t.Fatalf("order/ids wrong: %#v", out)
}
}
func TestEffectivePrefetchTopK(t *testing.T) {
if g := EffectivePrefetchTopK(5, nil); g != 5 {
t.Fatalf("got %d", g)
}
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 {
t.Fatalf("got %d", g)
}
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap {
t.Fatalf("cap: got %d", g)
}
}
func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) {
d1 := doc("1", "ab", 0.9)
d2 := doc("2", "cd", 0.8)
d3 := doc("3", "ef", 0.7)
po := &config.PostRetrieveConfig{MaxContextChars: 3}
out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5)
if err != nil {
t.Fatal(err)
}
if len(out) != 1 || out[0].ID != "1" {
t.Fatalf("got %#v", out)
}
out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2)
if err != nil {
t.Fatal(err)
}
if len(out2) != 2 {
t.Fatalf("topk: len=%d", len(out2))
}
}
+305
View File
@@ -0,0 +1,305 @@
package knowledge
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"sort"
"strings"
"sync"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值),
// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。
type Retriever struct {
db *sql.DB
embedder *Embedder
config *RetrievalConfig
logger *zap.Logger
rerankMu sync.RWMutex
reranker DocumentReranker
}
// RetrievalConfig 检索配置
type RetrievalConfig struct {
TopK int
SimilarityThreshold float64
// SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。
SubIndexFilter string
PostRetrieve config.PostRetrieveConfig
}
// NewRetriever 创建新的检索器
func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever {
return &Retriever{
db: db,
embedder: embedder,
config: config,
logger: logger,
}
}
// UpdateConfig 更新检索配置
func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
if cfg != nil {
r.config = cfg
if r.logger != nil {
r.logger.Info("检索器配置已更新",
zap.Int("top_k", cfg.TopK),
zap.Float64("similarity_threshold", cfg.SimilarityThreshold),
zap.String("sub_index_filter", cfg.SubIndexFilter),
zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK),
zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars),
zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens),
)
}
}
}
// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。
func (r *Retriever) SetDocumentReranker(rr DocumentReranker) {
if r == nil {
return
}
r.rerankMu.Lock()
defer r.rerankMu.Unlock()
r.reranker = rr
}
func (r *Retriever) documentReranker() DocumentReranker {
if r == nil {
return nil
}
r.rerankMu.RLock()
defer r.rerankMu.RUnlock()
return r.reranker
}
func cosineSimilarity(a, b []float32) float64 {
if len(a) != len(b) {
return 0.0
}
var dotProduct, normA, normB float64
for i := range a {
dotProduct += float64(a[i] * b[i])
normA += float64(a[i] * a[i])
normB += float64(b[i] * b[i])
}
if normA == 0 || normB == 0 {
return 0.0
}
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// Search 搜索知识库。统一经 [VectorEinoRetriever]Eino retriever.Retriever 边界)。
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
if req == nil {
return nil, fmt.Errorf("请求不能为空")
}
q := strings.TrimSpace(req.Query)
if q == "" {
return nil, fmt.Errorf("查询不能为空")
}
opts := r.einoRetrieverOptions(req)
docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...)
if err != nil {
return nil, err
}
return documentsToRetrievalResults(docs)
}
func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option {
var opts []retriever.Option
if req.TopK > 0 {
opts = append(opts, retriever.WithTopK(req.TopK))
}
dsl := map[string]any{}
if strings.TrimSpace(req.RiskType) != "" {
dsl[DSLRiskType] = strings.TrimSpace(req.RiskType)
}
if req.Threshold > 0 {
dsl[DSLSimilarityThreshold] = req.Threshold
}
if strings.TrimSpace(req.SubIndexFilter) != "" {
dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter)
}
if len(dsl) > 0 {
opts = append(opts, retriever.WithDSLInfo(dsl))
}
return opts
}
// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。
func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...)
}
func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) {
q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title
FROM knowledge_embeddings e
JOIN knowledge_base_items i ON e.item_id = i.id
WHERE 1=1`
var args []interface{}
if strings.TrimSpace(riskType) != "" {
q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE`
args = append(args, riskType)
}
if tag := strings.TrimSpace(subIndexFilter); tag != "" {
tag = strings.ToLower(strings.ReplaceAll(tag, " ", ""))
q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)`
args = append(args, tag)
}
return q, args
}
// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。
func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
if req.Query == "" {
return nil, fmt.Errorf("查询不能为空")
}
topK := req.TopK
if topK <= 0 && r.config != nil {
topK = r.config.TopK
}
if topK <= 0 {
topK = 5
}
threshold := req.Threshold
if threshold <= 0 && r.config != nil {
threshold = r.config.SimilarityThreshold
}
if threshold <= 0 {
threshold = 0.7
}
subIdxFilter := strings.TrimSpace(req.SubIndexFilter)
if subIdxFilter == "" && r.config != nil {
subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter)
}
queryText := FormatQueryEmbeddingText(req.RiskType, req.Query)
queryEmbedding, err := r.embedder.EmbedText(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("向量化查询失败: %w", err)
}
queryDim := len(queryEmbedding)
expectedModel := ""
if r.embedder != nil {
expectedModel = r.embedder.EmbeddingModelName()
}
sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter)
rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...)
if err != nil {
return nil, fmt.Errorf("查询向量失败: %w", err)
}
defer rows.Close()
type candidate struct {
chunk *KnowledgeChunk
item *KnowledgeItem
similarity float64
}
candidates := make([]candidate, 0)
rowNum := 0
for rows.Next() {
rowNum++
if rowNum%48 == 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string
var chunkIndex, rowDim int
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil {
r.logger.Warn("扫描向量失败", zap.Error(err))
continue
}
var embedding []float32
if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil {
r.logger.Warn("解析向量失败", zap.Error(err))
continue
}
if rowDim > 0 && len(embedding) != rowDim {
r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding)))
continue
}
if queryDim > 0 && len(embedding) != queryDim {
r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding)))
continue
}
if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel {
r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel))
continue
}
similarity := cosineSimilarity(queryEmbedding, embedding)
candidates = append(candidates, candidate{
chunk: &KnowledgeChunk{
ID: chunkID,
ItemID: itemID,
ChunkIndex: chunkIndex,
ChunkText: chunkText,
Embedding: embedding,
},
item: &KnowledgeItem{
ID: itemID,
Category: category,
Title: title,
},
similarity: similarity,
})
}
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].similarity > candidates[j].similarity
})
filtered := make([]candidate, 0, len(candidates))
for _, c := range candidates {
if c.similarity >= threshold {
filtered = append(filtered, c)
}
}
if len(filtered) > topK {
filtered = filtered[:topK]
}
results := make([]*RetrievalResult, len(filtered))
for i, c := range filtered {
results[i] = &RetrievalResult{
Chunk: c.chunk,
Item: c.item,
Similarity: c.similarity,
Score: c.similarity,
}
}
return results, nil
}
// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
return NewVectorEinoRetriever(r)
}
+51
View File
@@ -0,0 +1,51 @@
package knowledge
import (
"database/sql"
"fmt"
)
// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata.
func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error {
if db == nil {
return fmt.Errorf("db is nil")
}
var n int
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
return err
}
if n == 0 {
return nil
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes",
`ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil {
return err
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model",
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil {
return err
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim",
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil {
return err
}
return nil
}
func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error {
var colCount int
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
if err := db.QueryRow(q, column).Scan(&colCount); err != nil {
return err
}
if colCount > 0 {
return nil
}
_, err := db.Exec(alterSQL)
return err
}
// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。
func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error {
return EnsureKnowledgeEmbeddingsSchema(db)
}
+323
View File
@@ -0,0 +1,323 @@
package knowledge
import (
"context"
"encoding/json"
"fmt"
"sort"
"strings"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterKnowledgeTool 注册知识检索工具到MCP服务器
func RegisterKnowledgeTool(
mcpServer *mcp.Server,
retriever *Retriever,
manager *Manager,
logger *zap.Logger,
) {
// 注册第一个工具:获取所有可用的风险类型列表
listRiskTypesTool := mcp.Tool{
Name: builtin.ToolListKnowledgeRiskTypes,
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
ShortDescription: "获取知识库中所有可用的风险类型列表",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
},
}
listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
categories, err := manager.GetCategories()
if err != nil {
logger.Error("获取风险类型列表失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("获取风险类型列表失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(categories) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "知识库中暂无风险类型。",
},
},
}, nil
}
var resultText strings.Builder
resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories)))
for i, category := range categories {
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
}
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: resultText.String(),
},
},
}, nil
}
mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler)
logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name))
// 注册第二个工具:搜索知识库(保持原有功能)
searchTool := mcp.Tool{
Name: builtin.ToolSearchKnowledgeBase,
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
ShortDescription: "搜索知识库中的安全知识(向量语义检索)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "搜索查询内容,描述你想要了解的安全知识主题",
},
"risk_type": map[string]interface{}{
"type": "string",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
},
},
"required": []string{"query"},
},
}
searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: 查询参数不能为空",
},
},
IsError: true,
}, nil
}
riskType := ""
if rt, ok := args["risk_type"].(string); ok && rt != "" {
riskType = rt
}
logger.Info("执行知识库检索",
zap.String("query", query),
zap.String("riskType", riskType),
)
// 检索统一走 Retriever.Search → VectorEinoRetrieverEino retriever 语义)。
searchReq := &SearchRequest{
Query: query,
RiskType: riskType,
TopK: 5,
}
results, err := retriever.Search(ctx, searchReq)
if err != nil {
logger.Error("知识库检索失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("检索失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(results) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query),
},
},
}, nil
}
// 格式化结果
var resultText strings.Builder
// 按余弦相似度(Score)降序
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
// 按文档分组结果,以便更好地展示上下文
type itemGroup struct {
itemID string
results []*RetrievalResult
maxScore float64 // 该文档块的最高相似度
}
itemGroups := make([]*itemGroup, 0)
itemMap := make(map[string]*itemGroup)
for _, result := range results {
itemID := result.Item.ID
group, exists := itemMap[itemID]
if !exists {
group = &itemGroup{
itemID: itemID,
results: make([]*RetrievalResult, 0),
maxScore: result.Score,
}
itemMap[itemID] = group
itemGroups = append(itemGroups, group)
}
group.results = append(group.results, result)
if result.Score > group.maxScore {
group.maxScore = result.Score
}
}
// 按文档内最高相似度排序
sort.Slice(itemGroups, func(i, j int) bool {
return itemGroups[i].maxScore > itemGroups[j].maxScore
})
// 收集检索到的知识项ID(用于日志)
retrievedItemIDs := make([]string, 0, len(itemGroups))
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results)))
resultIndex := 1
for _, group := range itemGroups {
itemResults := group.results
mainResult := itemResults[0]
maxScore := mainResult.Score
for _, result := range itemResults {
if result.Score > maxScore {
maxScore = result.Score
mainResult = result
}
}
// 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序)
sort.Slice(itemResults, func(i, j int) bool {
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
})
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n",
resultIndex, mainResult.Similarity*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
if len(itemResults) == 1 {
// 只有一个chunk,直接显示
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText))
} else {
// 多个chunk,按逻辑顺序显示
resultText.WriteString("内容片段(按文档顺序):\n")
for i, result := range itemResults {
// 标记主结果
marker := ""
if result.Chunk.ID == mainResult.Chunk.ID {
marker = " [主匹配]"
}
resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText))
}
}
resultText.WriteString("\n")
if !contains(retrievedItemIDs, group.itemID) {
retrievedItemIDs = append(retrievedItemIDs, group.itemID)
}
resultIndex++
}
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
// 使用特殊标记,避免影响AI阅读结果
if len(retrievedItemIDs) > 0 {
metadataJSON, _ := json.Marshal(map[string]interface{}{
"_metadata": map[string]interface{}{
"retrievedItemIDs": retrievedItemIDs,
},
})
resultText.WriteString(fmt.Sprintf("\n<!-- METADATA: %s -->", string(metadataJSON)))
}
// 记录检索日志(异步,不阻塞)
// 注意:这里没有conversationID和messageID,需要在Agent层面记录
// 实际的日志记录应该在Agent的progressCallback中完成
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: resultText.String(),
},
},
}, nil
}
mcpServer.RegisterTool(searchTool, searchHandler)
logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name))
}
// contains 检查切片是否包含元素
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录)
func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) {
if q, ok := args["query"].(string); ok {
query = q
}
if rt, ok := args["risk_type"].(string); ok {
riskType = rt
}
return
}
// FormatRetrievalResults 格式化检索结果为字符串(用于日志)
func FormatRetrievalResults(results []*RetrievalResult) string {
if len(results) == 0 {
return "未找到相关结果"
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results)))
itemIDs := make(map[string]bool)
for i, result := range results {
builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n",
i+1, result.Item.Category, result.Item.Title, result.Similarity*100))
itemIDs[result.Item.ID] = true
}
// 返回知识项ID列表(JSON格式)
ids := make([]string, 0, len(itemIDs))
for id := range itemIDs {
ids = append(ids, id)
}
idsJSON, _ := json.Marshal(ids)
builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON)))
return builder.String()
}
+123
View File
@@ -0,0 +1,123 @@
package knowledge
import (
"encoding/json"
"time"
)
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
func formatTime(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format(time.RFC3339)
}
// KnowledgeItem 知识库项
type KnowledgeItem struct {
ID string `json:"id"`
Category string `json:"category"` // 风险类型(文件夹名)
Title string `json:"title"` // 标题(文件名)
FilePath string `json:"filePath"` // 文件路径
Content string `json:"content"` // 文件内容
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// KnowledgeItemSummary 知识库项摘要(用于列表,不包含完整内容)
type KnowledgeItemSummary struct {
ID string `json:"id"`
Category string `json:"category"`
Title string `json:"title"`
FilePath string `json:"filePath"`
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItemSummary
aux := &struct {
*Alias
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}{
Alias: (*Alias)(k),
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItem
aux := &struct {
*Alias
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}{
Alias: (*Alias)(k),
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
// KnowledgeChunk 知识块(用于向量化)
type KnowledgeChunk struct {
ID string `json:"id"`
ItemID string `json:"itemId"`
ChunkIndex int `json:"chunkIndex"`
ChunkText string `json:"chunkText"`
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
CreatedAt time.Time `json:"createdAt"`
}
// RetrievalResult 检索结果
type RetrievalResult struct {
Chunk *KnowledgeChunk `json:"chunk"`
Item *KnowledgeItem `json:"item"`
Similarity float64 `json:"similarity"` // 相似度分数
Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度
}
// RetrievalLog 检索日志
type RetrievalLog struct {
ID string `json:"id"`
ConversationID string `json:"conversationId,omitempty"`
MessageID string `json:"messageId,omitempty"`
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"`
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
CreatedAt time.Time `json:"createdAt"`
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
type Alias RetrievalLog
return json.Marshal(&struct {
*Alias
CreatedAt string `json:"createdAt"`
}{
Alias: (*Alias)(r),
CreatedAt: formatTime(r.CreatedAt),
})
}
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
type CategoryWithItems struct {
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
}
// SearchRequest 搜索请求
type SearchRequest struct {
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据)
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
}
+78
View File
@@ -0,0 +1,78 @@
package project
import (
"fmt"
"sort"
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
)
// AppendSystemPromptBlock 将附加块追加到 system prompt。
func AppendSystemPromptBlock(base, block string) string {
base = strings.TrimSpace(base)
block = strings.TrimSpace(block)
if block == "" {
return base
}
if base == "" {
return block
}
return base + "\n\n" + block
}
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
if db == nil || !cfg.Enabled {
return "", nil
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return "", nil
}
proj, err := db.GetProject(projectID)
if err != nil {
return "", err
}
facts, err := db.ListProjectFactsForIndex(projectID, cfg.DefaultInjectDeprecated)
if err != nil {
return "", err
}
if len(facts) == 0 {
return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil
}
sort.SliceStable(facts, func(i, j int) bool {
if facts[i].Pinned != facts[j].Pinned {
return facts[i].Pinned
}
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
})
maxRunes := cfg.FactIndexMaxRunesEffective()
var b strings.Builder
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s\n", proj.Name, proj.ID))
used := len([]rune(b.String()))
omitted := 0
for _, f := range facts {
line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
lineRunes := len([]rune(line))
if used+lineRunes > maxRunes {
omitted++
continue
}
b.WriteString(line)
used += lineRunes
}
if omitted > 0 {
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
}
b.WriteString("需要完整内容(攻击链、POC、请求响应等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n")
b.WriteString("写入事实时:summary 写「什么+在哪+如何验证」;body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。\n")
return b.String(), nil
}
+100
View File
@@ -0,0 +1,100 @@
package project
import (
"strings"
"cyberstrike-ai/internal/mcp/builtin"
)
// 边渗透边记录:统一节奏文案(agents/*.md 须与 FactRecordingIncrementalRhythmMarkdown 保持一致)。
const (
factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。"
factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。"
factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。"
)
// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。
func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string {
var b strings.Builder
b.WriteString("- **边渗透边记录(强制节奏)**:")
b.WriteString(factRhythmCore)
if coordinator {
b.WriteString(factRhythmCoordinatorSuffix)
}
if subAgent {
b.WriteString(factRhythmSubAgentSuffix)
}
return b.String()
}
func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string {
var b strings.Builder
b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ")
b.WriteString(builtin.ToolUpsertProjectFact)
b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ")
b.WriteString(builtin.ToolRecordVulnerability)
b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。")
if coordinator {
b.WriteString(factRhythmCoordinatorSuffix)
}
if subAgent {
b.WriteString(factRhythmSubAgentSuffix)
}
return b.String()
}
// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。
// coordinatorDelegate 为 true 时追加「协调者代子代理落库」说明(Deep / plan_execute / supervisor)。
func FactRecordingBlackboardSection(coordinatorDelegate bool) string {
var b strings.Builder
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ")
b.WriteString(builtin.ToolGetProjectFact)
b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n")
b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false))
b.WriteString("\n\n")
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ")
b.WriteString(builtin.ToolUpsertProjectFact)
b.WriteString("fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
b.WriteString("- **可交付漏洞**:使用 ")
b.WriteString(builtin.ToolRecordVulnerability)
b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ")
b.WriteString(builtin.ToolListVulnerabilities)
b.WriteString(" 查重,详情用 ")
b.WriteString(builtin.ToolGetVulnerability)
b.WriteString("(id)(默认仅当前项目/会话)。\n")
b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ")
b.WriteString(builtin.ToolDeprecateProjectFact)
b.WriteString(" 或漏洞状态 false_positive。\n")
b.WriteString("- 事实多时用 ")
b.WriteString(builtin.ToolListProjectFacts)
b.WriteString(" / ")
b.WriteString(builtin.ToolSearchProjectFacts)
b.WriteString(" 检索。\n\n")
b.WriteString(FactRecordingGuidanceBlock())
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
return b.String()
}
// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。
func FactRecordingSubAgentSection() string {
return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n"
}
// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。
func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string {
var b strings.Builder
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n")
b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false))
b.WriteString("\n\n")
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n")
b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n")
b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n")
b.WriteString(FactRecordingGuidanceBlock())
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
return b.String()
}
+140
View File
@@ -0,0 +1,140 @@
package project
import (
"fmt"
"strings"
)
// 事实 category 常量(写入 upsert_project_fact 的 category 字段)。
const (
FactCategoryTarget = "target"
FactCategoryAuth = "auth"
FactCategoryInfra = "infra"
FactCategoryBusiness = "business"
FactCategoryFinding = "finding"
FactCategoryChain = "chain"
FactCategoryExploit = "exploit"
FactCategoryPOC = "poc"
FactCategoryNote = "note"
)
// RequiresAttackChainBody 判断该事实是否应携带可复现的攻击链 / exploit 详情(写在 body,非仅 summary)。
func RequiresAttackChainBody(category, factKey string) bool {
c := strings.ToLower(strings.TrimSpace(category))
switch c {
case FactCategoryFinding, FactCategoryChain, FactCategoryExploit, FactCategoryPOC, "vuln":
return true
}
key := strings.ToLower(strings.TrimSpace(factKey))
for _, prefix := range []string{"finding/", "chain/", "exploit/", "poc/"} {
if strings.HasPrefix(key, prefix) {
return true
}
}
return false
}
// IsSparseFactBody 攻击链类事实 body 过短或缺少关键段落时返回 true(软校验,不阻断写入)。
func IsSparseFactBody(category, factKey, body string) bool {
if !RequiresAttackChainBody(category, factKey) {
return false
}
body = strings.TrimSpace(body)
if body == "" {
return true
}
lower := strings.ToLower(body)
// 至少应包含可复现线索:步骤/请求/命令/代码块 之一
hasSteps := strings.Contains(lower, "攻击链") || strings.Contains(lower, "## 攻击") ||
strings.Contains(lower, "## exploit") || strings.Contains(lower, "## poc")
hasHTTP := strings.Contains(lower, "```http") || strings.Contains(lower, "```bash") ||
strings.Contains(lower, "curl ") || strings.Contains(lower, "get ") || strings.Contains(lower, "post ")
hasReq := strings.Contains(lower, "请求") || strings.Contains(lower, "响应") || strings.Contains(lower, "payload")
// 无攻击链/POC/请求等结构线索,视为仅结论性描述(不论长短)
return !(hasSteps || hasHTTP || hasReq)
}
// FactBodyTemplate 按 category 返回建议的 body Markdown 骨架(供 Agent 填入真实内容)。
func FactBodyTemplate(category, factKey string) string {
if RequiresAttackChainBody(category, factKey) {
return attackChainFactBodyTemplate
}
return envFactBodyTemplate
}
const attackChainFactBodyTemplate = `## 结论(可验证,一句话)
<勿仅写「存在漏洞」;写明类型 + 位置 + 触发条件>
## 目标与入口
- 目标: <URL / IP:Port / 主机名>
- 入口: <路径 / 接口 / 参数>
- 前置条件: <匿名 / 角色 / Cookie / 其他依赖>
## 攻击链(逐步可复现)
1. <侦察/发现>
2. <利用/触发>
3. <影响证明(读文件、RCE 回显、越权数据等)>
## Exploit / POC
### 请求
` + "```http\n<METHOD> <path> HTTP/1.1\nHost: ...\n...\n\n<body>\n```" + `
### 响应 / 现象
<关键响应片段、状态码、差异点>
### 命令 / 脚本(如有)
` + "```bash\n<command>\n```" + `
## 关键证据
- <工具输出摘要 / 截图路径 / 会话或消息 ID>
## 关联
- related_vulnerability_id: <可选,对应 record_vulnerability 的 id>
- 依赖事实: <fact_key,如 auth/session_cookie>
## 备注与不确定性
<待验证假设、环境差异、绕过尝试记录>`
const envFactBodyTemplate = `## 摘要
<该事实的核心认知>
## 细节
<端口/版本/路径/凭据特征/业务规则等>
## 来源与证据
<命令输出、响应片段、发现时间>
## 关联
- 相关 fact_key: <可选>`
// FactRecordingGuidanceBlock 写入系统提示:要求事实沉淀攻击链上下文而非仅结论。
func FactRecordingGuidanceBlock() string {
return `### 事实写入规范(审计复现 / 知识沉淀)
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。
- **category / fact_key 建议**
- 环境认知:` + "`target/`" + `` + "`auth/`" + `` + "`infra/`" + `` + "`business/`" + `body 用环境模板即可)
- 发现与利用:` + "`finding/`" + `` + "`chain/`" + `` + "`exploit/`" + `` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
- **与漏洞记录分工**` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。`
}
// SparseBodyWarning 攻击链类事实 body 不足时的工具返回提示(不阻断保存)。
func SparseBodyWarning(category, factKey string) string {
if !IsSparseFactBody(category, factKey, "") {
return ""
}
return fmt.Sprintf(
"\n\n⚠ 提示:category=%q / fact_key=%q 属于攻击链类事实,但 body 为空或过简。请补充完整攻击链与 POC(参考模板),便于后续审计复现。\n建议 body 骨架:\n%s",
category, factKey, FactBodyTemplate(category, factKey),
)
}
// SparseBodyWarningIfNeeded 根据实际 body 判断是否追加警告。
func SparseBodyWarningIfNeeded(category, factKey, body string) string {
if !IsSparseFactBody(category, factKey, body) {
return ""
}
return SparseBodyWarning(category, factKey)
}
+42
View File
@@ -0,0 +1,42 @@
package project
import (
"strings"
"testing"
)
func TestRequiresAttackChainBody(t *testing.T) {
cases := []struct {
cat, key string
want bool
}{
{"finding", "note/misc", true},
{"note", "finding/sqli-login", true},
{"target", "target/primary_domain", false},
{"auth", "auth/admin_cookie", false},
{"chain", "x", true},
{"", "exploit/rce-upload", true},
}
for _, tc := range cases {
if got := RequiresAttackChainBody(tc.cat, tc.key); got != tc.want {
t.Errorf("RequiresAttackChainBody(%q,%q)=%v want %v", tc.cat, tc.key, got, tc.want)
}
}
}
func TestIsSparseFactBody(t *testing.T) {
long := strings.Repeat("x", 150)
if !IsSparseFactBody("finding", "finding/x", "") {
t.Error("empty body should be sparse")
}
if !IsSparseFactBody("finding", "finding/x", long) {
t.Error("body without repro clues should be sparse")
}
body := "## 攻击链\n1. step\n## Exploit\n```http\nGET / HTTP/1.1\n```\n"
if IsSparseFactBody("finding", "finding/x", body) {
t.Error("structured body should not be sparse")
}
if IsSparseFactBody("target", "target/x", "") {
t.Error("env fact empty body is ok")
}
}
+99
View File
@@ -0,0 +1,99 @@
package project
import (
"encoding/json"
"fmt"
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
)
// projectScopePayload 解析 projects.scope_json(约定字段,可扩展)。
type projectScopePayload struct {
Targets []string `json:"targets"`
Exclude []string `json:"exclude"`
Notes string `json:"notes"`
}
// BuildScopeBlock 将项目 scope_json 格式化为 Agent 可读的授权范围块。
func BuildScopeBlock(proj *database.Project) string {
if proj == nil {
return ""
}
raw := strings.TrimSpace(proj.ScopeJSON)
if raw == "" {
return ""
}
var payload projectScopePayload
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
return fmt.Sprintf("## 项目测试范围(project: %s\nscope_json 非合法 JSON,请人工核对配置)\n```\n%s\n```\n"+
"仅对明确授权目标执行测试;超出范围须停止并说明。\n", proj.Name, truncateRunes(raw, 800))
}
var b strings.Builder
b.WriteString(fmt.Sprintf("## 项目测试范围(project: %s, id: %s\n", proj.Name, proj.ID))
b.WriteString("以下为授权边界,**必须遵守**:仅测试列出的 targets,避开 exclude,不得擅自扩大范围。\n")
if len(payload.Targets) > 0 {
b.WriteString("\n**允许测试(targets**\n")
for _, t := range payload.Targets {
t = strings.TrimSpace(t)
if t != "" {
b.WriteString("- " + t + "\n")
}
}
}
if len(payload.Exclude) > 0 {
b.WriteString("\n**明确排除(exclude**\n")
for _, t := range payload.Exclude {
t = strings.TrimSpace(t)
if t != "" {
b.WriteString("- " + t + "\n")
}
}
}
if n := strings.TrimSpace(payload.Notes); n != "" {
b.WriteString("\n**说明(notes**\n" + n + "\n")
}
if len(payload.Targets) == 0 && len(payload.Exclude) == 0 && strings.TrimSpace(payload.Notes) == "" {
b.WriteString("\nscope_json 已配置但未识别 targets/exclude/notes 字段,原始内容供参考)\n```json\n")
b.WriteString(truncateRunes(raw, 1200))
b.WriteString("\n```\n")
}
b.WriteString("\n若目标不在 targets 内或命中 exclude,不得主动扫描/利用;需用户明确扩大授权后再继续。\n")
return b.String()
}
func truncateRunes(s string, max int) string {
r := []rune(s)
if len(r) <= max {
return s
}
return string(r[:max]) + "…"
}
// BuildProjectBlackboardBlock 组合测试范围 + 事实黑板索引。
func BuildProjectBlackboardBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return "", nil
}
proj, err := db.GetProject(projectID)
if err != nil {
return "", err
}
parts := []string{}
if scope := strings.TrimSpace(BuildScopeBlock(proj)); scope != "" {
parts = append(parts, scope)
}
index, err := BuildFactIndexBlock(db, projectID, cfg)
if err != nil {
return "", err
}
if strings.TrimSpace(index) != "" {
parts = append(parts, index)
}
return strings.Join(parts, "\n\n"), nil
}
+40
View File
@@ -0,0 +1,40 @@
package project
import (
"strings"
"testing"
"cyberstrike-ai/internal/database"
)
func TestBuildScopeBlock_targetsExcludeNotes(t *testing.T) {
proj := &database.Project{
ID: "p1",
Name: "Acme",
ScopeJSON: `{"targets":["https://app.example.com"],"exclude":["*.cdn.example.com"],"notes":"仅 Web 层"}`,
}
block := BuildScopeBlock(proj)
if !strings.Contains(block, "https://app.example.com") {
t.Fatalf("missing target: %s", block)
}
if !strings.Contains(block, "cdn.example.com") {
t.Fatalf("missing exclude: %s", block)
}
if !strings.Contains(block, "仅 Web 层") {
t.Fatalf("missing notes: %s", block)
}
}
func TestBuildScopeBlock_empty(t *testing.T) {
if BuildScopeBlock(&database.Project{Name: "X"}) != "" {
t.Fatal("expected empty")
}
}
func TestBuildScopeBlock_invalidJSON(t *testing.T) {
proj := &database.Project{Name: "X", ScopeJSON: `{not json`}
block := BuildScopeBlock(proj)
if !strings.Contains(block, "非合法 JSON") {
t.Fatalf("unexpected: %s", block)
}
}
+21
View File
@@ -0,0 +1,21 @@
package project
import "cyberstrike-ai/internal/database"
// GetProjectStats 聚合项目统计(含待补全事实数)。
func GetProjectStats(db *database.DB, projectID string) (*database.ProjectStats, error) {
stats, err := db.GetProjectStatsCounts(projectID)
if err != nil {
return nil, err
}
rows, err := db.ListProjectFactsForSparseCheck(projectID)
if err != nil {
return nil, err
}
for _, r := range rows {
if IsSparseFactBody(r.Category, r.FactKey, r.Body) {
stats.SparseFactCount++
}
}
return stats, nil
}
+22
View File
@@ -0,0 +1,22 @@
package project
import "strings"
// VisionImageAnalysisSection 单/多代理共用的图片分析提示(analyze_image;上下文仅保留文字摘要)。
func VisionImageAnalysisSection() string {
var b strings.Builder
b.WriteString("## 图片分析\n\n")
b.WriteString("- 遇到图片文件(截图、验证码、登录页、报告配图)时,若存在工具 analyze_image,请传入服务器上的文件路径进行分析。\n")
b.WriteString("- 不要对二进制图片使用 read_file 指望理解内容;用户消息中「📎 xxx.png: /path」即为可传给 analyze_image 的路径。\n")
b.WriteString("- 验证码类:若已从页面或接口保存为本地图片(如 captcha.png),用 analyze_imagequestion 写明「只输出验证码字符」;识别失败则刷新验证码后重新保存再识;复杂滑块/行为验证码勿指望单次识图成功。\n")
b.WriteString("- 委派子代理时,若子任务含验证码/截图识读,在 task description 中写明图片路径与期望输出格式。\n")
return b.String()
}
// AppendVisionImageAnalysisIfReady 仅在 vision.enabled 且 model 已配置时追加图片分析提示。
func AppendVisionImageAnalysisIfReady(base string, visionReady bool) string {
if !visionReady {
return base
}
return AppendSystemPromptBlock(base, VisionImageAnalysisSection())
}
+266
View File
@@ -0,0 +1,266 @@
// Package reasoning maps user/config intent to CloudWeGo Eino OpenAI ChatModel fields
// (ReasoningEffort, ExtraFields such as thinking / reasoning_effort / output_config).
package reasoning
import (
"strings"
"cyberstrike-ai/internal/config"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
)
// ClientIntent is optional per-request override from ChatRequest.reasoning.
type ClientIntent struct {
Mode string
Effort string
}
type wireProfile int
const (
wireNone wireProfile = iota
wireClaude
wireDeepseek
wireOpenAI
wireOutputConfig
)
// ApplyToEinoChatModelConfig merges reasoning-related options into cfg.
// Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set.
func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) {
if cfg == nil || oa == nil {
return
}
sr := &oa.Reasoning
allowClient := sr.AllowClientReasoningEffective()
mode := effectiveMode(sr, client, allowClient)
// Claude (Anthropic): merge admin extras first; optional extended thinking maps to top-level `thinking`
// (see internal/openai convertOpenAIToClaude). DeepSeek/OpenAI-style fields are not sent.
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") ||
strings.EqualFold(strings.TrimSpace(oa.Provider), "anthropic") {
if len(sr.ExtraRequestFields) > 0 {
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
for k, v := range sr.ExtraRequestFields {
cfg.ExtraFields[k] = v
}
}
if mode == "off" {
return
}
applyClaudeExtendedThinking(cfg, mode, effectiveEffort(sr, client, allowClient), oa.Model)
return
}
if mode == "off" {
applyThinkingDisabled(cfg)
return
}
effort := effectiveEffort(sr, client, allowClient)
prof := resolveWireProfile(oa, sr)
// Admin-defined extra root fields (merged first; automatic keys may follow).
if len(sr.ExtraRequestFields) > 0 {
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
for k, v := range sr.ExtraRequestFields {
cfg.ExtraFields[k] = v
}
}
switch prof {
case wireClaude, wireNone:
return
case wireDeepseek:
applyDeepseek(cfg, mode, effort)
case wireOutputConfig:
applyOutputConfigEffort(cfg, mode, effort)
default: // wireOpenAI
applyOpenAICompat(cfg, mode, effort)
}
}
// applyClaudeExtendedThinking sets Anthropic Messages API `thinking` when absent from ExtraRequestFields.
// Uses adaptive + summarized display by default (per Anthropic guidance for Claude 4.x); Sonnet 3.7 uses enabled+budget.
func applyClaudeExtendedThinking(cfg *einoopenai.ChatModelConfig, mode, effort, model string) {
if cfg == nil || mode == "off" {
return
}
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
if _, exists := cfg.ExtraFields["thinking"]; exists {
return
}
m := strings.ToLower(strings.TrimSpace(model))
thinking := map[string]any{
"type": "adaptive",
"display": "summarized",
}
// Sonnet 3.7: manual extended thinking is the documented path.
if strings.Contains(m, "claude-3-7-sonnet") || strings.Contains(m, "3-7-sonnet") || strings.Contains(m, "sonnet-3.7") {
thinking = map[string]any{
"type": "enabled",
"budget_tokens": 10000,
"display": "summarized",
}
}
// Opus 4.7+: manual enabled+budget rejected — keep adaptive only.
if strings.Contains(m, "opus-4-7") || strings.Contains(m, "opus-4.7") {
thinking = map[string]any{
"type": "adaptive",
"display": "summarized",
}
}
_ = effort // reserved: map to Anthropic effort / output_config when API stabilizes in one place
cfg.ExtraFields["thinking"] = thinking
}
func effectiveMode(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string {
server := strings.ToLower(strings.TrimSpace(sr.ModeEffective()))
if server == "" || server == "default" {
server = "auto"
}
if !allowClient || client == nil {
return server
}
cm := strings.ToLower(strings.TrimSpace(client.Mode))
if cm == "" || cm == "default" {
return server
}
return cm
}
func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string {
se := normalizeEffort(sr.Effort)
if !allowClient || client == nil {
return se
}
ce := normalizeEffort(client.Effort)
if ce != "" {
return ce
}
return se
}
func normalizeEffort(s string) string {
e := strings.ToLower(strings.TrimSpace(s))
switch e {
case "low", "medium", "high", "max", "xhigh":
return e
default:
return ""
}
}
// usesExtraFieldsReasoningEffort 为 Eino 无枚举的最高档 effort,经 ExtraFields 原样下发(max / xhigh 由网关自行识别,不做互转)。
func usesExtraFieldsReasoningEffort(e string) bool {
return e == "max" || e == "xhigh"
}
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
return wireClaude
}
p := strings.ToLower(strings.TrimSpace(sr.ProfileEffective()))
switch p {
case "output_config", "output_config_effort":
return wireOutputConfig
case "openai", "openai_compat":
return wireOpenAI
case "deepseek", "deepseek_compat":
return wireDeepseek
case "auto", "":
bu := strings.ToLower(oa.BaseURL)
mo := strings.ToLower(oa.Model)
if strings.Contains(bu, "deepseek") || strings.Contains(mo, "deepseek") {
return wireDeepseek
}
return wireOpenAI
default:
return wireOpenAI
}
}
func applyThinkingDisabled(cfg *einoopenai.ChatModelConfig) {
if cfg == nil {
return
}
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
if _, exists := cfg.ExtraFields["thinking"]; exists {
return
}
cfg.ExtraFields["thinking"] = map[string]any{"type": "disabled"}
}
func applyDeepseek(cfg *einoopenai.ChatModelConfig, mode, effort string) {
// auto: enable thinking for DeepSeek line; on: same; auto without effort still opens thinking.
if mode == "auto" || mode == "on" {
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
cfg.ExtraFields["thinking"] = map[string]any{"type": "enabled"}
}
if effort != "" {
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(effort)
}
}
func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) {
if mode == "auto" && effort == "" {
return
}
e := effort
if mode == "on" && e == "" {
e = "medium"
}
if e == "" {
return
}
if usesExtraFieldsReasoningEffort(e) {
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(e)
return
}
switch e {
case "low":
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelLow
case "medium":
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelMedium
case "high":
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelHigh
}
}
func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort string) {
if mode == "auto" && effort == "" {
return
}
e := effort
if mode == "on" && e == "" {
e = "high"
}
if e == "" {
return
}
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
cfg.ExtraFields["output_config"] = map[string]any{"effort": effortStringForAPI(e)}
}
func effortStringForAPI(e string) string {
// 原样透传:OpenAI 官方多为 xhigh,部分兼容网关为 max,由配置/对话 effort 选择。
return strings.ToLower(strings.TrimSpace(e))
}
+82
View File
@@ -0,0 +1,82 @@
package reasoning
import (
"testing"
"cyberstrike-ai/internal/config"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
)
func TestEffortStringForAPI_passthrough(t *testing.T) {
cases := map[string]string{
"max": "max",
"xhigh": "xhigh",
"HIGH": "high",
"Medium": "medium",
}
for in, want := range cases {
if got := effortStringForAPI(in); got != want {
t.Fatalf("%q -> %q, want %q", in, got, want)
}
}
}
func TestNormalizeEffort_maxAndXhigh(t *testing.T) {
if normalizeEffort("xhigh") != "xhigh" {
t.Fatal("xhigh not accepted")
}
if normalizeEffort("max") != "max" {
t.Fatal("max not accepted")
}
}
func TestApplyOpenAICompat_xhighExtraField(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{
Reasoning: config.OpenAIReasoningConfig{
Profile: "openai_compat",
Mode: "on",
Effort: "xhigh",
},
}
ApplyToEinoChatModelConfig(cfg, oa, nil)
if cfg.ExtraFields == nil {
t.Fatal("expected ExtraFields")
}
if got, _ := cfg.ExtraFields["reasoning_effort"].(string); got != "xhigh" {
t.Fatalf("reasoning_effort=%q", got)
}
}
func TestApplyReasoningOff_disablesThinking(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{
BaseURL: "https://api.openai.com/v1",
Model: "gpt-4o",
Reasoning: config.OpenAIReasoningConfig{
Mode: "off",
},
}
ApplyToEinoChatModelConfig(cfg, oa, nil)
th, ok := cfg.ExtraFields["thinking"].(map[string]any)
if !ok || th["type"] != "disabled" {
t.Fatalf("expected thinking disabled, got %#v", cfg.ExtraFields)
}
}
func TestApplyOpenAICompat_maxPassthrough(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{
Reasoning: config.OpenAIReasoningConfig{
Profile: "openai_compat",
Mode: "on",
Effort: "max",
},
}
ApplyToEinoChatModelConfig(cfg, oa, nil)
got, _ := cfg.ExtraFields["reasoning_effort"].(string)
if got != "max" {
t.Fatalf("max effort wire=%q, want max", got)
}
}
+132
View File
@@ -0,0 +1,132 @@
package vision
import (
"context"
"encoding/base64"
"fmt"
"net"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/schema"
)
// Client 调用独立 Vision ChatModel(单次 Generate)。
type Client struct {
cfg config.VisionConfig
mainOA config.OpenAIConfig
}
// NewClient 构造视觉客户端。
func NewClient(visionCfg config.VisionConfig, mainOpenAI config.OpenAIConfig) *Client {
return &Client{cfg: visionCfg, mainOA: mainOpenAI}
}
// Analyze 将图片字节送入 VL 模型并返回文本描述。
func (c *Client) Analyze(ctx context.Context, img ImagePayload, question string) (string, error) {
if len(img.Bytes) == 0 {
return "", fmt.Errorf("empty image payload")
}
mime := strings.TrimSpace(img.MIMEType)
if mime == "" {
mime = "image/jpeg"
}
oa := c.cfg.OpenAICfgEffective(c.mainOA)
if strings.TrimSpace(oa.APIKey) == "" {
return "", fmt.Errorf("vision API key is empty (set vision.api_key or openai.api_key)")
}
if strings.TrimSpace(oa.Model) == "" {
return "", fmt.Errorf("vision model is empty")
}
timeout := time.Duration(c.cfg.TimeoutSecondsEffective()) * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
httpClient := &http.Client{
Timeout: timeout + 15*time.Second,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 60 * time.Second,
KeepAlive: 60 * time.Second,
}).DialContext,
ResponseHeaderTimeout: timeout + 10*time.Second,
},
}
httpClient = openai.NewEinoHTTPClient(&oa, httpClient)
modelCfg := &einoopenai.ChatModelConfig{
APIKey: oa.APIKey,
BaseURL: strings.TrimSuffix(oa.BaseURL, "/"),
Model: oa.Model,
HTTPClient: httpClient,
}
chatModel, err := einoopenai.NewChatModel(ctx, modelCfg)
if err != nil {
return "", fmt.Errorf("vision chat model: %w", err)
}
b64 := base64.StdEncoding.EncodeToString(img.Bytes)
detail := schema.ImageURLDetailLow
switch c.cfg.DetailEffective() {
case "high":
detail = schema.ImageURLDetailHigh
case "auto":
detail = schema.ImageURLDetailAuto
}
prompt := buildVisionPrompt(question)
userMsg := &schema.Message{
Role: schema.User,
UserInputMultiContent: []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: prompt},
{
Type: schema.ChatMessagePartTypeImageURL,
Image: &schema.MessageInputImage{
MessagePartCommon: schema.MessagePartCommon{
Base64Data: &b64,
MIMEType: mime,
},
Detail: detail,
},
},
},
}
resp, err := chatModel.Generate(ctx, []*schema.Message{userMsg})
if err != nil {
return "", fmt.Errorf("vision generate: %w", err)
}
if resp == nil || strings.TrimSpace(resp.Content) == "" {
return "", fmt.Errorf("vision model returned empty content")
}
return strings.TrimSpace(resp.Content), nil
}
func buildVisionPrompt(question string) string {
q := strings.TrimSpace(question)
if q == "" {
q = "请对图片做通用描述,侧重授权安全测试场景(可见文本、表单、按钮、验证码、错误信息、技术栈线索)。"
}
extra := ""
if looksLikeCaptchaQuestion(q) {
extra = "\n若为验证码:仅输出你辨认出的字符序列,不要空格、标点、解释;看不清则明确说无法识别。"
}
return `你是授权安全测试助手请根据图片回答用户问题只描述你能从图中确认的内容不要编造
用户问题` + q + extra
}
func looksLikeCaptchaQuestion(q string) bool {
s := strings.ToLower(q)
for _, kw := range []string{"验证码", "captcha", "verification code", "verify code", "vcode", "图形码"} {
if strings.Contains(s, kw) {
return true
}
}
return strings.Contains(s, "只输出") && (strings.Contains(s, "字符") || strings.Contains(s, "character"))
}
+12
View File
@@ -0,0 +1,12 @@
package vision
import "testing"
func TestLooksLikeCaptchaQuestion(t *testing.T) {
if !looksLikeCaptchaQuestion("识别验证码,只输出字符") {
t.Fatal("expected captcha hint")
}
if looksLikeCaptchaQuestion("描述登录页布局") {
t.Fatal("expected non-captcha")
}
}
+72
View File
@@ -0,0 +1,72 @@
package vision
import (
"fmt"
"os"
"path/filepath"
"strings"
)
var allowedImageExt = map[string]struct{}{
".png": {}, ".jpg": {}, ".jpeg": {}, ".webp": {}, ".gif": {},
".bmp": {}, ".tif": {}, ".tiff": {},
}
// ResolveImagePath 解析并校验可读图片路径(支持任意目录;仍校验扩展名与常规文件)。
func ResolveImagePath(path string, cwd string) (string, error) {
p := strings.TrimSpace(path)
if p == "" {
return "", fmt.Errorf("path is empty")
}
cwdTrim := strings.TrimSpace(cwd)
if cwdTrim == "" {
var err error
cwdTrim, err = os.Getwd()
if err != nil {
return "", fmt.Errorf("getwd: %w", err)
}
}
cwdAbs, err := filepath.Abs(filepath.Clean(cwdTrim))
if err != nil {
return "", err
}
var candidate string
if filepath.IsAbs(p) {
candidate = filepath.Clean(p)
} else {
candidate = filepath.Clean(filepath.Join(cwdAbs, p))
}
resolved := normalizeAbsPath(candidate)
if resolved == "" {
return "", fmt.Errorf("invalid path")
}
ext := strings.ToLower(filepath.Ext(resolved))
if _, ok := allowedImageExt[ext]; !ok {
return "", fmt.Errorf("unsupported image extension %q", ext)
}
st, err := os.Stat(resolved)
if err != nil {
return "", fmt.Errorf("stat: %w", err)
}
if st.IsDir() {
return "", fmt.Errorf("not a regular file")
}
if st.Size() > 0 && st.Size() > 1<<30 {
return "", fmt.Errorf("file too large on disk")
}
return resolved, nil
}
func normalizeAbsPath(p string) string {
abs, err := filepath.Abs(filepath.Clean(p))
if err != nil {
return ""
}
if link, err := filepath.EvalSymlinks(abs); err == nil {
return link
}
return abs
}
+52
View File
@@ -0,0 +1,52 @@
package vision
import (
"os"
"path/filepath"
"testing"
)
func TestResolveImagePath_underCWD(t *testing.T) {
dir := t.TempDir()
img := filepath.Join(dir, "shot.png")
if err := os.WriteFile(img, []byte{0x89, 0x50, 0x4e, 0x47}, 0o644); err != nil {
t.Fatal(err)
}
got, err := ResolveImagePath(img, dir)
if err != nil {
t.Fatal(err)
}
want := normalizeAbsPath(img)
if got != want {
t.Fatalf("got %q want %q", got, want)
}
}
func TestResolveImagePath_absoluteOutsideCWD(t *testing.T) {
dir := t.TempDir()
cwd := t.TempDir()
img := filepath.Join(dir, "remote.png")
if err := os.WriteFile(img, []byte{0x89, 0x50, 0x4e, 0x47}, 0o644); err != nil {
t.Fatal(err)
}
got, err := ResolveImagePath(img, cwd)
if err != nil {
t.Fatalf("expected absolute path outside cwd to be allowed: %v", err)
}
want := normalizeAbsPath(img)
if got != want {
t.Fatalf("got %q want %q", got, want)
}
}
func TestResolveImagePath_rejectsNonImageExt(t *testing.T) {
dir := t.TempDir()
f := filepath.Join(dir, "notes.txt")
if err := os.WriteFile(f, []byte("x"), 0o644); err != nil {
t.Fatal(err)
}
_, err := ResolveImagePath(f, dir)
if err == nil {
t.Fatal("expected error for non-image extension")
}
}
+212
View File
@@ -0,0 +1,212 @@
package vision
import (
"bytes"
"fmt"
"image"
"os"
"strings"
"github.com/disintegration/imaging"
)
// ImagePayload 送入 VL API 的图片字节与 MIME。
type ImagePayload struct {
Bytes []byte
MIMEType string
}
// PreprocessMeta 记录缩放与编码结果,供工具输出与排障。
type PreprocessMeta struct {
OriginalPath string
OriginalBytes int64
OriginalWidth int
OriginalHeight int
OutputWidth int
OutputHeight int
OutputBytes int
OutputMIMEType string
JPEGQuality int // 0 表示未 JPEG 重编码(原图直传)
PreprocessMode string // passthrough | jpeg
}
// PreprocessOptions 图片预处理参数。
type PreprocessOptions struct {
MaxImageBytes int64
MaxDimension int
JPEGQuality int
MaxPayloadBytes int64
SkipPreprocessBelowBytes int64 // 0 = 始终压缩;>0 时小图+尺寸合规可直传
}
// PreprocessImageFile 读取图片;大图或超尺寸走 imaging 缩放+JPEG,否则可原图直传。
func PreprocessImageFile(path string, opt PreprocessOptions) (ImagePayload, PreprocessMeta, error) {
var meta PreprocessMeta
meta.OriginalPath = path
st, err := os.Stat(path)
if err != nil {
return ImagePayload{}, meta, err
}
meta.OriginalBytes = st.Size()
if opt.MaxImageBytes > 0 && st.Size() > opt.MaxImageBytes {
return ImagePayload{}, meta, fmt.Errorf("file size %d exceeds max_image_bytes %d", st.Size(), opt.MaxImageBytes)
}
cfgW, cfgH, format, err := imageDimensions(path)
if err != nil {
return ImagePayload{}, meta, err
}
meta.OriginalWidth = cfgW
meta.OriginalHeight = cfgH
maxDim := opt.MaxDimension
if maxDim <= 0 {
maxDim = 2048
}
maxPayload := opt.MaxPayloadBytes
if maxPayload <= 0 {
maxPayload = 512 * 1024
}
if payload, meta, ok, err := tryPassthrough(path, st.Size(), cfgW, cfgH, format, opt, maxDim, maxPayload); ok {
return payload, meta, err
}
return compressWithImaging(path, opt, maxDim, maxPayload, meta)
}
func tryPassthrough(path string, size int64, w, h int, format string, opt PreprocessOptions, maxDim int, maxPayload int64) (ImagePayload, PreprocessMeta, bool, error) {
var meta PreprocessMeta
meta.OriginalPath = path
meta.OriginalBytes = size
meta.OriginalWidth = w
meta.OriginalHeight = h
threshold := opt.SkipPreprocessBelowBytes
if threshold <= 0 {
return ImagePayload{}, meta, false, nil
}
if size > threshold {
return ImagePayload{}, meta, false, nil
}
longEdge := w
if h > longEdge {
longEdge = h
}
if longEdge > maxDim {
return ImagePayload{}, meta, false, nil
}
if size > maxPayload {
return ImagePayload{}, meta, false, nil
}
raw, err := os.ReadFile(path)
if err != nil {
return ImagePayload{}, meta, false, err
}
mime := mimeFromImageFormat(format)
if mime == "" {
return ImagePayload{}, meta, false, nil
}
meta.OutputWidth = w
meta.OutputHeight = h
meta.OutputBytes = len(raw)
meta.OutputMIMEType = mime
meta.PreprocessMode = "passthrough"
return ImagePayload{Bytes: raw, MIMEType: mime}, meta, true, nil
}
func compressWithImaging(path string, opt PreprocessOptions, maxDim int, maxPayload int64, meta PreprocessMeta) (ImagePayload, PreprocessMeta, error) {
src, err := imaging.Open(path)
if err != nil {
return ImagePayload{}, meta, fmt.Errorf("open image: %w", err)
}
bounds := src.Bounds()
meta.OriginalWidth = bounds.Dx()
meta.OriginalHeight = bounds.Dy()
dst := imaging.Fit(src, maxDim, maxDim, imaging.Lanczos)
outBounds := dst.Bounds()
meta.OutputWidth = outBounds.Dx()
meta.OutputHeight = outBounds.Dy()
quality := opt.JPEGQuality
if quality <= 0 || quality > 100 {
quality = 82
}
dim := maxDim
for attempt := 0; attempt < 6; attempt++ {
if attempt > 0 {
dim = int(float64(dim) * 0.85)
if dim < 256 {
dim = 256
}
dst = imaging.Fit(src, dim, dim, imaging.Lanczos)
outBounds = dst.Bounds()
meta.OutputWidth = outBounds.Dx()
meta.OutputHeight = outBounds.Dy()
}
q := quality
for q >= 60 {
var buf bytes.Buffer
if err := imaging.Encode(&buf, dst, imaging.JPEG, imaging.JPEGQuality(q)); err != nil {
return ImagePayload{}, meta, fmt.Errorf("encode jpeg: %w", err)
}
if int64(buf.Len()) <= maxPayload {
meta.JPEGQuality = q
meta.OutputBytes = buf.Len()
meta.OutputMIMEType = "image/jpeg"
meta.PreprocessMode = "jpeg"
return ImagePayload{Bytes: buf.Bytes(), MIMEType: "image/jpeg"}, meta, nil
}
q -= 5
}
quality = 75
}
return ImagePayload{}, meta, fmt.Errorf("could not compress image under max_payload_bytes %d", maxPayload)
}
func imageDimensions(path string) (w, h int, format string, err error) {
f, err := os.Open(path)
if err != nil {
return 0, 0, "", err
}
defer f.Close()
cfg, format, err := image.DecodeConfig(f)
if err != nil {
return 0, 0, "", fmt.Errorf("decode image config: %w", err)
}
return cfg.Width, cfg.Height, format, nil
}
func mimeFromImageFormat(format string) string {
switch strings.ToLower(strings.TrimSpace(format)) {
case "jpeg", "jpg":
return "image/jpeg"
case "png":
return "image/png"
case "gif":
return "image/gif"
case "webp":
return "image/webp"
case "bmp":
return "image/bmp"
case "tiff":
return "image/tiff"
default:
return ""
}
}
// DecodeImageConfig 用于测试:确认文件可被解码。
func DecodeImageConfig(path string) (image.Config, string, error) {
f, err := os.Open(path)
if err != nil {
return image.Config{}, "", err
}
defer f.Close()
return image.DecodeConfig(f)
}
+109
View File
@@ -0,0 +1,109 @@
package vision
import (
"image"
"image/color"
"image/png"
"os"
"path/filepath"
"testing"
"github.com/disintegration/imaging"
)
func TestPreprocessImageFile_scalesAndLimitsPayload(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "big.png")
img := imaging.New(3000, 2000, color.White)
if err := imaging.Save(img, path); err != nil {
t.Fatal(err)
}
out, meta, err := PreprocessImageFile(path, PreprocessOptions{
MaxImageBytes: 10 * 1024 * 1024,
MaxDimension: 1024,
JPEGQuality: 85,
MaxPayloadBytes: 600 * 1024,
SkipPreprocessBelowBytes: 0,
})
if err != nil {
t.Fatal(err)
}
if len(out.Bytes) == 0 {
t.Fatal("empty output")
}
if meta.PreprocessMode != "jpeg" {
t.Fatalf("mode: %s", meta.PreprocessMode)
}
if meta.OutputWidth > 1024 || meta.OutputHeight > 1024 {
t.Fatalf("expected fit within 1024, got %dx%d", meta.OutputWidth, meta.OutputHeight)
}
if int64(len(out.Bytes)) > 600*1024 {
t.Fatalf("payload %d exceeds max", len(out.Bytes))
}
}
func TestPreprocessImageFile_passthroughSmallPNG(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "small.png")
if err := imaging.Save(imaging.New(400, 300, color.White), path); err != nil {
t.Fatal(err)
}
out, meta, err := PreprocessImageFile(path, PreprocessOptions{
MaxImageBytes: 5 * 1024 * 1024,
MaxDimension: 2048,
MaxPayloadBytes: 512 * 1024,
SkipPreprocessBelowBytes: 2 * 1024 * 1024,
})
if err != nil {
t.Fatal(err)
}
if meta.PreprocessMode != "passthrough" {
t.Fatalf("expected passthrough, got %s", meta.PreprocessMode)
}
if out.MIMEType != "image/png" {
t.Fatalf("mime: %s", out.MIMEType)
}
if meta.OutputWidth != 400 || meta.OutputHeight != 300 {
t.Fatalf("dims: %dx%d", meta.OutputWidth, meta.OutputHeight)
}
}
func TestPreprocessImageFile_passthroughDisabled(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "small.png")
if err := imaging.Save(imaging.New(100, 100, color.White), path); err != nil {
t.Fatal(err)
}
_, meta, err := PreprocessImageFile(path, PreprocessOptions{
MaxDimension: 2048,
MaxPayloadBytes: 512 * 1024,
SkipPreprocessBelowBytes: 0,
})
if err != nil {
t.Fatal(err)
}
if meta.PreprocessMode != "jpeg" {
t.Fatalf("expected jpeg compress, got %s", meta.PreprocessMode)
}
}
func TestPreprocessImageFile_rejectsOversizeFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "tiny.png")
f, err := os.Create(path)
if err != nil {
t.Fatal(err)
}
if err := png.Encode(f, image.NewRGBA(image.Rect(0, 0, 2, 2))); err != nil {
t.Fatal(err)
}
f.Close()
_, _, err = PreprocessImageFile(path, PreprocessOptions{MaxImageBytes: 1})
if err == nil {
t.Fatal("expected error when file exceeds max_image_bytes")
}
}
+125
View File
@@ -0,0 +1,125 @@
package vision
import (
"context"
"fmt"
"os"
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterAnalyzeImageTool 在 vision.enabled 且 model 已配置时注册 MCP 工具 analyze_image。
func RegisterAnalyzeImageTool(mcpServer *mcp.Server, cfg *config.Config, logger *zap.Logger) {
if mcpServer == nil || cfg == nil {
return
}
if !cfg.Vision.Ready() {
if cfg.Vision.Enabled && logger != nil {
logger.Warn("vision.enabled 但 vision.model 为空,跳过注册 analyze_image")
}
return
}
cwd, err := os.Getwd()
if err != nil {
if logger != nil {
logger.Warn("vision: getwd failed, skip analyze_image", zap.Error(err))
}
return
}
preOpt := PreprocessOptions{
MaxImageBytes: cfg.Vision.MaxImageBytesEffective(),
MaxDimension: cfg.Vision.MaxDimensionEffective(),
JPEGQuality: cfg.Vision.JPEGQualityEffective(),
MaxPayloadBytes: cfg.Vision.MaxPayloadBytesEffective(),
SkipPreprocessBelowBytes: cfg.Vision.SkipPreprocessBelowBytesEffective(),
}
client := NewClient(cfg.Vision, cfg.OpenAI)
tool := mcp.Tool{
Name: builtin.ToolAnalyzeImage,
Description: "分析服务器上的本地图片并返回文字描述(验证码、UI 元素、报错、架构图要点等)。" +
"输入为文件路径(如用户上传的 chat_uploads 路径或工具截图路径)。" +
"输出仅为文本,不含图片数据。不要对二进制图片使用 read_file 指望理解内容。",
ShortDescription: "分析本地图片并返回文字描述(验证码/UI/报错等)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"path": map[string]interface{}{
"type": "string",
"description": "图片绝对路径或相对于进程工作目录的路径",
},
"question": map[string]interface{}{
"type": "string",
"description": "可选:希望模型重点回答的问题。验证码图建议:只输出验证码字符,不要空格和解释",
},
},
"required": []string{"path"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
path, _ := args["path"].(string)
question, _ := args["question"].(string)
abs, err := ResolveImagePath(path, cwd)
if err != nil {
return textResult(fmt.Sprintf("路径校验失败: %v", err), true), nil
}
img, meta, err := PreprocessImageFile(abs, preOpt)
if err != nil {
return textResult(fmt.Sprintf("图片预处理失败: %v", err), true), nil
}
summary, err := client.Analyze(ctx, img, question)
if err != nil {
return textResult(fmt.Sprintf("视觉模型调用失败: %v", err), true), nil
}
body := formatAnalysisResult(abs, meta, summary)
return textResult(body, false), nil
}
mcpServer.RegisterTool(tool, handler)
if logger != nil {
logger.Info("vision: analyze_image 工具已注册", zap.String("model", cfg.Vision.Model))
}
}
func textResult(text string, isError bool) *mcp.ToolResult {
return &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: text}},
IsError: isError,
}
}
func formatAnalysisResult(path string, meta PreprocessMeta, summary string) string {
var b strings.Builder
b.WriteString("## Image analysis\n")
b.WriteString("- **path**: ")
b.WriteString(path)
b.WriteString("\n")
switch meta.PreprocessMode {
case "passthrough":
b.WriteString(fmt.Sprintf("- **preprocess**: passthrough %dx%d, %s, %dKB (original %dKB)\n\n",
meta.OutputWidth, meta.OutputHeight, meta.OutputMIMEType,
(meta.OutputBytes+1023)/1024, (meta.OriginalBytes+1023)/1024))
default:
b.WriteString(fmt.Sprintf("- **preprocess**: %dx%d → %dx%d, jpeg q=%d, %dKB (original %dKB)\n\n",
meta.OriginalWidth, meta.OriginalHeight,
meta.OutputWidth, meta.OutputHeight,
meta.JPEGQuality, (meta.OutputBytes+1023)/1024,
(meta.OriginalBytes+1023)/1024))
}
b.WriteString("### Summary\n")
b.WriteString(strings.TrimSpace(summary))
b.WriteString("\n")
return b.String()
}