mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-25 15:30:15 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
staleRunningMinAge = 45 * time.Second
|
||||
staleRunningReconcileGap = 2 * time.Minute
|
||||
)
|
||||
|
||||
// ExecutionReconciler 在启动或运行期将无对应协程的 running 执行记录收尾为 cancelled。
|
||||
type ExecutionReconciler struct {
|
||||
db *database.DB
|
||||
mcpServer *mcp.Server
|
||||
externalMgr *mcp.ExternalMCPManager
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewExecutionReconciler creates a reconciler for orphaned MCP tool executions.
|
||||
func NewExecutionReconciler(db *database.DB, mcpServer *mcp.Server, externalMgr *mcp.ExternalMCPManager, logger *zap.Logger) *ExecutionReconciler {
|
||||
return &ExecutionReconciler{
|
||||
db: db,
|
||||
mcpServer: mcpServer,
|
||||
externalMgr: externalMgr,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ReconcileOnStartup marks every persisted running row as cancelled (safe right after process start).
|
||||
func (r *ExecutionReconciler) ReconcileOnStartup() {
|
||||
if r == nil || r.db == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
n, err := r.db.CancelOrphanedRunningToolExecutions(now, "执行已中断(服务重启)")
|
||||
if err != nil {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("启动时清理孤儿 running 工具执行记录失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && r.logger != nil {
|
||||
r.logger.Info("启动时已收尾孤儿 running 工具执行记录", zap.Int64("count", n))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ExecutionReconciler) activeExecutionIDs() map[string]struct{} {
|
||||
ids := make(map[string]struct{})
|
||||
if r.mcpServer != nil {
|
||||
for id := range r.mcpServer.ActiveRunningExecutionIDs() {
|
||||
ids[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
if r.externalMgr != nil {
|
||||
for id := range r.externalMgr.ActiveRunningExecutionIDs() {
|
||||
ids[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// ReconcileStaleRunning finalizes running rows that are not tracked in-memory and older than staleRunningMinAge.
|
||||
func (r *ExecutionReconciler) ReconcileStaleRunning() {
|
||||
if r == nil || r.db == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
n, err := r.db.FinalizeStaleRunningToolExecutions(now, staleRunningMinAge, r.activeExecutionIDs(), "执行已中断(会话已结束)")
|
||||
if err != nil {
|
||||
if r.logger != nil {
|
||||
r.logger.Warn("定期收尾 stale running 工具执行记录失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && r.logger != nil {
|
||||
r.logger.Info("已收尾 stale running 工具执行记录", zap.Int64("count", n))
|
||||
}
|
||||
}
|
||||
|
||||
// StartStaleRunningReconcileLoop periodically reconciles orphaned running tool executions.
|
||||
func StartStaleRunningReconcileLoop(r *ExecutionReconciler, logger *zap.Logger) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(staleRunningReconcileGap)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
r.ReconcileStaleRunning()
|
||||
if logger != nil {
|
||||
logger.Debug("monitor stale running reconcile tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestExecutionReconciler_ReconcileOnStartup(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err := db.SaveToolExecution(&mcp.ToolExecution{
|
||||
ID: "run-1", ToolName: "hydra", Status: "running", StartTime: time.Now().Add(-time.Hour),
|
||||
}); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
r := NewExecutionReconciler(db, mcp.NewServer(zap.NewNop()), nil, zap.NewNop())
|
||||
r.ReconcileOnStartup()
|
||||
|
||||
got, err := db.GetToolExecution("run-1")
|
||||
if err != nil {
|
||||
t.Fatalf("GetToolExecution: %v", err)
|
||||
}
|
||||
if got.Status != "cancelled" {
|
||||
t.Fatalf("expected cancelled after startup reconcile, got %s", got.Status)
|
||||
}
|
||||
}
|
||||
@@ -150,6 +150,7 @@ func newEinoSummarizationMiddleware(
|
||||
}
|
||||
if appCfg != nil {
|
||||
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
|
||||
out = refreshUserVerbatimAnchorInMessages(out, db, conversationID, appCfg.MultiAgent.UserVerbatimAnchorMaxRunesEffective(), logger)
|
||||
}
|
||||
return out, nil
|
||||
},
|
||||
@@ -413,6 +414,36 @@ func writeSummarizationTranscript(path string, msgs []adk.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshUserVerbatimAnchorInMessages 压缩后从 messages 表刷新 system 中的用户原文锚点。
|
||||
func refreshUserVerbatimAnchorInMessages(msgs []adk.Message, db *database.DB, conversationID string, maxRunes int, logger *zap.Logger) []adk.Message {
|
||||
if maxRunes < 0 || db == nil {
|
||||
return msgs
|
||||
}
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return msgs
|
||||
}
|
||||
rows, err := db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("summarization: 刷新用户原文锚点失败",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
block := project.BuildUserVerbatimAnchorBlockFromMessages(rows, maxRunes)
|
||||
if block == "" {
|
||||
return msgs
|
||||
}
|
||||
out := project.RefreshUserVerbatimAnchorInMessages(msgs, block)
|
||||
if logger != nil {
|
||||
logger.Info("summarization: 已刷新用户原文锚点", zap.String("conversationId", conversationID))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||
tc := agent.NewTikTokenCounter()
|
||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||
|
||||
@@ -372,8 +372,15 @@ func RunDeepAgent(
|
||||
|
||||
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
||||
deepHandlers := []adk.ChatModelAgentMiddleware{newNoNestedTaskMiddleware()}
|
||||
taskEnrichExtra := systemPromptExtra
|
||||
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunes, taskEnrichExtra); mw != nil {
|
||||
var taskBlackboardSupplement string
|
||||
if appCfg != nil && appCfg.Project.Enabled && db != nil {
|
||||
if pid := strings.TrimSpace(projectID); pid != "" {
|
||||
if block, err := project.BuildFactIndexBlock(db, pid, appCfg.Project); err == nil {
|
||||
taskBlackboardSupplement = strings.TrimSpace(block)
|
||||
}
|
||||
}
|
||||
}
|
||||
if mw := newTaskContextEnrichMiddleware(userMessage, history, ma.SubAgentUserContextMaxRunesEffective(), taskBlackboardSupplement); mw != nil {
|
||||
deepHandlers = append(deepHandlers, mw)
|
||||
}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
|
||||
@@ -3,6 +3,7 @@ package multiagent
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
@@ -11,7 +12,7 @@ import (
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
)
|
||||
|
||||
const defaultSubAgentUserContextMaxRunes = 2000
|
||||
const userContextSupplementHeader = "\n\n## 用户历史输入(原文,子代理必读)\n"
|
||||
|
||||
// taskContextEnrichMiddleware intercepts "task" tool calls on the orchestrator
|
||||
// and appends the user's original conversation messages to the task description.
|
||||
@@ -30,13 +31,14 @@ type taskContextEnrichMiddleware struct {
|
||||
// newTaskContextEnrichMiddleware returns a middleware that enriches task
|
||||
// descriptions with user conversation context. Returns nil if disabled
|
||||
// (maxRunes < 0) or no user messages exist.
|
||||
// projectBlackboard 仅传项目黑板索引块(BuildFactIndexBlock);勿传完整 systemPromptExtra。
|
||||
func newTaskContextEnrichMiddleware(userMessage string, history []agent.ChatMessage, maxRunes int, projectBlackboard string) adk.ChatModelAgentMiddleware {
|
||||
supplement := buildUserContextSupplement(userMessage, history, maxRunes)
|
||||
if bb := strings.TrimSpace(projectBlackboard); bb != "" {
|
||||
if supplement != "" {
|
||||
supplement += "\n\n## 项目黑板索引\n" + bb
|
||||
supplement += "\n\n" + bb
|
||||
} else {
|
||||
supplement = "\n\n## 项目黑板索引\n" + bb
|
||||
supplement = "\n\n" + bb
|
||||
}
|
||||
}
|
||||
if supplement == "" {
|
||||
@@ -86,9 +88,6 @@ func buildUserContextSupplement(userMessage string, history []agent.ChatMessage,
|
||||
if maxRunes < 0 {
|
||||
return ""
|
||||
}
|
||||
if maxRunes == 0 {
|
||||
maxRunes = defaultSubAgentUserContextMaxRunes
|
||||
}
|
||||
|
||||
var userMsgs []string
|
||||
for _, h := range history {
|
||||
@@ -107,12 +106,16 @@ func buildUserContextSupplement(userMessage string, history []agent.ChatMessage,
|
||||
return ""
|
||||
}
|
||||
|
||||
joined := strings.Join(userMsgs, "\n---\n")
|
||||
if len([]rune(joined)) > maxRunes {
|
||||
lines := make([]string, 0, len(userMsgs))
|
||||
for i, msg := range userMsgs {
|
||||
lines = append(lines, fmt.Sprintf("[第%d轮] %s", i+1, msg))
|
||||
}
|
||||
joined := strings.Join(lines, "\n")
|
||||
if maxRunes > 0 && len([]rune(joined)) > maxRunes {
|
||||
joined = truncateKeepFirstLast(userMsgs, maxRunes)
|
||||
}
|
||||
|
||||
return "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n" + joined
|
||||
return userContextSupplementHeader + joined
|
||||
}
|
||||
|
||||
// truncateKeepFirstLast keeps the first and last user messages, giving each
|
||||
|
||||
@@ -74,7 +74,7 @@ func TestBuildUserContextSupplement_DisabledByNegative(t *testing.T) {
|
||||
func TestBuildUserContextSupplement_CustomMaxRunes(t *testing.T) {
|
||||
msg := strings.Repeat("A", 200)
|
||||
result := buildUserContextSupplement(msg, nil, 50)
|
||||
header := "\n\n## 会话上下文(自动补充,确保你了解用户完整意图)\n"
|
||||
header := userContextSupplementHeader
|
||||
body := strings.TrimPrefix(result, header)
|
||||
if len([]rune(body)) > 50 {
|
||||
t.Errorf("body should be capped at 50 runes, got %d", len([]rune(body)))
|
||||
@@ -89,7 +89,7 @@ func TestBuildUserContextSupplement_TruncateKeepsFirstAndLast(t *testing.T) {
|
||||
history = append(history, agent.ChatMessage{Role: "user", Content: strings.Repeat("B", 500)})
|
||||
}
|
||||
last := "最后一条指令"
|
||||
result := buildUserContextSupplement(last, history, 0)
|
||||
result := buildUserContextSupplement(last, history, 800)
|
||||
if !strings.Contains(result, "http://target.com") {
|
||||
t.Error("first message (target URL) should survive truncation")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user